diff --git a/grudge/function_registry.py b/grudge/function_registry.py index 111443c41a7c12f155222bb06bb0ac636c0bde87..2f56eba4ad26d066039f5ac373619ac4ac10a656 100644 --- a/grudge/function_registry.py +++ b/grudge/function_registry.py @@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa from pytools import RecordWithoutPickling, memoize_in +# {{{ helpers + +def should_use_numpy(arg): + from numbers import Number + if isinstance(arg, Number) or \ + isinstance(arg, np.ndarray) and arg.shape == (): + return True + return False + + +def cl_to_numpy_function_name(name): + return { + "atan2": "arctan2", + }.get(name, name) + +# }}} + + # {{{ function class FunctionNotFound(KeyError): @@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function): def __call__(self, queue, arg): func_name = self.identifier - - from numbers import Number - if ( - isinstance(arg, Number) - or (isinstance(arg, np.ndarray) - and arg.shape == ())): + if should_use_numpy(arg): func = getattr(np, func_name) return func(arg) @@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function): return out +class CElementwiseBinaryFunction(Function): + supports_codegen = True + + def get_result_dofdesc(self, arg_dds): + assert len(arg_dds) == 2 + from pytools import single_valued + return single_valued(arg_dds) + + def __call__(self, queue, arg0, arg1): + func_name = self.identifier + if should_use_numpy(arg0) and should_use_numpy(arg1): + func = getattr(np, cl_to_numpy_function_name(func_name)) + return func(arg0, arg1) + + from pymbolic.primitives import Variable + @memoize_in(self, "map_call_knl_%s" % func_name) + def knl(): + i = Variable("i") + knl = lp.make_kernel( + "{[i]: 0<=i<n}", + [ + lp.Assignment( + Variable("out")[i], + Variable(func_name)(Variable("a")[i], Variable("b")[i])) + ], default_offset=lp.auto) + + return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") + + evt, (out,) = knl()(queue, a=arg0, b=arg1) + return out + + class CBesselFunction(Function): supports_codegen = True @@ -179,6 +224,7 @@ def _make_bfr(): bfr = bfr.register(CElementwiseUnaryFunction("fabs")) bfr = bfr.register(CElementwiseUnaryFunction("sin")) bfr = bfr.register(CElementwiseUnaryFunction("cos")) + bfr = bfr.register(CElementwiseBinaryFunction("atan2")) bfr = bfr.register(CBesselFunction("bessel_j")) bfr = bfr.register(CBesselFunction("bessel_y")) diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index f8bff1bf2af1403a652f37a32995fd3a70b88f6a..aa00e8e678be045fa5ae3938a295af5da3f9ff41 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -363,6 +363,7 @@ sqrt = FunctionSymbol("sqrt") exp = FunctionSymbol("exp") sin = FunctionSymbol("sin") cos = FunctionSymbol("cos") +atan2 = FunctionSymbol("atan2") bessel_j = FunctionSymbol("bessel_j") bessel_y = FunctionSymbol("bessel_y")