diff --git a/doc/reference.rst b/doc/reference.rst index 8b1a4d6f9f4069f6b24eeed0590ad4743dbd9ff7..1067f6b1111d8e0ed25c95b3cc9250d4ff5f22f3 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -208,7 +208,7 @@ Argument types .. autofunction:: add_argument_dtypes -.. autofunction:: infer_argument_dtypes +.. autofunction:: infer_unknown_types .. autofunction:: add_and_infer_argument_dtypes diff --git a/loopy/__init__.py b/loopy/__init__.py index 186b269e293ae9cd6956e8545abebb9a25f48083..346edcec10ba2b9713f1fc1ad26a0a1ea6e9778d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -61,14 +61,14 @@ from loopy.kernel.data import ( from loopy.kernel import LoopKernel from loopy.kernel.tools import ( get_dot_dependency_graph, add_argument_dtypes, - infer_argument_dtypes, add_and_infer_argument_dtypes) + add_and_infer_argument_dtypes) from loopy.kernel.creation import make_kernel from loopy.reduction import register_reduction_parser from loopy.subst import extract_subst, expand_subst from loopy.cse import precompute from loopy.padding import (split_arg_axis, find_padding_multiple, add_padding) -from loopy.preprocess import preprocess_kernel, realize_reduction +from loopy.preprocess import preprocess_kernel, realize_reduction, infer_unknown_types from loopy.schedule import generate_loop_schedules from loopy.codegen import generate_code from loopy.compiled import CompiledKernel, auto_test_vs_ref @@ -88,7 +88,7 @@ __all__ = [ "get_dot_dependency_graph", "add_argument_dtypes", "infer_argument_dtypes", "add_and_infer_argument_dtypes", - "preprocess_kernel", "realize_reduction", + "preprocess_kernel", "realize_reduction", "infer_unknown_types", "generate_loop_schedules", "generate_code", "CompiledKernel", "auto_test_vs_ref", "check_kernels", diff --git a/loopy/compiled.py b/loopy/compiled.py index b6d0df8d0f3b35a8d04428d9ad465e5c14668262..59630a9a41d800fd6da654dca0ea8eea60f4bc58 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -125,14 +125,14 @@ class CompiledKernel: from loopy.kernel.tools import ( add_argument_dtypes, - infer_argument_dtypes, get_arguments_with_incomplete_dtype) if get_arguments_with_incomplete_dtype(kernel): if dtype_mapping_set is not None: kernel = add_argument_dtypes(kernel, dict(dtype_mapping_set)) - kernel = infer_argument_dtypes(kernel) + from loopy.preprocess import infer_unknown_types + kernel = infer_unknown_types(kernel, expect_completion=True) import loopy as lp if kernel.schedule is None: diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 2d0bbb40190cfddb5597d360507210b82a5bb863..56074508418d5c38e59bed4c1bec3cbe5fe23e52 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -65,73 +65,15 @@ def add_argument_dtypes(knl, dtype_dict): return knl.copy(args=new_args) -def _infer_argument_dtypes_inner(knl): - new_args = [] - did_something = False - - writer_map = knl.writer_map() - - from loopy.codegen.expression import ( - TypeInferenceMapper, TypeInferenceFailure) - tim = TypeInferenceMapper(knl) - - from loopy.symbolic import SubstitutionRuleExpander - - submap = SubstitutionRuleExpander(knl.substitutions, - knl.get_var_name_generator()) - - for arg in knl.args: - if arg.dtype is None: - new_dtype = None - - if arg.name in knl.all_params(): - new_dtype = knl.index_dtype - else: - try: - for write_insn_id in writer_map.get(arg.name, ()): - write_insn = knl.id_to_insn[write_insn_id] - - new_tim_dtype = tim( - submap(write_insn.expression, write_insn_id)) - if new_dtype is None: - new_dtype = new_tim_dtype - elif new_dtype != new_tim_dtype: - # Now we know *nothing*. - new_dtype = None - break - - except TypeInferenceFailure: - # Even one type inference failure is enough to - # make this dtype not safe to guess. Don't. - pass - - if new_dtype is not None: - did_something = True - arg = arg.copy(dtype=new_dtype) - - new_args.append(arg) - - return knl.copy(args=new_args), did_something - def get_arguments_with_incomplete_dtype(knl): return [arg.name for arg in knl.args if arg.dtype is None] -def infer_argument_dtypes(knl): - while True: - knl, did_something = _infer_argument_dtypes_inner(knl) - incomplete_args = get_arguments_with_incomplete_dtype(knl) - - if incomplete_args: - if not did_something: - raise RuntimeError("not all argument dtypes are specified " - "or could be inferred: " + ", ".join(incomplete_args)) - else: - return knl - def add_and_infer_argument_dtypes(knl, dtype_dict): knl = add_argument_dtypes(knl, dtype_dict) - return infer_argument_dtypes(knl) + + from loopy.preprocess import infer_unknown_types + return infer_unknown_types(knl, expect_completion=True) # }}} diff --git a/loopy/preprocess.py b/loopy/preprocess.py index bb168b347fb10962a7511c678d49170bc8cac815..0a23d0b3d800e29280018ec5f35d77c8f3eb19a4 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -28,32 +28,36 @@ THE SOFTWARE. import pyopencl as cl import pyopencl.characterize as cl_char +import logging +logger = logging.getLogger(__name__) + # {{{ infer types -def infer_temp_var_type(kernel, tv, type_inf_mapper, debug): - dtypes = [] +def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander): + if var_name in kernel.all_params(): + return kernel.index_dtype - writers = kernel.writer_map()[tv.name] - exprs = [kernel.id_to_insn[w].expression for w in writers] + dtypes = [] from loopy.codegen.expression import DependencyTypeInferenceFailure - for expr in exprs: + for writer_insn_id in kernel.writer_map()[var_name]: + expr = subst_expander( + kernel.id_to_insn[writer_insn_id].expression, + insn_id=writer_insn_id) + try: - if debug: - print " via expr", expr + logger.debug(" via expr %s" % expr) result = type_inf_mapper(expr) - if debug: - print " result", result + logger.debug(" result: %s" % result) dtypes.append(result) except DependencyTypeInferenceFailure: - if debug: - print " failed" + logger.debug(" failed") if not dtypes: return None @@ -61,14 +65,34 @@ def infer_temp_var_type(kernel, tv, type_inf_mapper, debug): from pytools import is_single_valued if not is_single_valued(dtypes): raise RuntimeError("ambiguous type inference for '%s'" - % tv.name) + % var_name) return dtypes[0] -def infer_types_of_temporaries(kernel): - """Infer types on temporaries.""" +class _DictUnionView: + def __init__(self, children): + self.children = children + + def get(self, key): + try: + return self[key] + except KeyError: + return None + + def __getitem__(self, key): + for ch in self.children: + try: + return ch[key] + except KeyError: + pass + + raise KeyError(key) + +def infer_unknown_types(kernel, expect_completion=False): + """Infer types on temporaries and argumetns.""" new_temp_vars = kernel.temporary_variables.copy() + new_arg_dict = kernel.arg_dict.copy() # {{{ fill queue @@ -80,56 +104,75 @@ def infer_types_of_temporaries(kernel): if tv.dtype is lp.auto: queue.append(tv) + for arg in kernel.args: + if arg.dtype is None: + queue.append(arg) + # }}} from loopy.codegen.expression import TypeInferenceMapper - type_inf_mapper = TypeInferenceMapper(kernel, new_temp_vars) + type_inf_mapper = TypeInferenceMapper(kernel, + _DictUnionView([ + new_temp_vars, + new_arg_dict + ])) - # {{{ work on type inference queue + from loopy.symbolic import SubstitutionRuleExpander + subst_expander = SubstitutionRuleExpander(kernel.substitutions, + kernel.get_var_name_generator()) - from loopy.kernel.data import TemporaryVariable + # {{{ work on type inference queue - debug = 0 + from loopy.kernel.data import TemporaryVariable, KernelArgument - first_failure = None + failed_names = set() while queue: item = queue.pop(0) - if isinstance(item, TemporaryVariable): - if debug: - print "inferring type for tempvar", item.name + logger.debug("inferring type for %s %s" % (type(item).__name__, item.name)) - result = infer_temp_var_type(kernel, item, type_inf_mapper, debug) + result = _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander) - failed = result is None - if not failed: - if debug: - print " success", result + failed = result is None + if not failed: + logger.debug(" success: %s" % result) + if isinstance(item, TemporaryVariable): new_temp_vars[item.name] = item.copy(dtype=result) + elif isinstance(item, KernelArgument): + new_arg_dict[item.name] = item.copy(dtype=result) else: - if debug: - print " failure", result + raise RuntimeError("unexpected item type in type inference") else: - raise RuntimeError("unexpected item type in type inference") + logger.debug(" failure") if failed: - if item is first_failure: + if expect_completion and item.name in failed_names: # this item has failed before, give up. raise RuntimeError("could not determine type of '%s'" % item.name) - if first_failure is None: - # remember the first failure for this round through the queue - first_failure = item + # remember that this item failed + failed_names.add(item.name) + + queue_names = set(qi.name for qi in queue) + + if queue_names == failed_names: + # We did what we could... + print queue_names, failed_names, item.name + assert not expect_completion + break # can't infer type yet, put back into queue queue.append(item) else: - # we've made progress, reset failure marker - first_failure = None + # we've made progress, reset failure markers + failed_names = set() # }}} - return kernel.copy(temporary_variables=new_temp_vars) + return kernel.copy( + temporary_variables=new_temp_vars, + args=[new_arg_dict[arg.name] for arg in kernel.args], + ) # }}} @@ -911,7 +954,7 @@ def preprocess_kernel(kernel): # Type inference doesn't handle substitutions. Get them out of the # way. - kernel = infer_types_of_temporaries(kernel) + kernel = infer_unknown_types(kernel, expect_completion=False) # Ordering restriction: # realize_reduction must happen after type inference because it needs