From 17c2451c55fc8f58a811186506f05501e79c1ac5 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sun, 17 Jan 2021 18:03:14 -0600 Subject: [PATCH] bunch of callables related changes: - simplifies interface to with_types - simplifies interface to with_descr - simplifies the logic within CallableKernel.with_descrs - gets rid of ManglerCallable - introduces InKernelCallable.with_added_arg --- loopy/kernel/function_interface.py | 392 ++++++++++++--------------- loopy/library/function.py | 11 +- loopy/library/random123.py | 19 +- loopy/library/reduction.py | 35 +-- loopy/preprocess.py | 56 +++- loopy/target/c/__init__.py | 47 ++-- loopy/target/c/codegen/expression.py | 19 +- loopy/target/cuda.py | 5 +- loopy/target/opencl.py | 82 +++++- loopy/target/pyopencl.py | 4 +- loopy/target/python.py | 8 - loopy/type_inference.py | 174 ++++-------- test/testlib.py | 17 +- 13 files changed, 419 insertions(+), 450 deletions(-) diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 1120dd2bb..9eb707e81 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -20,15 +20,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import islpy as isl from pytools import ImmutableRecord from loopy.diagnostic import LoopyError from loopy.tools import update_persistent_hash from loopy.kernel import LoopKernel -from loopy.kernel.data import ValueArg, ArrayArg, ConstantArg -from loopy.symbolic import (SubstitutionMapper, DependencyMapper) -from pymbolic.primitives import Variable +from loopy.kernel.data import ValueArg, ArrayArg +from loopy.symbolic import DependencyMapper, WalkMapper __doc__ = """ @@ -39,7 +37,6 @@ __doc__ = """ .. autoclass:: InKernelCallable .. autoclass:: CallableKernel .. autoclass:: ScalarCallable -.. autoclass:: ManglerCallable """ @@ -77,6 +74,9 @@ class ArrayArgDescriptor(ImmutableRecord): A tuple of instances of :class:`loopy.kernel.array.ArrayDimImplementationTag` + + .. automethod:: map_expr + .. automethod:: depends_on """ fields = {"shape", "address_space", "dim_tags"} @@ -102,11 +102,19 @@ class ArrayArgDescriptor(ImmutableRecord): dim_tags=dim_tags) def map_expr(self, f): + """ + Returns an instance of :class:`ArrayArgDescriptor` with its shapes, strides, + mapped by *f*. + """ new_shape = tuple(f(axis_len) for axis_len in self.shape) new_dim_tags = tuple(dim_tag.map_expr(f) for dim_tag in self.dim_tags) return self.copy(shape=new_shape, dim_tags=new_dim_tags) def depends_on(self): + """ + Returns class:`frozenset` of all the variable names the + :class:`ArrayArgDescriptor` depends on. + """ from loopy.kernel.data import auto result = DependencyMapper(composite_leaves=False)([lngth for lngth in self.shape if lngth not in [None, auto]]) | ( @@ -124,13 +132,50 @@ class ArrayArgDescriptor(ImmutableRecord): key_builder.rec(key_hash, self.dim_tags) +class ExpressionIsScalarChecker(WalkMapper): + def __init__(self, kernel): + self.kernel = kernel + + def map_sub_array_ref(self, expr): + raise LoopyError("Sub-array refs can only be used as call's parameters" + f" or assignees. '{expr}'violates this.") + + def map_call(self, expr): + for child in expr.parameters: + self.rec(child) + + def map_call_with_kwargs(self, expr): + for child in expr.parameters + tuple(expr.kw_parameters.values()): + self.rec(child) + + def map_subscript(self, expr): + for child in expr.index_tuple: + self.rec(child) + + def map_variable(self, expr): + from loopy.kernel.data import TemporaryVariable, ArrayArg + if expr.name in self.kernel.all_inames(): + # inames are scalar + return + + var = self.kernel.arg_dict.get(expr.name, None) or ( + self.kernel.temporary_variables.get(expr.name, None)) + + if var is not None: + if isinstance(var, (ArrayArg, TemporaryVariable)) and ( + var.shape != ()): + raise LoopyError("Array regions can only passed as sub-array refs.") + + def map_slice(self, expr): + raise LoopyError("Array regions can only passed as sub-array refs.") + + def get_arg_descriptor_for_expression(kernel, expr): """ :returns: a :class:`ArrayArgDescriptor` or a :class:`ValueArgDescriptor` describing the argument expression *expr* which occurs in a call in the code of *kernel*. """ - from pymbolic.primitives import Variable from loopy.symbolic import (SubArrayRef, pw_aff_to_expr, SweptInameStrideCollector) from loopy.kernel.data import TemporaryVariable, ArrayArg @@ -186,24 +231,8 @@ def get_arg_descriptor_for_expression(kernel, expr): address_space=aspace, dim_tags=sub_dim_tags, shape=sub_shape) - - elif isinstance(expr, Variable): - arg = kernel.get_var_descriptor(expr.name) - from loopy.kernel.array import ArrayBase - - if isinstance(arg, ValueArg) or (isinstance(arg, ArrayBase) - and arg.shape == ()): - return ValueArgDescriptor() - elif isinstance(arg, (ArrayArg, TemporaryVariable)): - raise LoopyError("may not pass entire array " - "'%s' in call statement in kernel '%s'" - % (expr.name, kernel.name)) - else: - raise LoopyError("unsupported argument type " - "'%s' of '%s' in call statement" - % (type(arg).__name__, expr.name)) - else: + ExpressionIsScalarChecker(kernel)(expr) return ValueArgDescriptor() # }}} @@ -242,8 +271,8 @@ class GridOverrideForCalleeKernel(ImmutableRecord): Helper class to set the :attr:`loopy.kernel.LoopKernel.override_get_grid_size_for_insn_ids` of the callee kernels. Refer to - :func:`loopy.kernel.function_interface.GridOverrideForCalleeKernel.__call__`, - :func:`loopy.kernel.function_interface.CallbleKernel.with_hw_axes_sizes`. + :meth:`loopy.kernel.function_interface.GridOverrideForCalleeKernel.__call__`, + :meth:`loopy.kernel.function_interface.CallbleKernel.with_hw_axes_sizes`. .. attribute:: global_size @@ -325,7 +354,7 @@ class InKernelCallable(ImmutableRecord): update_persistent_hash = update_persistent_hash - def with_types(self, arg_id_to_dtype, caller_kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): """ :arg arg_id_to_type: a mapping from argument identifiers (integers for positional arguments, names for keyword @@ -345,12 +374,15 @@ class InKernelCallable(ImmutableRecord): raise NotImplementedError() - def with_descrs(self, arg_id_to_descr, caller_kernel, callables_table, expr): + def with_descrs(self, arg_id_to_descr, callables_table): """ - :arg arg_id_to_descr: a mapping from argument identifiers - (integers for positional arguments, names for keyword - arguments) to :class:`loopy.ArrayArgDescriptor` instances. - Unspecified/unknown types are not represented in *arg_id_to_descr*. + :arg arg_id_to_descr: a mapping from argument identifiers (integers for + positional arguments, names for keyword arguments) to + :class:`loopy.ArrayArgDescriptor` instances. Unspecified/unknown + descriptors are not represented in *arg_id_to_descr*. + + All the expressions in arg_id_to_descr must have variables that belong + to the callable's namespace. Return values are denoted by negative integers, with the first returned value identified as *-1*. @@ -439,6 +471,13 @@ class InKernelCallable(ImmutableRecord): return hash(tuple(self.fields)) + def with_added_arg(self, arg_dtype, arg_descr): + """ + Registers a new argument to the callable and returns the name of the + argument in the callable's namespace. + """ + raise NotImplementedError() + # }}} @@ -451,8 +490,7 @@ class ScalarCallable(InKernelCallable): .. note:: The :meth:`ScalarCallable.with_types` is intended to assist with type - specialization of the function and is expected to be supplemented in the - derived subclasses. + specialization of the function and sub-classes must define it. """ fields = {"name", "arg_id_to_dtype", "arg_id_to_descr", "name_in_target"} @@ -474,16 +512,16 @@ class ScalarCallable(InKernelCallable): return (self.arg_id_to_dtype, self.arg_id_to_descr, self.name_in_target) - def with_types(self, arg_id_to_dtype, caller_kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): raise LoopyError("No type inference information present for " "the function %s." % (self.name)) - def with_descrs(self, arg_id_to_descr, caller_kernel, callables_table, expr): + def with_descrs(self, arg_id_to_descr, callables_table): arg_id_to_descr[-1] = ValueArgDescriptor() return ( self.copy(arg_id_to_descr=arg_id_to_descr), - callables_table, ()) + callables_table) def with_hw_axes_sizes(self, global_size, local_size): return self.copy() @@ -584,9 +622,6 @@ class ScalarCallable(InKernelCallable): # assignee is returned whenever the size of assignees is non zero. first_assignee_is_returned = len(insn.assignees) > 0 - # TODO: Maybe this interface a bit confusing. Should we allow this - # method to directly return a cgen.Assign or cgen.ExpressionStatement? - return var(self.name_in_target)(*c_parameters), first_assignee_is_returned def generate_preambles(self, target): @@ -595,6 +630,9 @@ class ScalarCallable(InKernelCallable): # }}} + def with_added_arg(self, arg_dtype, arg_descr): + raise LoopyError("Cannot add args to scalar callables.") + # }}} @@ -645,8 +683,7 @@ class CallableKernel(InKernelCallable): def name(self): return self.subkernel.name - def with_types(self, arg_id_to_dtype, caller_kernel, - callables_table): + def with_types(self, arg_id_to_dtype, callables_table): kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel) new_args = [] @@ -684,124 +721,116 @@ class CallableKernel(InKernelCallable): return self.copy(subkernel=specialized_kernel, arg_id_to_dtype=new_arg_id_to_dtype), callables_table - def with_descrs(self, arg_id_to_descr, caller_kernel, callables_table, - expr=None): - # tune the subkernel so that we have the matching shapes and - # dim_tags - - # {{{ map the arg_descrs so that all the variables are from the callees - # perspective - - domain_dependent_vars = frozenset().union( - *(frozenset(dom.get_var_names(isl.dim_type.param)) for dom in - self.subkernel.domains)) - - # FIXME: This is ill-formed, because par can be an expression, e.g. - # 2*i+2 or 2*(i+1). A key feature of expression is that structural - # equality and semantic equality are not the same, so even if the - # SubstitutionMapper allowed non-variables, it would have to solve the - # (considerable) problem of expression equivalence. - - import numbers - substs = {} - assumptions = {} - - if expr: - for arg, par in zip(self.subkernel.args, expr.parameters): - if isinstance(arg, ValueArg) and arg.name in domain_dependent_vars: - if isinstance(par, Variable): - if par in substs: - assumptions[arg.name] = substs[par].name - else: - substs[par] = Variable(arg.name) - elif isinstance(par, numbers.Number): - assumptions[arg.name] = par - - def subst_func(expr): - if expr in substs: - return substs[expr] - else: - return expr - - subst_mapper = SubstitutionMapper(subst_func) - - arg_id_to_descr = {arg_id: descr.map_expr(subst_mapper) - for arg_id, descr in arg_id_to_descr.items()} + def with_descrs(self, arg_id_to_descr, callables_table): - # }}} + # arg_id_to_descr expressions provided are from the caller's namespace, + # need to register - dependents = frozenset().union(*(descr.depends_on() for descr in - arg_id_to_descr.values())) - unknown_deps = dependents - self.subkernel.all_variable_names() + kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel) - if expr is None: - assert unknown_deps == frozenset() - # FIXME: Need to make sure that we make the name of the variables - # unique, and then run a subst_mapper + kw_to_callee_idx = {arg.name: i + for i, arg in enumerate(self.subkernel.args)} new_args = self.subkernel.args[:] - kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel) for arg_id, descr in arg_id_to_descr.items(): if isinstance(arg_id, int): arg_id = pos_to_kw[arg_id] - assert isinstance(arg_id, str) + + callee_arg = new_args[kw_to_callee_idx[arg_id]] + + # {{{ checks + + if isinstance(callee_arg, ValueArg) and ( + isinstance(descr, ArrayArgDescriptor)): + raise LoopyError(f"In call to {self.subkernel.name}, '{arg_id}' " + "expected to be a scalar, got an array region.") + + if isinstance(callee_arg, ArrayArg) and ( + isinstance(descr, ValueArgDescriptor)): + raise LoopyError(f"In call to {self.subkernel.name}, '{arg_id}' " + "expected to be an array, got a scalar.") + + if (isinstance(descr, ArrayArgDescriptor) + and isinstance(callee_arg.shape, tuple) + and len(callee_arg.shape) != len(descr.shape)): + raise LoopyError(f"In call to {self.subkernel.name}, '{arg_id}'" + " has a dimensionality mismatch, expected " + f"{len(callee_arg.shape)}, got {len(descr.shape)}") + + # }}} if isinstance(descr, ArrayArgDescriptor): - if not isinstance(self.subkernel.arg_dict[arg_id], (ArrayArg, - ConstantArg)): - raise LoopyError("Array passed to scalar argument " - "'%s' of the function '%s' (in '%s')." % ( - arg_id, self.subkernel.name, - caller_kernel.name)) - if self.subkernel.arg_dict[arg_id].shape and ( - len(self.subkernel.arg_dict[arg_id].shape) != - len(descr.shape)): - raise LoopyError("Dimension mismatch for argument " - " '%s' of the function '%s' (in '%s')." % ( - arg_id, self.subkernel.name, - caller_kernel.name)) - - new_arg = self.subkernel.arg_dict[arg_id].copy( - shape=descr.shape, - dim_tags=descr.dim_tags, - address_space=descr.address_space) - # replacing the new arg with the arg of the same name - new_args = [new_arg if arg.name == arg_id else arg for arg in - new_args] - elif isinstance(descr, ValueArgDescriptor): - if not isinstance(self.subkernel.arg_dict[arg_id], ValueArg): - raise LoopyError("Scalar passed to array argument " - "'%s' of the callable '%s' (in '%s')" % ( - arg_id, self.subkernel.name, - caller_kernel.name)) + callee_arg = callee_arg.copy(shape=descr.shape, + dim_tags=descr.dim_tags, + address_space=descr.address_space) else: - raise LoopyError("Descriptor must be either an instance of " - "ArrayArgDescriptor or ValueArgDescriptor -- got %s" % - type(descr)) - - descriptor_specialized_knl = self.subkernel.copy(args=new_args) - # add the variables on which the strides/shapes depend but not provided - # as arguments - args_added_knl = descriptor_specialized_knl.copy( - args=descriptor_specialized_knl.args - + [ValueArg(dep) for dep in unknown_deps]) + # do nothing for a scalar arg. + assert isinstance(descr, ValueArgDescriptor) + + new_args[kw_to_callee_idx[arg_id]] = callee_arg + + subkernel = self.subkernel.copy(args=new_args) + from loopy.preprocess import traverse_to_infer_arg_descr - from loopy.transform.parameter import assume - args_added_knl, callables_table = ( - traverse_to_infer_arg_descr(args_added_knl, + subkernel, callables_table = ( + traverse_to_infer_arg_descr(subkernel, callables_table)) - if assumptions: - assumption_str = " and ".join([f"{key}={val}" - for key, val in assumptions.items()]) - args_added_knl = assume(args_added_knl, assumption_str) + # {{{ update the arg descriptors - return ( - self.copy( - subkernel=args_added_knl, - arg_id_to_descr=arg_id_to_descr), - callables_table, tuple(Variable(dep) for dep in unknown_deps)) + for arg in subkernel.args: + kw = arg.name + if isinstance(arg, ArrayArg): + arg_id_to_descr[kw] = ( + ArrayArgDescriptor(shape=arg.shape, + dim_tags=arg.dim_tags, + address_space=arg.address_space)) + else: + assert isinstance(arg, ValueArg) + arg_id_to_descr[kw] = ValueArgDescriptor() + + arg_id_to_descr[kw_to_pos[kw]] = arg_id_to_descr[kw] + + # }}} + + return (self.copy(subkernel=subkernel, + arg_id_to_descr=arg_id_to_descr), + callables_table) + + def with_added_arg(self, arg_dtype, arg_descr): + var_name = self.subkernel.get_var_name_generator()(based_on="_lpy_arg") + + if isinstance(arg_descr, ValueArgDescriptor): + subknl = self.subkernel.copy( + args=self.subkernel.args+[ + ValueArg(var_name, arg_dtype, self.subkernel.target)]) + + kw_to_pos, pos_to_kw = get_kw_pos_association(subknl) + + if self.arg_id_to_dtype is None: + arg_id_to_dtype = {} + else: + arg_id_to_dtype = self.arg_id_to_dtype.copy() + if self.arg_id_to_descr is None: + arg_id_to_descr = {} + else: + arg_id_to_descr = self.arg_id_to_descr.copy() + + arg_id_to_dtype[var_name] = arg_dtype + arg_id_to_descr[var_name] = arg_descr + arg_id_to_dtype[kw_to_pos[var_name]] = arg_dtype + arg_id_to_descr[kw_to_pos[var_name]] = arg_descr + + return (self.copy(subkernel=subknl, + arg_id_to_dtype=arg_id_to_dtype, + arg_id_to_descr=arg_id_to_descr), + var_name) + + else: + # don't think this should ever be needed + raise NotImplementedError("with_added_arg not implemented for array" + " types arguments.") def with_packing_for_args(self): from loopy.kernel.data import AddressSpace @@ -892,81 +921,4 @@ class CallableKernel(InKernelCallable): # }}} -# {{{ mangler callable - -class ManglerCallable(ScalarCallable): - """ - A callable whose characteristic is defined by a function mangler. - - .. attribute:: function_mangler - - A function of signature ``(kernel, name , arg_dtypes)`` and returns an - instance of ``loopy.CallMangleInfo``. - """ - fields = {"name", "function_mangler", "arg_id_to_dtype", "arg_id_to_descr", - "name_in_target"} - init_arg_names = ("name", "function_mangler", "arg_id_to_dtype", - "arg_id_to_descr", "name_in_target") - hash_fields = ("name", "arg_id_to_dtype", "arg_id_to_descr", - "name_in_target") - - def __init__(self, name, function_mangler, arg_id_to_dtype=None, - arg_id_to_descr=None, name_in_target=None): - - self.function_mangler = function_mangler - - super().__init__( - name=name, - arg_id_to_dtype=arg_id_to_dtype, - arg_id_to_descr=arg_id_to_descr, - name_in_target=name_in_target) - - def __getinitargs__(self): - return (self.name, self.function_mangler, self.arg_id_to_dtype, - self.arg_id_to_descr, self.name_in_target) - - def with_types(self, arg_id_to_dtype, kernel, callables_table): - if self.arg_id_to_dtype is not None: - # specializing an already specialized function. - for arg_id, dtype in arg_id_to_dtype.items(): - # only checking for the ones which have been provided - # if does not match, returns an error. - if self.arg_id_to_dtype[arg_id] != arg_id_to_dtype[arg_id]: - raise LoopyError("Overwriting a specialized" - " function is illegal--maybe start with new instance of" - " ManglerCallable?") - - sorted_keys = sorted(arg_id_to_dtype.keys()) - arg_dtypes = tuple(arg_id_to_dtype[key] for key in sorted_keys if - key >= 0) - - mangle_result = self.function_mangler(kernel, self.name, - arg_dtypes) - if mangle_result: - new_arg_id_to_dtype = dict(enumerate(mangle_result.arg_dtypes)) - new_arg_id_to_dtype.update({-i-1: dtype for i, dtype in - enumerate(mangle_result.result_dtypes)}) - return ( - self.copy(name_in_target=mangle_result.target_name, - arg_id_to_dtype=new_arg_id_to_dtype), - callables_table) - else: - # The function mangler does not agree with the arg id to dtypes - # provided. Indicating that is illegal. - raise LoopyError("Function %s not coherent with the provided types." % ( - self.name)) - - def mangle_result(self, kernel): - """ - Returns an instance of :class:`loopy.kernel.data.CallMangleInfo` for - the given pair :attr:`function_mangler` and :attr:`arg_id_to_dtype`. - """ - sorted_keys = sorted(self.arg_id_to_dtype.keys()) - arg_dtypes = tuple(self.arg_id_to_dtype[key] for key in sorted_keys if - key >= 0) - - return self.function_mangler(kernel, self.name, arg_dtypes) - -# }}} - # vim: foldmethod=marker diff --git a/loopy/library/function.py b/loopy/library/function.py index bea9a4a70..73241152f 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -22,10 +22,11 @@ THE SOFTWARE. from loopy.kernel.function_interface import ScalarCallable from loopy.diagnostic import LoopyError +import numpy as np class MakeTupleCallable(ScalarCallable): - def with_types(self, arg_id_to_dtype, kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): new_arg_id_to_dtype = arg_id_to_dtype.copy() for i in range(len(arg_id_to_dtype)): if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None: @@ -34,22 +35,22 @@ class MakeTupleCallable(ScalarCallable): return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype, name_in_target="loopy_make_tuple"), callables_table) - def with_descrs(self, arg_id_to_descr, caller_kernel, callables_table, expr): + def with_descrs(self, arg_id_to_descr, callables_table): from loopy.kernel.function_interface import ValueArgDescriptor new_arg_id_to_descr = {(id, ValueArgDescriptor()): (-id-1, ValueArgDescriptor()) for id in arg_id_to_descr.keys()} return ( self.copy(arg_id_to_descr=new_arg_id_to_descr), - callables_table, ()) + callables_table) class IndexOfCallable(ScalarCallable): - def with_types(self, arg_id_to_dtype, kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): new_arg_id_to_dtype = {i: dtype for i, dtype in arg_id_to_dtype.items() if dtype is not None} - new_arg_id_to_dtype[-1] = kernel.index_dtype + new_arg_id_to_dtype[-1] = np.int32 return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype), callables_table) diff --git a/loopy/library/random123.py b/loopy/library/random123.py index c2e64fc55..14199b279 100644 --- a/loopy/library/random123.py +++ b/loopy/library/random123.py @@ -168,7 +168,18 @@ class Random123Callable(ScalarCallable): Records information about for the random123 functions. """ - def with_types(self, arg_id_to_dtype, kernel, callables_table): + def __init__(self, name, arg_id_to_dtype=None, + arg_id_to_descr=None, name_in_target=None, target=None): + + super().__init__( + name=name, + arg_id_to_dtype=arg_id_to_dtype, + arg_id_to_descr=arg_id_to_descr, + name_in_target=name_in_target) + + self.target = target + + def with_types(self, arg_id_to_dtype, callables_table): if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or ( arg_id_to_dtype[0] is None or arg_id_to_dtype[1] is None): @@ -178,7 +189,7 @@ class Random123Callable(ScalarCallable): callables_table) name = self.name - target = kernel.target + target = self.target rng_variant = FUNC_NAMES_TO_RNG[name] @@ -230,7 +241,7 @@ class Random123Callable(ScalarCallable): return -def get_random123_callables(): - return {id_: Random123Callable(id_) for id_ in FUNC_NAMES_TO_RNG} +def get_random123_callables(target): + return {id_: Random123Callable(id_, target=target) for id_ in FUNC_NAMES_TO_RNG} # vim: foldmethod=marker diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index fa6c0cd89..1d53d06b0 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -53,7 +53,7 @@ class ReductionOperation: equality-comparable. """ - def result_dtypes(self, target, *arg_dtypes): + def result_dtypes(self, *arg_dtypes): """ :arg arg_dtypes: may be None if not known :returns: None if not known, otherwise the returned type @@ -112,10 +112,11 @@ class ScalarReductionOperation(ReductionOperation): def arg_count(self): return 1 - def result_dtypes(self, kernel, arg_dtype): + def result_dtypes(self, arg_dtype): if self.forced_result_type is not None: - return (self.parse_result_type( - kernel.target, self.forced_result_type),) + raise NotImplementedError() + # return (self.parse_result_type( + # kernel.target, self.forced_result_type),) if arg_dtype is None: return None @@ -224,7 +225,7 @@ class MaxReductionOperation(ScalarReductionOperation): # type specialize the callable max_scalar_callable, callables_table = max_scalar_callable.with_types( - {0: dtype, 1: dtype}, None, callables_table) + {0: dtype, 1: dtype}, callables_table) # populate callables_table func_id, callables_table = update_table(callables_table, "max", @@ -246,7 +247,7 @@ class MinReductionOperation(ScalarReductionOperation): # type specialize the callable min_scalar_callable, callables_table = min_scalar_callable.with_types( - {0: dtype, 1: dtype}, None, callables_table) + {0: dtype, 1: dtype}, callables_table) # populate callables_table func_id, callables_table = update_table(callables_table, "min", @@ -325,7 +326,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): make_tuple_callable, callables_table = make_tuple_callable.with_types( dict(enumerate([scalar_dtype, segment_flag_dtype])), - None, callables_table) + callables_table) func_id, callables_table = update_table( callables_table, "make_tuple", make_tuple_callable) @@ -333,8 +334,8 @@ class _SegmentedScalarReductionOperation(ReductionOperation): return ResolvedFunction(func_id)(scalar_neutral_element, segment_flag_dtype.numpy_dtype.type(0)), callables_table - def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype): - return (self.inner_reduction.result_dtypes(kernel, scalar_dtype) + def result_dtypes(self, scalar_dtype, segment_flag_dtype): + return (self.inner_reduction.result_dtypes(scalar_dtype) + (segment_flag_dtype,)) def __str__(self): @@ -355,7 +356,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): segmented_scalar_callable, callables_table = ( segmented_scalar_callable.with_types( {0: dtypes[0], 1: dtypes[1], 2: dtypes[0], 3: dtypes[1]}, - None, callables_table)) + callables_table)) # populate callables_table from loopy.program import update_table @@ -414,7 +415,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): scalar_dtype.numpy_dtype.type.__name__, index_dtype.numpy_dtype.type.__name__) - def result_dtypes(self, kernel, scalar_dtype, index_dtype): + def result_dtypes(self, scalar_dtype, index_dtype): return (scalar_dtype, index_dtype) def neutral_element(self, scalar_dtype, index_dtype, callables_table, @@ -430,7 +431,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): make_tuple_callable, callables_table = make_tuple_callable.with_types( dict(enumerate([scalar_dtype, index_dtype])), - None, callables_table) + callables_table) # populate callables_table func_id, callables_table = update_table(callables_table, "make_tuple", @@ -459,7 +460,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): arg_ext_scalar_callable, callables_table = ( arg_ext_scalar_callable.with_types( {0: dtypes[0], 1: dtypes[1], 2: dtypes[0], 3: dtypes[1]}, - None, callables_table)) + callables_table)) # populate callables_table from loopy.program import update_table @@ -549,10 +550,10 @@ def parse_reduction_op(name): # {{{ reduction specific callables class ReductionCallable(ScalarCallable): - def with_types(self, arg_id_to_dtype, kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): scalar_dtype = arg_id_to_dtype[0] index_dtype = arg_id_to_dtype[1] - result_dtypes = self.name.reduction_op.result_dtypes(kernel, scalar_dtype, + result_dtypes = self.name.reduction_op.result_dtypes(scalar_dtype, index_dtype) new_arg_id_to_dtype = arg_id_to_dtype.copy() new_arg_id_to_dtype[-1] = result_dtypes[0] @@ -563,13 +564,13 @@ class ReductionCallable(ScalarCallable): return self.copy(arg_id_to_dtype=new_arg_id_to_dtype, name_in_target=name_in_target), callables_table - def with_descrs(self, arg_id_to_descr, caller_kernel, callables_table, expr): + def with_descrs(self, arg_id_to_descr, callables_table): from loopy.kernel.function_interface import ValueArgDescriptor new_arg_id_to_descr = arg_id_to_descr.copy() new_arg_id_to_descr[-1] = ValueArgDescriptor() return ( self.copy(arg_id_to_descr=arg_id_to_descr), - callables_table, ()) + callables_table) def generate_preambles(self, target): if isinstance(self.name, ArgExtOp): diff --git a/loopy/preprocess.py b/loopy/preprocess.py index e377adc28..20ed08402 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2084,8 +2084,13 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): self.callables_table = callables_table def map_call(self, expr, expn_state, assignees=None): - from pymbolic.primitives import Call, CallWithKwargs + from pymbolic.primitives import Call, CallWithKwargs, Variable + from loopy.kernel.function_interface import ValueArgDescriptor from loopy.symbolic import ResolvedFunction + from loopy.kernel.array import ArrayBase + from loopy.kernel.data import ValueArg + from pymbolic.mapper.substitutor import make_subst_func + from loopy.symbolic import SubstitutionMapper if not isinstance(expr.function, ResolvedFunction): # ignore if the call is not to a ResolvedFunction @@ -2105,13 +2110,45 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): arg_id: get_arg_descriptor_for_expression( self.caller_kernel, arg) for arg_id, arg in arg_id_to_val.items()} + in_knl_callable = self.callables_table[expr.function.name] + + # {{{ translating descriptor expressions to the callable's namespace + + deps_as_params = [] + subst_map = {} + + deps = frozenset().union(*(descr.depends_on() + for descr in arg_id_to_descr.values())) + + assert deps <= self.caller_kernel.all_variable_names() + + for dep in deps: + caller_arg = self.caller_kernel.arg_dict.get(dep, None) + caller_arg = self.caller_kernel.temporary_variables.get(dep, caller_arg) + + if not (isinstance(caller_arg, ValueArg) or (isinstance(caller_arg, + ArrayBase) and arg.shape == ())): + raise NotImplementedError(f"Obtained '{dep}' as a dependency for" + f" call '{expr.function.name}' which is not a scalar.") + + in_knl_callable, callee_name = in_knl_callable.with_added_arg( + caller_arg.dtype, ValueArgDescriptor()) + + subst_map[dep] = Variable(callee_name) + deps_as_params.append(Variable(dep)) + + mapper = SubstitutionMapper(make_subst_func(subst_map)) + arg_id_to_descr = {id_: descr.map_expr(mapper) + for id_, descr in arg_id_to_descr.items()} + + # }}} # specializing the function according to the parameter description - in_knl_callable = self.callables_table[expr.function.name] - new_in_knl_callable, self.callables_table, new_vars = ( + new_in_knl_callable, self.callables_table = ( in_knl_callable.with_descrs( - arg_id_to_descr, self.caller_kernel, - self.callables_table, expr)) + arg_id_to_descr, self.callables_table)) + + # find the deps of the new in kernel callablen and add those arguments to self.callables_table, new_func_id = ( self.callables_table.with_callable( @@ -2122,9 +2159,10 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): return Call( ResolvedFunction(new_func_id), tuple(self.rec(child, expn_state) - for child in expr.parameters)+new_vars) + for child in expr.parameters) + + tuple(deps_as_params)) else: - # FIXME: Order for vars when kwards are present? + # FIXME: Order for vars when kwargs are present? assert isinstance(expr, CallWithKwargs) return CallWithKwargs( ResolvedFunction(new_func_id), @@ -2231,8 +2269,8 @@ def infer_arg_descr(program): arg_id_to_descr[arg.name] = ValueArgDescriptor() else: raise NotImplementedError() - new_callable, clbl_inf_ctx, _ = program.callables_table[e].with_descrs( - arg_id_to_descr, None, clbl_inf_ctx) + new_callable, clbl_inf_ctx = program.callables_table[e].with_descrs( + arg_id_to_descr, clbl_inf_ctx) clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable) renamed_entrypoints.add(new_name.name) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 8babd6fec..5fe9e3842 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -466,7 +466,7 @@ class CMathCallable(ScalarCallable): C-Target. """ - def with_types(self, arg_id_to_dtype, caller_kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): name = self.name if name in ["abs", "min", "max"]: @@ -497,18 +497,16 @@ class CMathCallable(ScalarCallable): elif dtype.kind == "c": raise LoopyTypeError(f"{name} does not support type {dtype}") - from loopy.target.opencl import OpenCLTarget - if not isinstance(caller_kernel.target, OpenCLTarget): - # for CUDA, C Targets the name must be modified - if dtype == np.float64: - pass # fabs - elif dtype == np.float32: - name = name + "f" # fabsf - elif dtype == np.float128: # pylint:disable=no-member - name = name + "l" # fabsl - else: - raise LoopyTypeError("{} does not support type {}".format(name, - dtype)) + # for CUDA, C Targets the name must be modified + if dtype == np.float64: + pass # fabs + elif dtype == np.float32: + name = name + "f" # fabsf + elif dtype == np.float128: # pylint:disable=no-member + name = name + "l" # fabsl + else: + raise LoopyTypeError("{} does not support type {}".format(name, + dtype)) return ( self.copy(name_in_target=name, @@ -521,9 +519,6 @@ class CMathCallable(ScalarCallable): for id in arg_id_to_dtype: if not -1 <= id <= 1: - #FIXME: Do we need to raise here?: - # The pattern we generally follow is that if we don't find - # a function, then we just return None raise LoopyError("%s can take only two arguments." % name) if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or ( @@ -542,17 +537,15 @@ class CMathCallable(ScalarCallable): raise LoopyTypeError("%s does not support complex numbers") elif dtype.kind == "f": - from loopy.target.opencl import OpenCLTarget - if not isinstance(caller_kernel.target, OpenCLTarget): - if dtype == np.float64: - pass # fmin - elif dtype == np.float32: - name = name + "f" # fminf - elif dtype == np.float128: # pylint:disable=no-member - name = name + "l" # fminl - else: - raise LoopyTypeError("%s does not support type %s" - % (name, dtype)) + if dtype == np.float64: + pass # fmin + elif dtype == np.float32: + name = name + "f" # fminf + elif dtype == np.float128: # pylint:disable=no-member + name = name + "l" # fminl + else: + raise LoopyTypeError("%s does not support type %s" + % (name, dtype)) dtype = NumpyType(dtype) return ( self.copy(name_in_target=name, diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 23f6e92f3..70f046c9d 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -451,23 +451,6 @@ class ExpressionToCExpressionMapper(IdentityMapper): "for constant '%s'" % expr) def map_call(self, expr, type_context): - - identifier_name = ( - self.codegen_state.callables_table[expr.function.name].name) - - from loopy.kernel.function_interface import ManglerCallable - if isinstance(self.codegen_state.callables_table[expr.function.name], - ManglerCallable): - from loopy.codegen import SeenFunction - in_knl_callable = ( - self.codegen_state.callables_table[ - expr.function.name]) - mangle_result = in_knl_callable.mangle_result(self.kernel) - self.codegen_state.seen_functions.add( - SeenFunction(identifier_name, - mangle_result.target_name, - mangle_result.arg_dtypes)) - return ( self.codegen_state.callables_table[ expr.function.name].emit_call( @@ -666,7 +649,7 @@ class ExpressionToCExpressionMapper(IdentityMapper): from loopy.codegen import SeenFunction clbl = self.codegen_state.ast_builder.known_callables["pow"] clbl = clbl.with_types({0: tgt_dtype, 1: exponent_dtype}, - self.kernel, self.codegen_state.callables_table)[0] + self.codegen_state.callables_table)[0] self.codegen_state.seen_functions.add( SeenFunction( clbl.name, clbl.name_in_target, diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 54b1006ad..ee99f27e7 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -121,13 +121,10 @@ _CUDA_SPECIFIC_FUNCTIONS = { class CudaCallable(ScalarCallable): - def cuda_with_types(self, arg_id_to_dtype, caller_kernel, - callables_table): + def cuda_with_types(self, arg_id_to_dtype, callables_table): name = self.name - # FIXME: dot is not implemented yet. - if name in _CUDA_SPECIFIC_FUNCTIONS: num_args = _CUDA_SPECIFIC_FUNCTIONS[name] diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 22fa78a55..affe9ff5b 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -171,10 +171,71 @@ class OpenCLCallable(ScalarCallable): :class:`loopy.target.c.CMathCallable`. """ - def with_types(self, arg_id_to_dtype, caller_kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): name = self.name - if name in ["max", "min"]: + # unary functions + if name in ["fabs", "acos", "asin", "atan", "cos", "cosh", "sin", "sinh", + "tan", "tanh", "exp", "log", "log10", "sqrt", "ceil", "floor", + "erf", "erfc"]: + + for id in arg_id_to_dtype: + if not -1 <= id <= 0: + raise LoopyError(f"'{name}' can take only one argument.") + + if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: + # the types provided aren't mature enough to specialize the + # callable + return ( + self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + + dtype = arg_id_to_dtype[0] + dtype = dtype.numpy_dtype + + if dtype.kind in ("u", "i"): + # ints and unsigned casted to float32 + dtype = np.float32 + elif dtype.kind == "c": + raise LoopyTypeError(f"{name} does not support type {dtype}") + + return ( + self.copy(name_in_target=name, + arg_id_to_dtype={0: NumpyType(dtype), -1: + NumpyType(dtype)}), + callables_table) + # binary functions + elif name in ["fmax", "fmin", "atan2", "copysign"]: + + for id in arg_id_to_dtype: + if not -1 <= id <= 1: + #FIXME: Do we need to raise here?: + # The pattern we generally follow is that if we don't find + # a function, then we just return None + raise LoopyError("%s can take only two arguments." % name) + + if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or ( + arg_id_to_dtype[0] is None or arg_id_to_dtype[1] is None): + # the types provided aren't mature enough to specialize the + # callable + return ( + self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + + dtype = np.find_common_type( + [], [dtype.numpy_dtype for id, dtype in arg_id_to_dtype.items() + if id >= 0]) + + if dtype.kind == "c": + raise LoopyTypeError("%s does not support complex numbers") + + dtype = NumpyType(dtype) + return ( + self.copy(name_in_target=name, + arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}), + callables_table) + + elif name in ["max", "min"]: for id in arg_id_to_dtype: if not -1 <= id <= 1: raise LoopyError("%s can take only 2 arguments." % name) @@ -200,7 +261,7 @@ class OpenCLCallable(ScalarCallable): raise LoopyError("%s function not supported for the types %s" % (name, common_dtype)) - if name == "dot": + elif name == "dot": for id in arg_id_to_dtype: if not -1 <= id <= 1: raise LoopyError(f"'{name}' can take only 2 arguments.") @@ -220,7 +281,7 @@ class OpenCLCallable(ScalarCallable): NumpyType(scalar_dtype), 0: dtype, 1: dtype}), callables_table) - if name == "pow": + elif name == "pow": for id in arg_id_to_dtype: if not -1 <= id <= 1: raise LoopyError(f"'{name}' can take only 2 arguments.") @@ -244,7 +305,7 @@ class OpenCLCallable(ScalarCallable): 0: common_dtype, 1: common_dtype}), callables_table) - if name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: + elif name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name] for id in arg_id_to_dtype: if not -1 <= id < num_args: @@ -275,7 +336,7 @@ class OpenCLCallable(ScalarCallable): arg_id_to_dtype=updated_arg_id_to_dtype), callables_table) - if name in VECTOR_LITERAL_FUNCS: + elif name in VECTOR_LITERAL_FUNCS: base_tp_name, dtype, count = VECTOR_LITERAL_FUNCS[name] for id in arg_id_to_dtype: @@ -313,8 +374,13 @@ def get_opencl_callables(): Returns an instance of :class:`InKernelCallable` if the function defined by *identifier* is known in OpenCL. """ - opencl_function_ids = {"max", "min", "dot", "pow"} | set( - _CL_SIMPLE_MULTI_ARG_FUNCTIONS) | set(VECTOR_LITERAL_FUNCS) + opencl_function_ids = ( + {"max", "min", "dot", "pow", "abs", "acos", "asin", + "atan", "cos", "cosh", "sin", "sinh", "pow", "atan2", "tanh", "exp", + "log", "log10", "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", + "fabs", "tan", "erf", "erfc"} + | set(_CL_SIMPLE_MULTI_ARG_FUNCTIONS) + | set(VECTOR_LITERAL_FUNCS)) return {id_: OpenCLCallable(name=id_) for id_ in opencl_function_ids} diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 59b90ef90..a192520c4 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -201,7 +201,7 @@ class PyOpenCLCallable(ScalarCallable): Records information about the callables which are not covered by :class:`loopy.target.opencl.OpenCLCallable` """ - def with_types(self, arg_id_to_dtype, caller_kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): name = self.name @@ -816,7 +816,7 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder): from loopy.library.random123 import get_random123_callables callables = super().known_callables callables.update(get_pyopencl_callables()) - callables.update(get_random123_callables()) + callables.update(get_random123_callables(self.target)) return callables def preamble_generators(self): diff --git a/loopy/target/python.py b/loopy/target/python.py index 03910e120..c7f20ff55 100644 --- a/loopy/target/python.py +++ b/loopy/target/python.py @@ -90,16 +90,8 @@ class ExpressionToPythonMapper(StringifyMapper): raise LoopyError( "indexof, indexof_vec not yet supported in Python") - from loopy.kernel.function_interface import ManglerCallable clbl = self.codegen_state.callables_table[ expr.function.name] - if isinstance(clbl, ManglerCallable): - from loopy.codegen import SeenFunction - mangle_result = clbl.mangle_result(self.kernel) - self.codegen_state.seen_functions.add( - SeenFunction(identifier_name, - mangle_result.target_name, - mangle_result.arg_dtypes)) str_parameters = None number_of_assignees = len([key for key in diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 422404411..4410a2676 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -426,137 +426,75 @@ class TypeInferenceMapper(CombineMapper): tuple(enumerate(expr.parameters)) + tuple(kw_parameters.items())} # specializing the known function wrt type - if isinstance(expr.function, ResolvedFunction): - in_knl_callable = self.clbl_inf_ctx[expr.function.name] - - # {{{ checking that there is no overwriting of types of in_knl_callable - - if in_knl_callable.arg_id_to_dtype is not None: - - # specializing an already specialized function. - for id, dtype in arg_id_to_dtype.items(): - if id in in_knl_callable.arg_id_to_dtype and ( - in_knl_callable.arg_id_to_dtype[id] != - arg_id_to_dtype[id]): - - # {{{ ignoring the the cases when there is a discrepancy - # between np.uint and np.int + in_knl_callable = self.clbl_inf_ctx[expr.function.name] - import numpy as np - if in_knl_callable.arg_id_to_dtype[id].dtype.type == ( - np.uint32) and ( - arg_id_to_dtype[id].dtype.type == np.int32): - continue - if in_knl_callable.arg_id_to_dtype[id].dtype.type == ( - np.uint64) and ( - arg_id_to_dtype[id].dtype.type == - np.int64): - continue + # {{{ checking that there is no overwriting of types of in_knl_callable - if np.can_cast(arg_id_to_dtype[id].dtype.type, - in_knl_callable.arg_id_to_dtype[id].dtype.type): - continue + if in_knl_callable.arg_id_to_dtype is not None: - # }}} + # specializing an already specialized function. + for id, dtype in arg_id_to_dtype.items(): + if id in in_knl_callable.arg_id_to_dtype and ( + in_knl_callable.arg_id_to_dtype[id] != + arg_id_to_dtype[id]): - raise LoopyError("Overwriting a specialized function " - "is illegal--maybe start with new instance of " - "InKernelCallable?") + # {{{ ignoring the the cases when there is a discrepancy + # between np.uint and np.int - # }}} + import numpy as np + if in_knl_callable.arg_id_to_dtype[id].dtype.type == ( + np.uint32) and ( + arg_id_to_dtype[id].dtype.type == np.int32): + continue + if in_knl_callable.arg_id_to_dtype[id].dtype.type == ( + np.uint64) and ( + arg_id_to_dtype[id].dtype.type == + np.int64): + continue - in_knl_callable, self.clbl_inf_ctx = ( - in_knl_callable.with_types( - arg_id_to_dtype, self.kernel, - self.clbl_inf_ctx)) - - in_knl_callable = in_knl_callable.with_target(self.kernel.target) - - # storing the type specialized function so that it can be used for - # later use - self.clbl_inf_ctx, new_function_id = ( - self.clbl_inf_ctx.with_callable( - expr.function.function, - in_knl_callable)) - - if isinstance(expr, Call): - self.old_calls_to_new_calls[expr] = new_function_id - else: - assert isinstance(expr, CallWithKwargs) - self.old_calls_to_new_calls[expr] = new_function_id + if np.can_cast(arg_id_to_dtype[id].dtype.type, + in_knl_callable.arg_id_to_dtype[id].dtype.type): + continue - new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype + # }}} - if new_arg_id_to_dtype is None: - return [] + raise LoopyError("Overwriting a specialized function " + "is illegal--maybe start with new instance of " + "InKernelCallable?") - # collecting result dtypes in order of the assignees - if -1 in new_arg_id_to_dtype and new_arg_id_to_dtype[-1] is not None: - if return_tuple: - return [get_return_types_as_tuple(new_arg_id_to_dtype)] - else: - return [new_arg_id_to_dtype[-1]] + # }}} - elif isinstance(expr.function, Variable): - # Since, the function is not "scoped", attempt to infer using - # kernel.function_manglers + in_knl_callable, self.clbl_inf_ctx = ( + in_knl_callable.with_types( + arg_id_to_dtype, + self.clbl_inf_ctx)) - # {{{ trying to infer using function manglers + in_knl_callable = in_knl_callable.with_target(self.kernel.target) - arg_dtypes = tuple(none_if_empty(self.rec(par)) for par in - expr.parameters) + # storing the type specialized function so that it can be used for + # later use + self.clbl_inf_ctx, new_function_id = ( + self.clbl_inf_ctx.with_callable( + expr.function.function, + in_knl_callable)) - # finding the function_mangler which would be associated with the - # realized function. + if isinstance(expr, Call): + self.old_calls_to_new_calls[expr] = new_function_id + else: + assert isinstance(expr, CallWithKwargs) + self.old_calls_to_new_calls[expr] = new_function_id - mangle_result = None - for function_mangler in self.kernel.function_manglers: - mangle_result = function_mangler(self.kernel, identifier, - arg_dtypes) - if mangle_result: - # found a match. - break + new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype - if mangle_result is not None: - from loopy.kernel.function_interface import ManglerCallable - - # creating arg_id_to_dtype from arg_dtypes - arg_id_to_dtype = {i: dt.with_target(self.kernel.target) - for i, dt in enumerate(mangle_result.arg_dtypes)} - arg_id_to_dtype.update({-i-1: - dtype.with_target(self.kernel.target) for i, dtype in enumerate( - mangle_result.result_dtypes)}) - - # creating the ManglerCallable object corresponding to the - # function. - in_knl_callable = ManglerCallable( - identifier, function_mangler, arg_id_to_dtype, - name_in_target=mangle_result.target_name) - # FIXME: we have not tested how it works with mangler callable - # yet. - self.clbl_inf_ctx, new_function_id = ( - self.clbl_inf_ctx.with_callable( - expr.function, in_knl_callable)) - - if isinstance(expr, Call): - self.old_calls_to_new_calls[expr] = new_function_id - else: - assert isinstance(expr, CallWithKwargs) - self.old_calls_to_new_calls = new_function_id + if new_arg_id_to_dtype is None: + return [] - # Returning the type. + # collecting result dtypes in order of the assignees + if -1 in new_arg_id_to_dtype and new_arg_id_to_dtype[-1] is not None: if return_tuple: - if mangle_result is not None: - return [mangle_result.result_dtypes] + return [get_return_types_as_tuple(new_arg_id_to_dtype)] else: - if mangle_result is not None: - if len(mangle_result.result_dtypes) != 1 and not return_tuple: - raise LoopyError("functions with more or fewer than one " - "return value may only be used in direct " - "assignments") - - return [mangle_result.result_dtypes[0]] - # }}} + return [new_arg_id_to_dtype[-1]] return [] @@ -678,10 +616,10 @@ class TypeInferenceMapper(CombineMapper): rec_results = self.rec(expr.expr) if return_tuple: - return [expr.operation.result_dtypes(self.kernel, *rec_result) + return [expr.operation.result_dtypes(*rec_result) for rec_result in rec_results] else: - return [expr.operation.result_dtypes(self.kernel, rec_result)[0] + return [expr.operation.result_dtypes(rec_result)[0] for rec_result in rec_results] def map_sub_array_ref(self, expr): @@ -1111,13 +1049,11 @@ def infer_unknown_types(program, expect_completion=False): renamed_entrypoints = set() for e in program.entrypoints: - # FIXME: Need to add docs which say that we need not add the current - # callable to the clbl_inf_ctx while writing the "with_types" logger.debug(f"Entering entrypoint: {e}") arg_id_to_dtype = {arg.name: arg.dtype for arg in program[e].args if arg.dtype not in (None, auto)} new_callable, clbl_inf_ctx = program.callables_table[e].with_types( - arg_id_to_dtype, None, clbl_inf_ctx) + arg_id_to_dtype, clbl_inf_ctx) clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable) renamed_entrypoints.add(new_name.name) @@ -1174,7 +1110,7 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( raise LoopyError("failed to determine type of accumulator for " "reduction '%s'" % expr) - reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes) + reduction_dtypes = expr.operation.result_dtypes(*arg_dtypes) reduction_dtypes = tuple( dt.with_target(kernel.target) if dt is not lp.auto else dt diff --git a/test/testlib.py b/test/testlib.py index 034a0188e..7009e8f5a 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -138,7 +138,7 @@ class SeparateTemporariesPreambleTestPreambleGenerator( class Log2Callable(lp.ScalarCallable): - def with_types(self, arg_id_to_dtype, kernel, callables_table): + def with_types(self, arg_id_to_dtype, callables_table): if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: # the types provided aren't mature enough to specialize the @@ -153,14 +153,13 @@ class Log2Callable(lp.ScalarCallable): # ints and unsigned casted to float32 dtype = np.float32 - from loopy.target.opencl import OpenCLTarget - name_in_target = "log2" - if not isinstance(kernel.target, OpenCLTarget): - # for CUDA, C Targets the name must be modified - if dtype == np.float32: - name_in_target = "log2f" - elif dtype == np.float128: - name_in_target = "log2l" + if dtype.type == np.float32: + name_in_target = "log2f" + elif dtype.type == np.float64: + name_in_target = "log2" + pass + else: + raise TypeError(f"log2: unexpected type {dtype}") from loopy.types import NumpyType return ( -- GitLab