Skip to content
Snippets Groups Projects
Commit e0ac49ba authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'atan2-function' into 'master'

Add atan2 FunctionSymbol

See merge request inducer/grudge!55
parents 54915dd7 7463a9d6
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa
from pytools import RecordWithoutPickling, memoize_in
# {{{ helpers
def should_use_numpy(arg):
from numbers import Number
if isinstance(arg, Number) or \
isinstance(arg, np.ndarray) and arg.shape == ():
return True
return False
def cl_to_numpy_function_name(name):
return {
"atan2": "arctan2",
}.get(name, name)
# }}}
# {{{ function
class FunctionNotFound(KeyError):
......@@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function):
def __call__(self, queue, arg):
func_name = self.identifier
from numbers import Number
if (
isinstance(arg, Number)
or (isinstance(arg, np.ndarray)
and arg.shape == ())):
if should_use_numpy(arg):
func = getattr(np, func_name)
return func(arg)
......@@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function):
return out
class CElementwiseBinaryFunction(Function):
supports_codegen = True
def get_result_dofdesc(self, arg_dds):
assert len(arg_dds) == 2
from pytools import single_valued
return single_valued(arg_dds)
def __call__(self, queue, arg0, arg1):
func_name = self.identifier
if should_use_numpy(arg0) and should_use_numpy(arg1):
func = getattr(np, cl_to_numpy_function_name(func_name))
return func(arg0, arg1)
from pymbolic.primitives import Variable
@memoize_in(self, "map_call_knl_%s" % func_name)
def knl():
i = Variable("i")
knl = lp.make_kernel(
"{[i]: 0<=i<n}",
[
lp.Assignment(
Variable("out")[i],
Variable(func_name)(Variable("a")[i], Variable("b")[i]))
], default_offset=lp.auto)
return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
evt, (out,) = knl()(queue, a=arg0, b=arg1)
return out
class CBesselFunction(Function):
supports_codegen = True
......@@ -179,6 +224,7 @@ def _make_bfr():
bfr = bfr.register(CElementwiseUnaryFunction("fabs"))
bfr = bfr.register(CElementwiseUnaryFunction("sin"))
bfr = bfr.register(CElementwiseUnaryFunction("cos"))
bfr = bfr.register(CElementwiseBinaryFunction("atan2"))
bfr = bfr.register(CBesselFunction("bessel_j"))
bfr = bfr.register(CBesselFunction("bessel_y"))
......
......@@ -363,6 +363,7 @@ sqrt = FunctionSymbol("sqrt")
exp = FunctionSymbol("exp")
sin = FunctionSymbol("sin")
cos = FunctionSymbol("cos")
atan2 = FunctionSymbol("atan2")
bessel_j = FunctionSymbol("bessel_j")
bessel_y = FunctionSymbol("bessel_y")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment