diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 699c045ea470e4af37f73e479cc566da82a4d86d..ad45cc172abd5386a543f6818386a045c06eb5b7 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -25,7 +25,10 @@ THE SOFTWARE. import six from pymbolic.mapper import CombineMapper +from pymbolic.primitives import Call, CallWithKwargs +from loopy.symbolic import IdentityMapper, ScopedFunction import numpy as np +import re from loopy.tools import is_integer from loopy.types import NumpyType @@ -34,6 +37,9 @@ from loopy.diagnostic import ( LoopyError, TypeInferenceFailure, DependencyTypeInferenceFailure) +from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction, + _DataObliviousInstruction) + import logging logger = logging.getLogger(__name__) @@ -61,6 +67,7 @@ class TypeInferenceMapper(CombineMapper): self.new_assignments = new_assignments self.symbols_with_unknown_types = set() self.scoped_functions = kernel.scoped_functions + self.specialized_functions = {} def __call__(self, expr, return_tuple=False, return_dtype_set=False): kwargs = {} @@ -251,9 +258,7 @@ class TypeInferenceMapper(CombineMapper): return self.rec(expr.aggregate) def map_call(self, expr, return_tuple=False): - from pymbolic.primitives import Variable, Expression - from loopy.symbolic import SubArrayRef - from loopy.kernel.function_interface import ValueArgDescriptor + from pymbolic.primitives import Variable identifier = expr.function if isinstance(identifier, Variable): @@ -281,16 +286,9 @@ class TypeInferenceMapper(CombineMapper): self.scoped_functions[expr.function.name].with_types( arg_id_to_dtype)) - # need to colllect arg_id_to_descr from the Subarrayrefs - arg_id_to_descr = {} - for id, par in enumerate(expr.parameters): - if isinstance(par, SubArrayRef): - arg_id_to_descr[id] = par.get_arg_descr() - elif isinstance(par, Expression): - arg_id_to_descr[id] = ValueArgDescriptor() - else: - # should not come over here - raise LoopyError("Unexpected parameter given to call") + # storing the type specialized function so that it can be used for + # later use + self.specialized_functions[expr] = in_knl_callable new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype result_dtypes = [] @@ -488,11 +486,12 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander): dtype_sets.append(result) if not dtype_sets: - return None, type_inf_mapper.symbols_with_unknown_types + return None, type_inf_mapper.symbols_with_unknown_types, None result = type_inf_mapper.combine(dtype_sets) - return result, type_inf_mapper.symbols_with_unknown_types + return (result, type_inf_mapper.symbols_with_unknown_types, + type_inf_mapper.specialized_functions) # }}} @@ -517,6 +516,46 @@ class _DictUnionView: raise KeyError(key) +# {{{ FunctionType Specializer + + +# }}} + +# {{{ duplicating the funciton name + +def next_indexed_name(name): + FUNC_NAME = re.compile(r"^(?P\S+?)_(?P\d+?)$") + + match = FUNC_NAME.match(name) + + if match is None: + if name[-1] == '_': + return "{old_name}0".format(old_name=name) + else: + return "{old_name}_0".format(old_name=name) + + return "{alpha}_{num}".format(alpha=match.group('alpha'), + num=int(match.group('num'))+1) + +# }}} + + +# {{{ FunctionScopeChanger + +class FunctionScopeChanger(IdentityMapper): + def __init__(self, new_names): + self.new_names = new_names + + def map_call(self, expr): + return Call(ScopedFunction(self.new_names[expr]), + expr.parameters) + + def map_call_with_kwargs(self, expr): + return CallWithKwargs(ScopedFunction(self.new_names[expr]), + expr.parameters, expr.kw_parameters) +# }}} + + # {{{ infer_unknown_types def infer_unknown_types(kernel, expect_completion=False): @@ -590,6 +629,8 @@ def infer_unknown_types(kernel, expect_completion=False): from loopy.kernel.data import TemporaryVariable, KernelArgument + specialized_functions = {} + for var_chain in sccs: changed_during_last_queue_run = False queue = var_chain[:] @@ -613,7 +654,7 @@ def infer_unknown_types(kernel, expect_completion=False): debug("inferring type for %s %s", type(item).__name__, item.name) - result, symbols_with_unavailable_types = ( + result, symbols_with_unavailable_types, new_specialized_functions = ( _infer_var_type( kernel, item.name, type_inf_mapper, subst_expander)) @@ -634,6 +675,8 @@ def infer_unknown_types(kernel, expect_completion=False): new_arg_dict[name] = item.copy(dtype=new_dtype) else: raise LoopyError("unexpected item type in type inference") + specialized_functions = {**specialized_functions, + **new_specialized_functions} else: debug(" failure") @@ -676,11 +719,52 @@ def infer_unknown_types(kernel, expect_completion=False): logger.debug("type inference took {dur:.2f} seconds".format( dur=end_time - start_time)) - return unexpanded_kernel.copy( + pre_type_specialized_knl = unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], ) + # {{{ type specialization + + # TODO: These 2 dictionaries are inverse mapping of each other and help to keep + # track of which ...(need to explain better) + scoped_names_to_functions = {} + scoped_functions_to_names = {} + pymbolic_calls_to_new_names = {} + + for pymbolic_call, knl_callable in specialized_functions.items(): + if knl_callable not in scoped_functions_to_names: + # need to make a new name deerived from the old name such that new + # name in not present in new_scoped_name_to_function + old_name = pymbolic_call.function.name + new_name = next_indexed_name(old_name) + while new_name not in scoped_names_to_functions: + new_name = next_indexed_name(new_name) + + scoped_names_to_functions[new_name] = knl_callable + scoped_functions_to_names[knl_callable] = new_name + + pymbolic_calls_to_new_names[pymbolic_call] = ( + scoped_functions_to_names[knl_callable]) + + # }}} + + new_insns = [] + scope_changer = FunctionScopeChanger(pymbolic_calls_to_new_names) + for insn in pre_type_specialized_knl.instructions: + if isinstance(insn, (MultiAssignmentBase, CInstruction)): + expr = scope_changer(insn.expression) + new_insns.append(insn.copy(expression=expr)) + pass + elif isinstance(insn, _DataObliviousInstruction): + new_insns.append(insn) + else: + raise NotImplementedError("Type Inference Specialization not" + "implemented for %s instruciton" % type(insn)) + + return pre_type_specialized_knl.copy(scope_functions=scoped_names_to_functions, + instructions=new_insns) + # }}}