diff --git a/grudge/function_registry.py b/grudge/function_registry.py index 111443c41a7c12f155222bb06bb0ac636c0bde87..2f56eba4ad26d066039f5ac373619ac4ac10a656 100644 --- a/grudge/function_registry.py +++ b/grudge/function_registry.py @@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa from pytools import RecordWithoutPickling, memoize_in +# {{{ helpers + +def should_use_numpy(arg): + from numbers import Number + if isinstance(arg, Number) or \ + isinstance(arg, np.ndarray) and arg.shape == (): + return True + return False + + +def cl_to_numpy_function_name(name): + return { + "atan2": "arctan2", + }.get(name, name) + +# }}} + + # {{{ function class FunctionNotFound(KeyError): @@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function): def __call__(self, queue, arg): func_name = self.identifier - - from numbers import Number - if ( - isinstance(arg, Number) - or (isinstance(arg, np.ndarray) - and arg.shape == ())): + if should_use_numpy(arg): func = getattr(np, func_name) return func(arg) @@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function): return out +class CElementwiseBinaryFunction(Function): + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 2 + from pytools import single_valued + return single_valued(arg_dds) + + def __call__(self, queue, arg0, arg1): + func_name = self.identifier + if should_use_numpy(arg0) and should_use_numpy(arg1): + func = getattr(np, cl_to_numpy_function_name(func_name)) + return func(arg0, arg1) + + from pymbolic.primitives import Variable + @memoize_in(self, "map_call_knl_%s" % func_name) + def knl(): + i = Variable("i") + knl = lp.make_kernel( + "{[i]: 0<=i