From 7463a9d660d5b426e65f47ea8240acea7608e259 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 2 May 2020 16:32:03 -0500 Subject: [PATCH] function-registry: add atan2 --- grudge/function_registry.py | 58 +++++++++++++++++++++++++++++++---- grudge/symbolic/primitives.py | 1 + 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/grudge/function_registry.py b/grudge/function_registry.py index 111443c4..2f56eba4 100644 --- a/grudge/function_registry.py +++ b/grudge/function_registry.py @@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa from pytools import RecordWithoutPickling, memoize_in +# {{{ helpers + +def should_use_numpy(arg): + from numbers import Number + if isinstance(arg, Number) or \ + isinstance(arg, np.ndarray) and arg.shape == (): + return True + return False + + +def cl_to_numpy_function_name(name): + return { + "atan2": "arctan2", + }.get(name, name) + +# }}} + + # {{{ function class FunctionNotFound(KeyError): @@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function): def __call__(self, queue, arg): func_name = self.identifier - - from numbers import Number - if ( - isinstance(arg, Number) - or (isinstance(arg, np.ndarray) - and arg.shape == ())): + if should_use_numpy(arg): func = getattr(np, func_name) return func(arg) @@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function): return out +class CElementwiseBinaryFunction(Function): + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 2 + from pytools import single_valued + return single_valued(arg_dds) + + def __call__(self, queue, arg0, arg1): + func_name = self.identifier + if should_use_numpy(arg0) and should_use_numpy(arg1): + func = getattr(np, cl_to_numpy_function_name(func_name)) + return func(arg0, arg1) + + from pymbolic.primitives import Variable + @memoize_in(self, "map_call_knl_%s" % func_name) + def knl(): + i = Variable("i") + knl = lp.make_kernel( + "{[i]: 0<=i