diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index 64d0839fff97ea52b921d93184f3ff2197c6b00d..7554ccc206a3e9f31f55615e07d028d545e0bce3 100755 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -43,6 +43,7 @@ import pymbolic.primitives as p import grudge.symbolic.mappers as gmap import grudge.symbolic.operators as op from grudge.execution import ExecutionMapper +from grudge.function_registry import base_function_registry from pymbolic.mapper.evaluator import EvaluationMapper \ as PymbolicEvaluationMapper from pytools import memoize @@ -246,6 +247,7 @@ class RK4TimeStepperBase(object): } def set_up_stepper(self, discr, field_var_name, sym_rhs, num_fields, + function_registry=base_function_registry, exec_mapper_factory=ExecutionMapper): dt_method = LSRK4Method(component_id=field_var_name) dt_code = dt_method.generate() @@ -270,7 +272,9 @@ class RK4TimeStepperBase(object): flattened_results = join_fields(output_t, output_dt, *output_states) self.bound_op = bind( - discr, flattened_results, exec_mapper_factory=exec_mapper_factory) + discr, flattened_results, + function_registry=function_registry, + exec_mapper_factory=exec_mapper_factory) def run(self, fields, t_start, dt, t_end, return_profile_data=False): context = self.get_initial_context(fields, t_start, dt) @@ -327,9 +331,9 @@ class RK4TimeStepper(RK4TimeStepperBase): # Construct sym_rhs to have the effect of replacing the RHS calls in the # dagrt code with calls of the grudge operator. - from grudge.symbolic.primitives import ExternalCall, Variable + from grudge.symbolic.primitives import FunctionSymbol, Variable call = sym.cse(ExternalCall( - var("grudge_op"), + FunctionSymbol("grudge_op"), ( (Variable("t", dd=sym.DD_SCALAR),) + tuple( @@ -342,17 +346,29 @@ class RK4TimeStepper(RK4TimeStepperBase): self.queue = queue self.grudge_bound_op = grudge_bound_op + + from dagrt.function_registry import register_external_function + + freg = register_external_function( + base_function_registry, + "grudge_op", + implementation=self._bound_op, + dd=sym.DD_VOLUME) + self.set_up_stepper( - discr, field_var_name, sym_rhs, num_fields, exec_mapper_factory) + discr, field_var_name, sym_rhs, num_fields, + freg, + exec_mapper_factory) + self.component_getter = component_getter - def _bound_op(self, t, *args, profile_data=None): + def _bound_op(self, queue, t, *args, profile_data=None): from pytools.obj_array import join_fields context = { "t": t, self.field_var_name: join_fields(*args)} result = self.grudge_bound_op( - self.queue, profile_data=profile_data, **context) + queue, profile_data=profile_data, **context) if profile_data is not None: result = result[0] return result @@ -369,7 +385,9 @@ class FusedRK4TimeStepper(RK4TimeStepperBase): component_getter, exec_mapper_factory=ExecutionMapper): super().__init__(queue, component_getter) self.set_up_stepper( - discr, field_var_name, sym_rhs, num_fields, exec_mapper_factory) + discr, field_var_name, sym_rhs, num_fields, + base_function_registry, + exec_mapper_factory) # }}} @@ -477,17 +495,16 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapper): def __init__(self, queue, context, bound_op): super().__init__(queue, context, bound_op) - def map_external_call(self, expr): + def map_call(self, expr): # Should have been caught by our op counter - assert False, ("map_external_call called: %s" % expr) + assert False, ("map_call called: %s" % expr) # {{{ expressions - def map_profiled_external_call(self, expr, profile_data): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) + def map_profiled_call(self, expr, profile_data): args = [self.rec(p) for p in expr.parameters] - return self.context[expr.function.name](*args, profile_data=profile_data) + return self.function_registry[expr.function.name]( + self.queue, *args, profile_data=profile_data) def map_profiled_essentially_elementwise_linear(self, op, field_expr, profile_data): @@ -517,9 +534,9 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapper): # {{{ instruction mappings def process_assignment_expr(self, expr, profile_data): - if isinstance(expr, sym.ExternalCall): - assert expr.mapper_method == "map_external_call" - val = self.map_profiled_external_call(expr, profile_data) + if isinstance(expr, p.Call): + assert expr.mapper_method == "map_call" + val = self.map_profiled_call(expr, profile_data) elif isinstance(expr, sym.OperatorBinding): if isinstance( @@ -774,19 +791,18 @@ def time_insn(f): class ExecutionMapperWithTiming(ExecutionMapper): - def map_external_call(self, expr): + def map_call(self, expr): # Should have been caught by our implementation. - assert False, ("map_external_call called: %s" % (expr)) + assert False, ("map_call called: %s" % (expr)) def map_operator_binding(self, expr): # Should have been caught by our implementation. assert False, ("map_operator_binding called: %s" % expr) - def map_profiled_external_call(self, expr, profile_data): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) + def map_profiled_call(self, expr, profile_data): args = [self.rec(p) for p in expr.parameters] - return self.context[expr.function.name](*args, profile_data=profile_data) + return self.function_registry[expr.function.name]( + self.queue, *args, profile_data=profile_data) def map_profiled_operator_binding(self, expr, profile_data): if profile_data is None: @@ -808,9 +824,9 @@ class ExecutionMapperWithTiming(ExecutionMapper): def map_insn_assign(self, insn, profile_data): if len(insn.exprs) == 1: - if isinstance(insn.exprs[0], sym.ExternalCall): - assert insn.exprs[0].mapper_method == "map_external_call" - val = self.map_profiled_external_call(insn.exprs[0], profile_data) + if isinstance(insn.exprs[0], p.Call): + assert insn.exprs[0].mapper_method == "map_call" + val = self.map_profiled_call(insn.exprs[0], profile_data) return [(insn.names[0], val)], [] elif isinstance(insn.exprs[0], sym.OperatorBinding): assert insn.exprs[0].mapper_method == "map_operator_binding" diff --git a/grudge/execution.py b/grudge/execution.py index d215e0bb943a31cedea2c1d5c4dbcf55f1b4978c..66e483f33998340ff1b6b1f0839a73acaf5102e5 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -31,6 +31,7 @@ from pytools import memoize_in import grudge.symbolic.mappers as mappers from grudge import sym +from grudge.function_registry import base_function_registry import logging logger = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class ExecutionMapper(mappers.Evaluator, super(ExecutionMapper, self).__init__(context) self.discrwb = bound_op.discrwb self.bound_op = bound_op + self.function_registry = bound_op.function_registry self.queue = queue # {{{ expression mappings ------------------------------------------------- @@ -94,56 +96,9 @@ class ExecutionMapper(mappers.Evaluator, value = ary return value - def map_external_call(self, expr): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) - args = [self.rec(p) for p in expr.parameters] - - return self.context[expr.function.name](*args) - def map_call(self, expr): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) - args = [self.rec(p) for p in expr.parameters] - - # Function lookup precedence: - # * Numpy functions - # * OpenCL functions - - from numbers import Number - representative_arg = args[0] - if ( - isinstance(representative_arg, Number) - or (isinstance(representative_arg, np.ndarray) - and representative_arg.shape == ())): - func = getattr(np, expr.function.name) - return func(*args) - - cached_name = "map_call_knl_" - - i = Variable("i") - func = Variable(expr.function.name) - if expr.function.name == "fabs": # FIXME - func = Variable("abs") - cached_name += "abs" - else: - cached_name += expr.function.name - - @memoize_in(self.bound_op, cached_name) - def knl(): - knl = lp.make_kernel( - "{[i]: 0<=i<n}", - [ - lp.Assignment(Variable("out")[i], - func(Variable("a")[i])) - ], default_offset=lp.auto) - return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") - - assert len(args) == 1 - evt, (out,) = knl()(self.queue, a=args[0]) - - return out + return self.function_registry[expr.function.name](self.queue, *args) def map_nodal_sum(self, op, field_expr): # FIXME: Could allow array scalars @@ -483,13 +438,15 @@ class MPISendFuture(object): # {{{ bound operator class BoundOperator(object): + def __init__(self, discrwb, discr_code, eval_code, debug_flags, - allocator=None, exec_mapper_factory=ExecutionMapper): + function_registry, exec_mapper_factory, allocator=None): self.discrwb = discrwb self.discr_code = discr_code self.eval_code = eval_code self.operator_data_cache = {} self.debug_flags = debug_flags + self.function_registry = function_registry self.allocator = allocator self.exec_mapper_factory = exec_mapper_factory @@ -636,8 +593,10 @@ def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, # }}} -def bind(discr, sym_operator, post_bind_mapper=lambda x: x, debug_flags=set(), - allocator=None, exec_mapper_factory=ExecutionMapper): +def bind(discr, sym_operator, post_bind_mapper=lambda x: x, + function_registry=base_function_registry, + exec_mapper_factory=ExecutionMapper, + debug_flags=frozenset(), allocator=None): # from grudge.symbolic.mappers import QuadratureUpsamplerRemover # sym_operator = QuadratureUpsamplerRemover(self.quad_min_degrees)( # sym_operator) @@ -660,12 +619,14 @@ def bind(discr, sym_operator, post_bind_mapper=lambda x: x, debug_flags=set(), dumper=dump_sym_operator) from grudge.symbolic.compiler import OperatorCompiler - discr_code, eval_code = OperatorCompiler(discr)(sym_operator) + discr_code, eval_code = OperatorCompiler(discr, function_registry)(sym_operator) bound_op = BoundOperator(discr, discr_code, eval_code, - debug_flags=debug_flags, allocator=allocator, - exec_mapper_factory=exec_mapper_factory) - + function_registry=function_registry, + exec_mapper_factory=exec_mapper_factory, + debug_flags=debug_flags, + allocator=allocator) + if "dump_op_code" in debug_flags: from grudge.tools import open_unique_debug_file outf, _ = open_unique_debug_file("op-code", ".txt") diff --git a/grudge/function_registry.py b/grudge/function_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6efa39da38a8bb30aead074ac7d96856ec70fc33 --- /dev/null +++ b/grudge/function_registry.py @@ -0,0 +1,198 @@ +from __future__ import division, with_statement + +__copyright__ = """ +Copyright (C) 2013 Andreas Kloeckner +Copyright (C) 2019 Matt Wala +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import loopy as lp +import numpy as np + +from pytools import RecordWithoutPickling, memoize_in + + +# {{{ function + +class FunctionNotFound(KeyError): + pass + + +class Function(RecordWithoutPickling): + """ + .. attribute:: identifier + .. attribute:: supports_codegen + .. automethod:: __call__ + .. automethod:: get_result_dofdesc + """ + + def __init__(self, identifier, **kwargs): + super(Function, self).__init__(identifier=identifier, **kwargs) + + def __call__(self, queue, *args, **kwargs): + """Call the function implementation, if available.""" + raise TypeError("function '%s' is not callable" % self.identifier) + + def get_result_dofdesc(self, arg_dds): + """Return the :class:`grudge.symbolic.primitives.DOFDesc` for the return value + of the function. + + :arg arg_dds: A list of :class:`grudge.symbolic.primitives.DOFDesc` instances + for each argument + """ + raise NotImplementedError + + +class CElementwiseUnaryFunction(Function): + + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 1 + return arg_dds[0] + + def __call__(self, queue, *args): + assert len(args) == 1 + + func_name = self.identifier + + arg, = args + from numbers import Number + if ( + isinstance(arg, Number) + or (isinstance(arg, np.ndarray) + and arg.shape == ())): + func = getattr(np, func_name) + return func(arg) + + cached_name = "map_call_knl_" + + from pymbolic.primitives import Variable + i = Variable("i") + + if self.identifier == "fabs": # FIXME + func_name = "abs" + + cached_name += func_name + + @memoize_in(self, cached_name) + def knl(): + knl = lp.make_kernel( + "{[i]: 0<=i<n}", + [ + lp.Assignment(Variable("out")[i], + Variable(func_name)(Variable("a")[i])) + ], default_offset=lp.auto) + return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") + + evt, (out,) = knl()(queue, a=arg) + return out + + +class CBesselFunction(Function): + + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 2 + return arg_dds[1] + + +class FixedDOFDescExternalFunction(Function): + + supports_codegen = False + + def __init__(self, identifier, implementation, dd): + super(FixedDOFDescExternalFunction, self).__init__( + identifier, + implementation=implementation, + dd=dd) + + def __call__(self, queue, *args, **kwargs): + return self.implementation(queue, *args, **kwargs) + + def get_result_dofdesc(self, arg_dds): + return self.dd + +# }}} + + +# {{{ function registry + +class FunctionRegistry(RecordWithoutPickling): + def __init__(self, id_to_function=None): + if id_to_function is None: + id_to_function = {} + + super(FunctionRegistry, self).__init__( + id_to_function=id_to_function) + + def register(self, function): + """Return a copy of *self* with *function* registered.""" + + if function.identifier in self.id_to_function: + raise ValueError("function '%s' is already registered" + % function.identifier) + + new_id_to_function = self.id_to_function.copy() + new_id_to_function[function.identifier] = function + return self.copy(id_to_function=new_id_to_function) + + def __getitem__(self, function_id): + try: + return self.id_to_function[function_id] + except KeyError: + raise FunctionNotFound( + "unknown function: '%s'" + % function_id) + + def __contains__(self, function_id): + return function_id in self.id_to_function + +# }}} + + +def _make_bfr(): + bfr = FunctionRegistry() + + bfr = bfr.register(CElementwiseUnaryFunction("sqrt")) + bfr = bfr.register(CElementwiseUnaryFunction("exp")) + bfr = bfr.register(CElementwiseUnaryFunction("fabs")) + bfr = bfr.register(CElementwiseUnaryFunction("sin")) + bfr = bfr.register(CElementwiseUnaryFunction("cos")) + bfr = bfr.register(CBesselFunction("bessel_j")) + bfr = bfr.register(CBesselFunction("bessel_y")) + + return bfr + + +base_function_registry = _make_bfr() + + +def register_external_function( + function_registry, identifier, implementation, dd): + return function_registry.register( + FixedDOFDescExternalFunction( + identifier, implementation, dd)) + +# vim: foldmethod=marker diff --git a/grudge/models/em.py b/grudge/models/em.py index f0e44f903d226e28710709651e948293a534d1b8..bf7495e2c9447b54d37d49a78905e549d9df54d8 100644 --- a/grudge/models/em.py +++ b/grudge/models/em.py @@ -355,7 +355,7 @@ class MaxwellOperator(HyperbolicOperator): return 1/sqrt(self.epsilon*self.mu) # a number else: import grudge.symbolic as sym - return sym.NodalMax()(1/sym.CFunction("sqrt")(self.epsilon*self.mu)) + return sym.NodalMax()(1/sym.FunctionSymbol("sqrt")(self.epsilon*self.mu)) def max_eigenvalue(self, t, fields=None, discr=None, context={}): if self.fixed_material: diff --git a/grudge/models/gas_dynamics/__init__.py b/grudge/models/gas_dynamics/__init__.py index e5a8ddc9418f1bac4776f252d2951fab3d3c1a2b..4df408d811a826a1f924b855f0ff8a9fce8a29f2 100644 --- a/grudge/models/gas_dynamics/__init__.py +++ b/grudge/models/gas_dynamics/__init__.py @@ -326,8 +326,8 @@ class GasDynamicsOperator(TimeDependentOperator): def characteristic_velocity_optemplate(self, state): from grudge.symbolic.operators import ElementwiseMaxOperator - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") sound_speed = cse(sqrt( self.equation_of_state.gamma*self.cse_p(state)/self.cse_rho(state)), @@ -743,8 +743,8 @@ class GasDynamicsOperator(TimeDependentOperator): volq_flux = self.flux(self.volq_state()) faceq_flux = self.flux(self.faceq_state()) - from grudge.symbolic.primitives import CFunction - sqrt = CFunction("sqrt") + from grudge.symbolic.primitives import FunctionSymbol + sqrt = FunctionSymbol("sqrt") speed = self.characteristic_velocity_optemplate(self.state()) diff --git a/grudge/models/wave.py b/grudge/models/wave.py index 40710eb2363b7e475f1be299f1b69540343309be..0ac60a336a4f72f12ba4f3eeb0bc2d79c9f07509 100644 --- a/grudge/models/wave.py +++ b/grudge/models/wave.py @@ -459,7 +459,7 @@ class VariableCoefficientWeakWaveOperator(HyperbolicOperator): self.radiation_tag]) def max_eigenvalue(self, t, fields=None, discr=None): - return sym.NodalMax()(sym.CFunction("fabs")(self.c)) + return sym.NodalMax()(sym.FunctionSymbol("fabs")(self.c)) # }}} diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index b60fd6d1f8856be0ebc3e2c5c62da31460b89e5f..3aac28bb3c3a8f74d6db255863c811314d464d36 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -606,6 +606,8 @@ def aggregate_assignments(inf_mapper, instructions, result, max_vectors_in_batch_expr): from pymbolic.primitives import Variable + function_registry = inf_mapper.function_registry + # {{{ aggregation helpers def get_complete_origins_set(insn, skip_levels=0): @@ -666,14 +668,14 @@ def aggregate_assignments(inf_mapper, instructions, result, for assignee in insn.get_assignees()) from pytools import partition - from grudge.symbolic.primitives import DTAG_SCALAR, ExternalCall + from grudge.symbolic.primitives import DTAG_SCALAR unprocessed_assigns, other_insns = partition( lambda insn: ( isinstance(insn, Assign) and not isinstance(insn, ToDiscretizationScopedAssign) and not isinstance(insn, FromDiscretizationScopedAssign) - and not isinstance(insn.exprs[0], ExternalCall) + and not is_external_call(insn.exprs[0], function_registry) and not any( inf_mapper.infer_for_name(n).domain_tag == DTAG_SCALAR for n in insn.names)), @@ -824,9 +826,23 @@ def aggregate_assignments(inf_mapper, instructions, result, # {{{ to-loopy mapper +def is_external_call(expr, function_registry): + from pymbolic.primitives import Call + if not isinstance(expr, Call): + return False + return not is_function_loopyable(expr.function, function_registry) + + +def is_function_loopyable(function, function_registry): + from grudge.symbolic.primitives import FunctionSymbol + assert isinstance(function, FunctionSymbol) + return function_registry[function.name].supports_codegen + + class ToLoopyExpressionMapper(mappers.IdentityMapper): def __init__(self, dd_inference_mapper, temp_names, iname): self.dd_inference_mapper = dd_inference_mapper + self.function_registry = dd_inference_mapper.function_registry self.temp_names = temp_names self.iname = iname from pymbolic import var @@ -886,13 +902,10 @@ class ToLoopyExpressionMapper(mappers.IdentityMapper): expr, "%s_%d" % (expr.aggregate.name, subscript)) - def map_external_call(self, expr): - raise ValueError( - "Cannot map external call '%s' into loopy" % expr.function) - def map_call(self, expr): - if isinstance(expr.function, sym.CFunction): + if is_function_loopyable(expr.function, self.function_registry): from pymbolic import var + func_name = expr.function.name if func_name == "fabs": func_name = "abs" @@ -972,14 +985,18 @@ def bessel_function_mangler(kernel, name, arg_dtypes): class ToLoopyInstructionMapper(object): def __init__(self, dd_inference_mapper): self.dd_inference_mapper = dd_inference_mapper + self.function_registry = dd_inference_mapper.function_registry self.insn_count = 0 def map_insn_assign(self, insn): - from grudge.symbolic.primitives import OperatorBinding, ExternalCall + from grudge.symbolic.primitives import OperatorBinding if ( len(insn.exprs) == 1 - and isinstance(insn.exprs[0], (OperatorBinding, ExternalCall))): + and ( + isinstance(insn.exprs[0], OperatorBinding) + or is_external_call( + insn.exprs[0], self.function_registry))): return insn iname = "grdg_i" @@ -1104,7 +1121,8 @@ class CodeGenerationState(Record): class OperatorCompiler(mappers.IdentityMapper): - def __init__(self, discr, prefix="_expr", max_vectors_in_batch_expr=None): + def __init__(self, discr, function_registry, + prefix="_expr", max_vectors_in_batch_expr=None): super(OperatorCompiler, self).__init__() self.prefix = prefix @@ -1120,6 +1138,7 @@ class OperatorCompiler(mappers.IdentityMapper): self.assigned_names = set() self.discr = discr + self.function_registry = function_registry from pytools import UniqueNameGenerator self.name_gen = UniqueNameGenerator() @@ -1153,7 +1172,8 @@ class OperatorCompiler(mappers.IdentityMapper): del self.discr_code from grudge.symbolic.dofdesc_inference import DOFDescInferenceMapper - inf_mapper = DOFDescInferenceMapper(discr_code + eval_code) + inf_mapper = DOFDescInferenceMapper( + discr_code + eval_code, self.function_registry) eval_code = aggregate_assignments( inf_mapper, eval_code, result, self.max_vectors_in_batch_expr) @@ -1272,20 +1292,8 @@ class OperatorCompiler(mappers.IdentityMapper): prefix=name_hint) return result_var - def map_external_call(self, expr, codegen_state): - return self.assign_to_new_var( - codegen_state, - type(expr)( - expr.function, - [self.assign_to_new_var( - codegen_state, - self.rec(par, codegen_state)) - for par in expr.parameters], - expr.dd)) - def map_call(self, expr, codegen_state): - from grudge.symbolic.primitives import CFunction - if isinstance(expr.function, CFunction): + if is_function_loopyable(expr.function, self.function_registry): return super(OperatorCompiler, self).map_call(expr, codegen_state) else: # If it's not a C-level function, it shouldn't get muddled up into diff --git a/grudge/symbolic/dofdesc_inference.py b/grudge/symbolic/dofdesc_inference.py index 96ead0885d2e300e7e87fed4395ec74ff771ccb2..c44a7940b9ec142171a0b0a67885c041ec9d43cb 100644 --- a/grudge/symbolic/dofdesc_inference.py +++ b/grudge/symbolic/dofdesc_inference.py @@ -76,7 +76,8 @@ class InferrableMultiAssignment(object): class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): - def __init__(self, assignments, name_to_dofdesc=None, check=True): + def __init__(self, assignments, function_registry, + name_to_dofdesc=None, check=True): """ :arg assignments: a list of objects adhering to :class:`InferrableMultiAssignment`. @@ -98,6 +99,8 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): self.name_to_dofdesc = name_to_dofdesc + self.function_registry = function_registry + def infer_for_name(self, name): try: return self.name_to_dofdesc[name] @@ -186,13 +189,9 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin): self.rec(par) for par in expr.parameters] - assert arg_dds - - # FIXME - return arg_dds[0] - - def map_external_call(self, expr): - return expr.dd + return ( + self.function_registry[expr.function.name] + .get_result_dofdesc(arg_dds)) # }}} diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 23ff1e50fd41512adabebf92cbeb00664baa0b5b..8266db938bdb89bc2dde0bce52b96c55474de9b2 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -211,17 +211,10 @@ class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin): # it's a leaf--no changing children return expr - map_c_function = map_grudge_variable - + map_function_symbol = map_grudge_variable map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable - def map_external_call(self, expr, *args, **kwargs): - return type(expr)( - self.rec(expr.function, *args, **kwargs), - self.rec(expr.parameters, *args, **kwargs), - dd=expr.dd) - # }}} @@ -275,15 +268,6 @@ class DependencyMapper( map_ones = _map_leaf map_node_coordinate_component = _map_leaf - def map_external_call(self, expr): - result = self.map_call(expr) - if self.include_calls == "descend_args": - # Unlike regular calls, we regard the function as an argument, - # because it's user-supplied (and thus we need to pick it up as a - # dependency). - result = self.combine((result, self.rec(expr.function))) - return result - class FlopCounter( CombineMapperMixin, @@ -294,10 +278,7 @@ class FlopCounter( def map_grudge_variable(self, expr): return 0 - def map_c_function(self, expr): - return 1 - - def map_external_call(self, expr): + def map_function_symbol(self, expr): return 1 def map_ones(self, expr): @@ -859,24 +840,15 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper): self.rec(expr.op, PREC_NONE), self.rec(expr.field, PREC_NONE)) - def map_c_function(self, expr, enclosing_prec): - return expr.name - def map_grudge_variable(self, expr, enclosing_prec): return "%s:%s" % (expr.name, self._format_dd(expr.dd)) + def map_function_symbol(self, expr, enclosing_prec): + return expr + def map_interpolation(self, expr, enclosing_prec): return "Interp" + self._format_op_dd(expr) - def map_external_call(self, expr, enclosing_prec): - from pymbolic.mapper.stringifier import PREC_CALL, PREC_NONE - return ( - self.parenthesize_if_needed( - "External:%s:%s" % ( - self.map_call(expr, PREC_NONE), self._format_dd(expr.dd)), - enclosing_prec, - PREC_CALL)) - class PrettyStringifyMapper( pymbolic.mapper.stringifier.CSESplittingStringifyMapperMixin, @@ -1260,7 +1232,7 @@ class CollectorMixin(OperatorReducerMixin, LocalOpReducerMixin, FluxOpReducerMix map_variable = map_constant map_grudge_variable = map_constant - map_c_function = map_grudge_variable + map_function_symbol = map_constant map_ones = map_grudge_variable map_node_coordinate_component = map_grudge_variable @@ -1277,9 +1249,6 @@ class BoundOperatorCollector(CSECachingMapperMixin, CollectorMixin, CombineMappe def __init__(self, op_class): self.op_class = op_class - def map_external_call(self, expr): - return self.map_call(expr) - map_common_subexpression_uncached = \ CombineMapper.map_common_subexpression @@ -1331,13 +1300,6 @@ class SymbolicEvaluator(pymbolic.mapper.evaluator.EvaluationMapper): for key, val in six.iteritems(expr.kw_parameters)) ) - def map_external_call(self, expr, *args, **kwargs): - return type(expr)( - expr.function, - tuple(self.rec(child, *args, **kwargs) - for child in expr.parameters), - expr.dd) - def map_common_subexpression(self, expr): return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index 53fb1422619db271786abb9c2c17df5b3f9a97c0..a1899ed21ef6431f7b57a81fcf143680433b7b53 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -602,16 +602,16 @@ def norm(p, arg, dd=None): if p == 2: norm_squared = sym.NodalSum(dd_in=dd)( - sym.CFunction("fabs")( + sym.FunctionSymbol("fabs")( arg * sym.MassOperator()(arg))) if isinstance(norm_squared, np.ndarray): norm_squared = norm_squared.sum() - return sym.CFunction("sqrt")(norm_squared) + return sym.FunctionSymbol("sqrt")(norm_squared) elif p == np.Inf: - result = sym.NodalMax(dd_in=dd)(sym.CFunction("fabs")(arg)) + result = sym.NodalMax(dd_in=dd)(sym.FunctionSymbol("fabs")(arg)) from pymbolic.primitives import Max if isinstance(result, np.ndarray): diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index d35da7c00648c63fdfa649544439114e5d184dfd..c59f1518b5034a9b69c3aca4a22c0c4bab92d7a6 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -74,10 +74,8 @@ Symbols .. autoclass:: Variable .. autoclass:: ScalarVariable -.. autoclass:: ExternalCall .. autoclass:: make_sym_array .. autoclass:: make_sym_mv -.. autoclass:: CFunction .. function :: sqrt(arg) .. function :: exp(arg) @@ -344,18 +342,6 @@ class ScalarVariable(Variable): super(ScalarVariable, self).__init__(name, DD_SCALAR) -class ExternalCall(HasDOFDesc, ExpressionBase, pymbolic.primitives.Call): - """A call to a user-supplied function with a :class:`DOFDesc`. - """ - - init_arg_names = ("function", "parameters", "dd") - - def __getinitargs__(self): - return (self.function, self.parameters, self.dd) - - mapper_method = "map_external_call" - - def make_sym_array(name, shape, dd=None): def var_factory(name): return Variable(name, dd) @@ -368,27 +354,21 @@ def make_sym_mv(name, dim, var_factory=None): make_sym_array(name, dim, var_factory)) -class CFunction(ExpressionBase, pymbolic.primitives.Variable): - """A symbol representing a C-level function, to be used as the function - argument of :class:`pymbolic.primitives.Call`. - """ +class FunctionSymbol(ExpressionBase, pymbolic.primitives.Variable): + """A symbol to be used as the function argument of + :class:`pymbolic.primitives.Call`. - def __call__(self, *exprs): - from pytools.obj_array import with_object_array_or_scalar_n_args - from functools import partial - return with_object_array_or_scalar_n_args( - partial(pymbolic.primitives.Expression.__call__, self), - *exprs) + """ - mapper_method = "map_c_function" + mapper_method = "map_function_symbol" -sqrt = CFunction("sqrt") -exp = CFunction("exp") -sin = CFunction("sin") -cos = CFunction("cos") -bessel_j = CFunction("bessel_j") -bessel_y = CFunction("bessel_y") +sqrt = FunctionSymbol("sqrt") +exp = FunctionSymbol("exp") +sin = FunctionSymbol("sin") +cos = FunctionSymbol("cos") +bessel_j = FunctionSymbol("bessel_j") +bessel_y = FunctionSymbol("bessel_y") # }}} diff --git a/test/test_grudge.py b/test/test_grudge.py index 6a92fb61bf837d46141c716c76dcd5fc8e343b84..288922e07529aa70bffbc1f812efe582e146ff60 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -539,11 +539,11 @@ def test_bessel(ctx_factory): assert z < 1e-15 -def test_ExternalCall(ctx_factory): # noqa +def test_external_call(ctx_factory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) - def double(x): + def double(queue, x): return 2 * x from meshmode.mesh.generation import generate_regular_rect_mesh @@ -554,15 +554,20 @@ def test_ExternalCall(ctx_factory): # noqa discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=1) ones = sym.Ones(sym.DD_VOLUME) - from pymbolic.primitives import Variable op = ( ones * 3 - + sym.ExternalCall( - Variable("double"), - (ones,), - sym.DD_VOLUME)) + + sym.FunctionSymbol("double")(ones)) - bound_op = bind(discr, op) + from grudge.function_registry import ( + base_function_registry, register_external_function) + + freg = register_external_function( + base_function_registry, + "double", + implementation=double, + dd=sym.DD_VOLUME) + + bound_op = bind(discr, op, function_registry=freg) result = bound_op(queue, double=double) assert (result == 5).get().all() diff --git a/unported-examples/gas_dynamics/lbm-simple.py b/unported-examples/gas_dynamics/lbm-simple.py index f99d5b76994b46f079b41af9f74bf69789ae18f5..2d3496da340fbda01adeb33b6a59290b76b1c698 100644 --- a/unported-examples/gas_dynamics/lbm-simple.py +++ b/unported-examples/gas_dynamics/lbm-simple.py @@ -63,12 +63,12 @@ def main(write_output=True, dtype=np.float32): from grudge.data import CompiledExpressionData def ic_expr(t, x, fields): - from grudge.symbolic import CFunction + from grudge.symbolic import FunctionSymbol from pymbolic.primitives import IfPositive from pytools.obj_array import make_obj_array - tanh = CFunction("tanh") - sin = CFunction("sin") + tanh = FunctionSymbol("tanh") + sin = FunctionSymbol("sin") rho = 1 u0 = 0.05 diff --git a/unported-examples/wave/wiggly.py b/unported-examples/wave/wiggly.py index f786074443d2123882bcaecf85972777b6694622..11c0121214982b70fb891b443dfec267ce5defb6 100644 --- a/unported-examples/wave/wiggly.py +++ b/unported-examples/wave/wiggly.py @@ -70,8 +70,8 @@ def main(write_output=True, from grudge.models.wave import StrongWaveOperator op = StrongWaveOperator(-1, discr.dimensions, source_f= - sym.CFunction("sin")(source_omega*sym.ScalarParameter("t")) - * sym.CFunction("exp")( + sym.FunctionSymbol("sin")(source_omega*sym.ScalarParameter("t")) + * sym.FunctionSymbol("exp")( -np.dot(sym_source_center_dist, sym_source_center_dist) / source_width**2), dirichlet_tag="boundary",