Skip to content
Snippets Groups Projects
Commit 7463a9d6 authored by Alexandru Fikl's avatar Alexandru Fikl
Browse files

function-registry: add atan2

parent 54915dd7
No related branches found
No related tags found
1 merge request!55Add atan2 FunctionSymbol
Pipeline #22261 failed
...@@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa ...@@ -33,6 +33,24 @@ from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_1 # noqa
from pytools import RecordWithoutPickling, memoize_in from pytools import RecordWithoutPickling, memoize_in
# {{{ helpers
def should_use_numpy(arg):
from numbers import Number
if isinstance(arg, Number) or \
isinstance(arg, np.ndarray) and arg.shape == ():
return True
return False
def cl_to_numpy_function_name(name):
return {
"atan2": "arctan2",
}.get(name, name)
# }}}
# {{{ function # {{{ function
class FunctionNotFound(KeyError): class FunctionNotFound(KeyError):
...@@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function): ...@@ -74,12 +92,7 @@ class CElementwiseUnaryFunction(Function):
def __call__(self, queue, arg): def __call__(self, queue, arg):
func_name = self.identifier func_name = self.identifier
if should_use_numpy(arg):
from numbers import Number
if (
isinstance(arg, Number)
or (isinstance(arg, np.ndarray)
and arg.shape == ())):
func = getattr(np, func_name) func = getattr(np, func_name)
return func(arg) return func(arg)
...@@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function): ...@@ -108,6 +121,38 @@ class CElementwiseUnaryFunction(Function):
return out return out
class CElementwiseBinaryFunction(Function):
supports_codegen = True
def get_result_dofdesc(self, arg_dds):
assert len(arg_dds) == 2
from pytools import single_valued
return single_valued(arg_dds)
def __call__(self, queue, arg0, arg1):
func_name = self.identifier
if should_use_numpy(arg0) and should_use_numpy(arg1):
func = getattr(np, cl_to_numpy_function_name(func_name))
return func(arg0, arg1)
from pymbolic.primitives import Variable
@memoize_in(self, "map_call_knl_%s" % func_name)
def knl():
i = Variable("i")
knl = lp.make_kernel(
"{[i]: 0<=i<n}",
[
lp.Assignment(
Variable("out")[i],
Variable(func_name)(Variable("a")[i], Variable("b")[i]))
], default_offset=lp.auto)
return lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
evt, (out,) = knl()(queue, a=arg0, b=arg1)
return out
class CBesselFunction(Function): class CBesselFunction(Function):
supports_codegen = True supports_codegen = True
...@@ -179,6 +224,7 @@ def _make_bfr(): ...@@ -179,6 +224,7 @@ def _make_bfr():
bfr = bfr.register(CElementwiseUnaryFunction("fabs")) bfr = bfr.register(CElementwiseUnaryFunction("fabs"))
bfr = bfr.register(CElementwiseUnaryFunction("sin")) bfr = bfr.register(CElementwiseUnaryFunction("sin"))
bfr = bfr.register(CElementwiseUnaryFunction("cos")) bfr = bfr.register(CElementwiseUnaryFunction("cos"))
bfr = bfr.register(CElementwiseBinaryFunction("atan2"))
bfr = bfr.register(CBesselFunction("bessel_j")) bfr = bfr.register(CBesselFunction("bessel_j"))
bfr = bfr.register(CBesselFunction("bessel_y")) bfr = bfr.register(CBesselFunction("bessel_y"))
......
...@@ -363,6 +363,7 @@ sqrt = FunctionSymbol("sqrt") ...@@ -363,6 +363,7 @@ sqrt = FunctionSymbol("sqrt")
exp = FunctionSymbol("exp") exp = FunctionSymbol("exp")
sin = FunctionSymbol("sin") sin = FunctionSymbol("sin")
cos = FunctionSymbol("cos") cos = FunctionSymbol("cos")
atan2 = FunctionSymbol("atan2")
bessel_j = FunctionSymbol("bessel_j") bessel_j = FunctionSymbol("bessel_j")
bessel_y = FunctionSymbol("bessel_y") bessel_y = FunctionSymbol("bessel_y")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment