diff --git a/pymbolic/__init__.py b/pymbolic/__init__.py index de5ad50958a7447a9e4f8376b7c210cce4e8a2f2..04a06f0ddeb63b47a4cf0dbe9d8bf20993746918 100644 --- a/pymbolic/__init__.py +++ b/pymbolic/__init__.py @@ -18,6 +18,8 @@ subscript = pymbolic.primitives.subscript flattened_product = pymbolic.primitives.flattened_product quotient = pymbolic.primitives.quotient linear_combination = pymbolic.primitives.linear_combination +cse = pymbolic.primitives.make_common_subexpression +make_sym_vector = pymbolic.primitives.make_sym_vector parse = pymbolic.parser.parse evaluate = pymbolic.mapper.evaluator.evaluate diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index bae27e304dc8fed8bf849c884201b78d203b8eb5..825ebd1008f1531a1ad2b8fc7a5054d513cddb77 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -826,3 +826,53 @@ def is_zero(value): + +def make_common_subexpression(field, prefix=None): + try: + from pytools.obj_array import log_shape + except ImportError: + have_obj_array = False + else: + have_obj_array = True + + if have_obj_array: + ls = log_shape(field) + + if have_obj_array and ls != (): + from pytools import indices_in_shape + result = numpy.zeros(ls, dtype=object) + + for i in indices_in_shape(ls): + if prefix is not None: + component_prefix = prefix+"_".join(str(i_i) for i_i in i) + else: + component_prefix = None + + if is_constant(field[i]): + result[i] = field[i] + else: + result[i] = CommonSubexpression(field[i], component_prefix) + + return result + else: + if is_constant(field): + return field + else: + return CommonSubexpression(field, prefix) + + + + +def make_sym_vector(name, components): + """Return an object array of *components* subscripted + :class:`Field` instances. + + :param components: The number of components in the vector. + """ + if isinstance(components, int): + components = range(components) + + from hedge.tools import join_fields + vfld = Variable(name) + return join_fields(*[vfld[i] for i in components]) +