diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index cedd152e653418e6776ab36421050bf880368fa4..b82442390fbce4763fe5352886f4869d69f37617 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -28,7 +28,10 @@ import numpy as np import six # noqa from six.moves import zip, reduce + from pytools import Record, memoize_method, memoize +from pytools.obj_array import obj_array_vectorize + from grudge import sym import grudge.symbolic.mappers as mappers from pymbolic.primitives import Variable, Subscript @@ -363,9 +366,7 @@ def dot_dataflow_graph(code, max_node_label_length=30, for dep in insn.get_dependencies(): gen_expr_arrow(dep, node_names[insn]) - from pytools.obj_array import is_obj_array - - if is_obj_array(code.result): + if isinstance(code.result, np.ndarray) and code.result.dtype.char == "O": for subexp in code.result: gen_expr_arrow(subexp, "result") else: @@ -469,8 +470,6 @@ class Code(object): # {{{ make sure results do not get discarded - from pytools.obj_array import with_object_array_or_scalar - dm = mappers.DependencyMapper(composite_leaves=False) def remove_result_variable(result_expr): @@ -483,7 +482,7 @@ class Code(object): assert isinstance(var, Variable) discardable_vars.discard(var.name) - with_object_array_or_scalar(remove_result_variable, self.result) + obj_array_vectorize(remove_result_variable, self.result) # }}} @@ -593,12 +592,11 @@ class Code(object): if log_quantities is not None: exec_sub_timer.stop().submit() - from pytools.obj_array import with_object_array_or_scalar if profile_data is not None: profile_data['total_time'] = time() - start_time - return (with_object_array_or_scalar(exec_mapper, self.result), + return (obj_array_vectorize(exec_mapper, self.result), profile_data) - return with_object_array_or_scalar(exec_mapper, self.result) + return obj_array_vectorize(exec_mapper, self.result) # }}} @@ -767,8 +765,7 @@ def aggregate_assignments(inf_mapper, instructions, result, for insn in processed_assigns + other_insns for expr in insn.get_dependencies()) - from pytools.obj_array import is_obj_array - if is_obj_array(result): + if isinstance(result, np.ndarray) and result.dtype.char == "O": externally_used_names |= set(expr for expr in result) else: externally_used_names |= set([result]) diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index c4afc0ec32d0d2d1b8a1ecb79003cf919bd5ca18..9813b648c8ef308ff06a1e6db189b581a30e30d9 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -111,7 +111,7 @@ class Operator(pymbolic.primitives.Expression): return StringifyMapper def __call__(self, expr): - from pytools.obj_array import with_object_array_or_scalar + from pytools.obj_array import obj_array_vectorize from grudge.tools import is_zero def bind_one(subexpr): @@ -121,7 +121,7 @@ class Operator(pymbolic.primitives.Expression): from grudge.symbolic.primitives import OperatorBinding return OperatorBinding(self, subexpr) - return with_object_array_or_scalar(bind_one, expr) + return obj_array_vectorize(bind_one, expr) def with_dd(self, dd_in=None, dd_out=None): """Return a copy of *self*, modified to the given DOF descriptors. @@ -151,7 +151,7 @@ class InterpolationOperator(Operator): super(InterpolationOperator, self).__init__(dd_in, dd_out) def __call__(self, expr): - from pytools.obj_array import with_object_array_or_scalar + from pytools.obj_array import obj_array_vectorize def interp_one(subexpr): from pymbolic.primitives import is_constant @@ -164,7 +164,7 @@ class InterpolationOperator(Operator): from grudge.symbolic.primitives import OperatorBinding return OperatorBinding(self, subexpr) - return with_object_array_or_scalar(interp_one, expr) + return obj_array_vectorize(interp_one, expr) mapper_method = intern("map_interpolation") diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index d2f1a01f1c5415e7c2877625156bc0087ab104a6..e6c466c7b2ee95c2a05810c8b8f64384b9e41441 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -363,8 +363,8 @@ class FunctionSymbol(ExpressionBase, VariableBase): """ def __call__(self, *exprs): - from pytools.obj_array import with_object_array_or_scalar_n_args - return with_object_array_or_scalar_n_args( + from pytools.obj_array import obj_array_vectorize_n_args + return obj_array_vectorize_n_args( super(FunctionSymbol, self).__call__, *exprs) mapper_method = "map_function_symbol" @@ -397,10 +397,9 @@ class OperatorBinding(ExpressionBase): return self.op, self.field def is_equal(self, other): - from pytools.obj_array import obj_array_equal return (other.__class__ == self.__class__ and other.op == self.op - and obj_array_equal(other.field, self.field)) + and np.array_equal(other.field, self.field)) def get_hash(self): from pytools.obj_array import obj_array_to_hashable