From 89496642e526a2b8d1c50df6d366b538b4c603f3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 27 Aug 2011 15:23:09 +0200 Subject: [PATCH] cse() for vectors, plus make_sym_vector. --- pymbolic/__init__.py | 2 ++ pymbolic/primitives.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/pymbolic/__init__.py b/pymbolic/__init__.py index de5ad50..04a06f0 100644 --- a/pymbolic/__init__.py +++ b/pymbolic/__init__.py @@ -18,6 +18,8 @@ subscript = pymbolic.primitives.subscript flattened_product = pymbolic.primitives.flattened_product quotient = pymbolic.primitives.quotient linear_combination = pymbolic.primitives.linear_combination +cse = pymbolic.primitives.make_common_subexpression +make_sym_vector = pymbolic.primitives.make_sym_vector parse = pymbolic.parser.parse evaluate = pymbolic.mapper.evaluator.evaluate diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index bae27e3..825ebd1 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -826,3 +826,53 @@ def is_zero(value): + +def make_common_subexpression(field, prefix=None): + try: + from pytools.obj_array import log_shape + except ImportError: + have_obj_array = False + else: + have_obj_array = True + + if have_obj_array: + ls = log_shape(field) + + if have_obj_array and ls != (): + from pytools import indices_in_shape + result = numpy.zeros(ls, dtype=object) + + for i in indices_in_shape(ls): + if prefix is not None: + component_prefix = prefix+"_".join(str(i_i) for i_i in i) + else: + component_prefix = None + + if is_constant(field[i]): + result[i] = field[i] + else: + result[i] = CommonSubexpression(field[i], component_prefix) + + return result + else: + if is_constant(field): + return field + else: + return CommonSubexpression(field, prefix) + + + + +def make_sym_vector(name, components): + """Return an object array of *components* subscripted + :class:`Field` instances. + + :param components: The number of components in the vector. + """ + if isinstance(components, int): + components = range(components) + + from hedge.tools import join_fields + vfld = Variable(name) + return join_fields(*[vfld[i] for i in components]) + -- GitLab