diff --git a/grudge/execution.py b/grudge/execution.py index c20aa4bc6271f1b9a8f5620c8b3021c22563880c..9e48aba39cd6e3aa5481f0bf9f31f3340381e4f6 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -31,6 +31,7 @@ from pytools import memoize_in import grudge.symbolic.mappers as mappers from grudge import sym +from grudge.function_registry import base_function_registry import logging logger = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class ExecutionMapper(mappers.Evaluator, super(ExecutionMapper, self).__init__(context) self.discrwb = bound_op.discrwb self.bound_op = bound_op + self.function_registry = bound_op.function_registry self.queue = queue # {{{ expression mappings ------------------------------------------------- @@ -95,45 +97,8 @@ class ExecutionMapper(mappers.Evaluator, return value def map_call(self, expr): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) - - # FIXME: Make a way to register functions - args = [self.rec(p) for p in expr.parameters] - from numbers import Number - representative_arg = args[0] - if ( - isinstance(representative_arg, Number) - or (isinstance(representative_arg, np.ndarray) - and representative_arg.shape == ())): - func = getattr(np, expr.function.name) - return func(*args) - - cached_name = "map_call_knl_" - - i = Variable("i") - func = Variable(expr.function.name) - if expr.function.name == "fabs": # FIXME - func = Variable("abs") - cached_name += "abs" - else: - cached_name += expr.function.name - - @memoize_in(self.bound_op, cached_name) - def knl(): - knl = lp.make_kernel( - "{[i]: 0<=i<n}", - [ - lp.Assignment(Variable("out")[i], - func(Variable("a")[i])) - ], default_offset=lp.auto) - return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") - - assert len(args) == 1 - evt, (out,) = knl()(self.queue, a=args[0]) - - return out + return self.function_registry[expr.function.name](self.queue, *args) def map_nodal_sum(self, op, field_expr): # FIXME: Could allow array scalars @@ -473,12 +438,15 @@ class MPISendFuture(object): # {{{ bound operator class BoundOperator(object): - def __init__(self, discrwb, discr_code, eval_code, debug_flags, allocator=None): + + def __init__(self, discrwb, discr_code, eval_code, debug_flags, + function_registry, allocator=None): self.discrwb = discrwb self.discr_code = discr_code self.eval_code = eval_code self.operator_data_cache = {} self.debug_flags = debug_flags + self.function_registry = function_registry self.allocator = allocator def __str__(self): @@ -625,7 +593,8 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, def bind(discr, sym_operator, post_bind_mapper=lambda x: x, - debug_flags=set(), allocator=None): + debug_flags=set(), allocator=None, + function_registry=base_function_registry): # from grudge.symbolic.mappers import QuadratureUpsamplerRemover # sym_operator = QuadratureUpsamplerRemover(self.quad_min_degrees)( # sym_operator) @@ -648,9 +617,10 @@ def bind(discr, sym_operator, post_bind_mapper=lambda x: x, dumper=dump_sym_operator) from grudge.symbolic.compiler import OperatorCompiler - discr_code, eval_code = OperatorCompiler(discr)(sym_operator) + discr_code, eval_code = OperatorCompiler(discr, function_registry)(sym_operator) bound_op = BoundOperator(discr, discr_code, eval_code, + function_registry=function_registry, debug_flags=debug_flags, allocator=allocator) if "dump_op_code" in debug_flags: diff --git a/grudge/function_registry.py b/grudge/function_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6efa39da38a8bb30aead074ac7d96856ec70fc33 --- /dev/null +++ b/grudge/function_registry.py @@ -0,0 +1,198 @@ +from __future__ import division, with_statement + +__copyright__ = """ +Copyright (C) 2013 Andreas Kloeckner +Copyright (C) 2019 Matt Wala +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import loopy as lp +import numpy as np + +from pytools import RecordWithoutPickling, memoize_in + + +# {{{ function + +class FunctionNotFound(KeyError): + pass + + +class Function(RecordWithoutPickling): + """ + .. attribute:: identifier + .. attribute:: supports_codegen + .. automethod:: __call__ + .. automethod:: get_result_dofdesc + """ + + def __init__(self, identifier, **kwargs): + super(Function, self).__init__(identifier=identifier, **kwargs) + + def __call__(self, queue, *args, **kwargs): + """Call the function implementation, if available.""" + raise TypeError("function '%s' is not callable" % self.identifier) + + def get_result_dofdesc(self, arg_dds): + """Return the :class:`grudge.symbolic.primitives.DOFDesc` for the return value + of the function. + + :arg arg_dds: A list of :class:`grudge.symbolic.primitives.DOFDesc` instances + for each argument + """ + raise NotImplementedError + + +class CElementwiseUnaryFunction(Function): + + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 1 + return arg_dds[0] + + def __call__(self, queue, *args): + assert len(args) == 1 + + func_name = self.identifier + + arg, = args + from numbers import Number + if ( + isinstance(arg, Number) + or (isinstance(arg, np.ndarray) + and arg.shape == ())): + func = getattr(np, func_name) + return func(arg) + + cached_name = "map_call_knl_" + + from pymbolic.primitives import Variable + i = Variable("i") + + if self.identifier == "fabs": # FIXME + func_name = "abs" + + cached_name += func_name + + @memoize_in(self, cached_name) + def knl(): + knl = lp.make_kernel( + "{[i]: 0<=i<n}", + [ + lp.Assignment(Variable("out")[i], + Variable(func_name)(Variable("a")[i])) + ], default_offset=lp.auto) + return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") + + evt, (out,) = knl()(queue, a=arg) + return out + + +class CBesselFunction(Function): + + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 2 + return arg_dds[1] + + +class FixedDOFDescExternalFunction(Function): + + supports_codegen = False + + def __init__(self, identifier, implementation, dd): + super(FixedDOFDescExternalFunction, self).__init__( + identifier, + implementation=implementation, + dd=dd) + + def __call__(self, queue, *args, **kwargs): + return self.implementation(queue, *args, **kwargs) + + def get_result_dofdesc(self, arg_dds): + return self.dd + +# }}} + + +# {{{ function registry + +class FunctionRegistry(RecordWithoutPickling): + def __init__(self, id_to_function=None): + if id_to_function is None: + id_to_function = {} + + super(FunctionRegistry, self).__init__( + id_to_function=id_to_function) + + def register(self, function): + """Return a copy of *self* with *function* registered.""" + + if function.identifier in self.id_to_function: + raise ValueError("function '%s' is already registered" + % function.identifier) + + new_id_to_function = self.id_to_function.copy() + new_id_to_function[function.identifier] = function + return self.copy(id_to_function=new_id_to_function) + + def __getitem__(self, function_id): + try: + return self.id_to_function[function_id] + except KeyError: + raise FunctionNotFound( + "unknown function: '%s'" + % function_id) + + def __contains__(self, function_id): + return function_id in self.id_to_function + +# }}} + + +def _make_bfr(): + bfr = FunctionRegistry() + + bfr = bfr.register(CElementwiseUnaryFunction("sqrt")) + bfr = bfr.register(CElementwiseUnaryFunction("exp")) + bfr = bfr.register(CElementwiseUnaryFunction("fabs")) + bfr = bfr.register(CElementwiseUnaryFunction("sin")) + bfr = bfr.register(CElementwiseUnaryFunction("cos")) + bfr = bfr.register(CBesselFunction("bessel_j")) + bfr = bfr.register(CBesselFunction("bessel_y")) + + return bfr + + +base_function_registry = _make_bfr() + + +def register_external_function( + function_registry, identifier, implementation, dd): + return function_registry.register( + FixedDOFDescExternalFunction( + identifier, implementation, dd)) + +# vim: foldmethod=marker diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index f8dc173e771010ff597fd3831f83ffdccfabc0a4..c2750e33b8430fda4060a1f9b89e220c2651ed02 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -31,7 +31,7 @@ from six.moves import zip, reduce from pytools import Record, memoize_method, memoize from grudge import sym import grudge.symbolic.mappers as mappers -from pymbolic.primitives import Variable, Subscript +from pymbolic.primitives import Variable, Subscript, Call from six.moves import intern from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa: F401 @@ -447,7 +447,8 @@ class Code(object): available_insns = [ (insn, insn.priority) for insn in self.instructions if insn not in done_insns - and all(dep.name in available_names + and all((dep.aggregate.name if isinstance(dep, Subscript) + else dep.name) in available_names for dep in insn.get_dependencies())] if not available_insns: @@ -455,7 +456,8 @@ class Code(object): from pytools import flatten discardable_vars = set(available_names) - set(flatten( - [dep.name for dep in insn.get_dependencies()] + [dep.aggregate.name if isinstance(dep, Subscript) else dep.name + for dep in insn.get_dependencies()] for insn in self.instructions if insn not in done_insns)) @@ -607,6 +609,8 @@ def aggregate_assignments(inf_mapper, instructions, result, max_vectors_in_batch_expr): from pymbolic.primitives import Variable + function_registry = inf_mapper.function_registry + # {{{ aggregation helpers def get_complete_origins_set(insn, skip_levels=0): @@ -674,6 +678,7 @@ def aggregate_assignments(inf_mapper, instructions, result, isinstance(insn, Assign) and not isinstance(insn, ToDiscretizationScopedAssign) and not isinstance(insn, FromDiscretizationScopedAssign) + and not is_external_call(insn.exprs[0], function_registry) and not any( inf_mapper.infer_for_name(n).domain_tag == DTAG_SCALAR for n in insn.names)), @@ -824,9 +829,21 @@ def aggregate_assignments(inf_mapper, instructions, result, # {{{ to-loopy mapper +def is_external_call(expr, function_registry): + if not isinstance(expr, Call): + return False + return not is_function_loopyable(expr.function, function_registry) + + +def is_function_loopyable(function, function_registry): + assert isinstance(function, Variable) + return function_registry[function.name].supports_codegen + + class ToLoopyExpressionMapper(mappers.IdentityMapper): def __init__(self, dd_inference_mapper, temp_names, iname): self.dd_inference_mapper = dd_inference_mapper + self.function_registry = dd_inference_mapper.function_registry self.temp_names = temp_names self.iname = iname from pymbolic import var @@ -887,8 +904,9 @@ class ToLoopyExpressionMapper(mappers.IdentityMapper): "%s_%d" % (expr.aggregate.name, subscript)) def map_call(self, expr): - if isinstance(expr.function, sym.CFunction): + if is_function_loopyable(expr.function, self.function_registry): from pymbolic import var + func_name = expr.function.name if func_name == "fabs": func_name = "abs" @@ -968,11 +986,18 @@ def bessel_function_mangler(kernel, name, arg_dtypes): class ToLoopyInstructionMapper(object): def __init__(self, dd_inference_mapper): self.dd_inference_mapper = dd_inference_mapper + self.function_registry = dd_inference_mapper.function_registry self.insn_count = 0 def map_insn_assign(self, insn): from grudge.symbolic.primitives import OperatorBinding - if len(insn.exprs) == 1 and isinstance(insn.exprs[0], OperatorBinding): + + if ( + len(insn.exprs) == 1 + and ( + isinstance(insn.exprs[0], OperatorBinding) + or is_external_call( + insn.exprs[0], self.function_registry))): return insn iname = "grdg_i" @@ -1097,7 +1122,8 @@ class CodeGenerationState(Record): class OperatorCompiler(mappers.IdentityMapper): - def __init__(self, discr, prefix="_expr", max_vectors_in_batch_expr=None): + def __init__(self, discr, function_registry, + prefix="_expr", max_vectors_in_batch_expr=None): super(OperatorCompiler, self).__init__() self.prefix = prefix @@ -1113,6 +1139,7 @@ class OperatorCompiler(mappers.IdentityMapper): self.assigned_names = set() self.discr = discr + self.function_registry = function_registry from pytools import UniqueNameGenerator self.name_gen = UniqueNameGenerator() @@ -1146,7 +1173,8 @@ class OperatorCompiler(mappers.IdentityMapper): del self.discr_code from grudge.symbolic.dofdesc_inference import DOFDescInferenceMapper - inf_mapper = DOFDescInferenceMapper(discr_code + eval_code) + inf_mapper = DOFDescInferenceMapper( + discr_code + eval_code, self.function_registry) eval_code = aggregate_assignments( inf_mapper, eval_code, result, self.max_vectors_in_batch_expr) @@ -1266,21 +1294,19 @@ class OperatorCompiler(mappers.IdentityMapper): return result_var def map_call(self, expr, codegen_state): - from grudge.symbolic.primitives import CFunction - if isinstance(expr.function, CFunction): + if is_function_loopyable(expr.function, self.function_registry): return super(OperatorCompiler, self).map_call(expr, codegen_state) else: # If it's not a C-level function, it shouldn't get muddled up into # a vector math expression. - return self.assign_to_new_var( - codegen_state, - type(expr)( - expr.function, - [self.assign_to_new_var( - codegen_state, - self.rec(par, codegen_state)) - for par in expr.parameters])) + codegen_state, + type(expr)( + expr.function, + [self.assign_to_new_var( + codegen_state, + self.rec(par, codegen_state)) + for par in expr.parameters])) def map_ref_diff_op_binding(self, expr, codegen_state): try: diff --git a/grudge/symbolic/dofdesc_inference.py b/grudge/symbolic/dofdesc_inference.py index 92be126f7f3f081b668247e1fe25a73b122a4887..26e5e19a605e0268a01e3abd7d324d86b3a5b8c9 100644 --- a/grudge/symbolic/dofdesc_inference.py +++ b/grudge/symbolic/dofdesc_inference.py @@ -76,7 +76,8 @@ class InferrableMultiAssignment(object): class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): - def __init__(self, assignments, name_to_dofdesc=None, check=True): + def __init__(self, assignments, function_registry, + name_to_dofdesc=None, check=True): """ :arg assignments: a list of objects adhering to :class:`InferrableMultiAssignment`. @@ -98,6 +99,8 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): self.name_to_dofdesc = name_to_dofdesc + self.function_registry = function_registry + def infer_for_name(self, name): try: return self.name_to_dofdesc[name] @@ -186,7 +189,9 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): self.rec(par) for par in expr.parameters] - assert arg_dds + return ( + self.function_registry[expr.function.name] + .get_result_dofdesc(arg_dds)) # FIXME return arg_dds[0] diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 8b9bf73b7a7d5080ffe0652a52c0e1926681b749..bfb60879c911918f901f460f197d4004a38068ae 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -211,8 +211,6 @@ class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin): # it's a leaf--no changing children return expr - map_c_function = map_grudge_variable - map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable @@ -279,7 +277,7 @@ class FlopCounter( def map_grudge_variable(self, expr): return 0 - def map_c_function(self, expr): + def map_call(self, expr): return 1 def map_ones(self, expr): @@ -841,9 +839,6 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper): self.rec(expr.op, PREC_NONE), self.rec(expr.field, PREC_NONE)) - def map_c_function(self, expr, enclosing_prec): - return expr.name - def map_grudge_variable(self, expr, enclosing_prec): return "%s:%s" % (expr.name, self._format_dd(expr.dd)) @@ -1232,7 +1227,8 @@ class CollectorMixin(OperatorReducerMixin, LocalOpReducerMixin, FluxOpReducerMix return OrderedSet() map_grudge_variable = map_constant - map_c_function = map_grudge_variable + # Found in function call nodes + map_variable = map_grudge_variable map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index 231a70dfa374d8c0e60e2830ff13427f5d33c0dd..3c2dc96ef25af3254b9bac3e7399d9567223c97d 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -76,7 +76,6 @@ Symbols .. autoclass:: ScalarVariable .. autoclass:: make_sym_array .. autoclass:: make_sym_mv -.. autoclass:: CFunction .. function :: sqrt(arg) .. function :: exp(arg) @@ -289,7 +288,16 @@ class HasDOFDesc(object): discretization on which this property is given. """ - def __init__(self, dd): + def __init__(self, *args, **kwargs): + # The remaining arguments are passed to the chained superclass. + + if "dd" in kwargs: + dd = kwargs.pop("dd") + else: + dd = args[-1] + args = args[:-1] + + super(HasDOFDesc, self).__init__(*args, **kwargs) self.dd = dd def __getinitargs__(self): @@ -298,9 +306,7 @@ class HasDOFDesc(object): def with_dd(self, dd): """Return a copy of *self*, modified to the given DOF descriptor. """ - return type(self)( - *self.__getinitargs__()[:-1], - dd=dd or self.dd) + return type(self)(*self.__getinitargs__()) # }}} @@ -320,8 +326,7 @@ class Variable(HasDOFDesc, ExpressionBase, pymbolic.primitives.Variable): if dd is None: dd = DD_VOLUME - HasDOFDesc.__init__(self, dd) - pymbolic.primitives.Variable.__init__(self, name) + super(Variable, self).__init__(name, dd) def __getinitargs__(self): return (self.name, self.dd,) @@ -349,30 +354,14 @@ def make_sym_mv(name, dim, var_factory=None): make_sym_array(name, dim, var_factory)) -class CFunction(pymbolic.primitives.Variable): - """A symbol representing a C-level function, to be used as the function - argument of :class:`pymbolic.primitives.Call`. - """ - def stringifier(self): - from grudge.symbolic.mappers import StringifyMapper - return StringifyMapper - - def __call__(self, *exprs): - from pytools.obj_array import with_object_array_or_scalar_n_args - from functools import partial - return with_object_array_or_scalar_n_args( - partial(pymbolic.primitives.Expression.__call__, self), - *exprs) - - mapper_method = "map_c_function" - - -sqrt = CFunction("sqrt") -exp = CFunction("exp") -sin = CFunction("sin") -cos = CFunction("cos") -bessel_j = CFunction("bessel_j") -bessel_y = CFunction("bessel_y") +# function symbols +CFunction = Variable +sqrt = Variable("sqrt") +exp = Variable("exp") +sin = Variable("sin") +cos = Variable("cos") +bessel_j = Variable("bessel_j") +bessel_y = Variable("bessel_y") # }}} @@ -424,16 +413,13 @@ class PrioritizedSubexpression(pymbolic.primitives.CommonSubexpression): # }}} -class Ones(ExpressionBase, HasDOFDesc): - def __getinitargs__(self): - return () - +class Ones(HasDOFDesc, ExpressionBase): mapper_method = intern("map_ones") # {{{ geometry data -class DiscretizationProperty(ExpressionBase, HasDOFDesc): +class DiscretizationProperty(HasDOFDesc, ExpressionBase): pass diff --git a/test/test_grudge.py b/test/test_grudge.py index 97c5a012fce82c377373889ac6c053d9189db096..f646dff4cb8b1c54377b9580cc1baf6001e4c852 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -539,6 +539,41 @@ def test_bessel(ctx_factory): assert z < 1e-15 +def test_external_call(ctx_factory): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + def double(queue, x): + return 2 * x + + from meshmode.mesh.generation import generate_regular_rect_mesh + + dims = 2 + + mesh = generate_regular_rect_mesh(a=(0,) * dims, b=(1,) * dims, n=(4,) * dims) + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=1) + + ones = sym.Ones(sym.DD_VOLUME) + from pymbolic.primitives import Variable + op = ( + ones * 3 + + Variable("double")(ones)) + + from grudge.function_registry import ( + base_function_registry, register_external_function) + + freg = register_external_function( + base_function_registry, + "double", + implementation=double, + dd=sym.DD_VOLUME) + + bound_op = bind(discr, op, function_registry=freg) + + result = bound_op(queue, double=double) + assert (result == 5).get().all() + + # You can test individual routines by typing # $ python test_grudge.py 'test_routine()'