diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 0cb3341aa163b27b8a2d56d94290e8de3614e629..f1719a86bcf9d426eec5bd8104c6636a75fc4678 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -32,6 +32,7 @@ from pymbolic.mapper import CombineMapper import islpy as isl import pyopencl as cl import pyopencl.array # noqa +from pytools import Record from loopy.diagnostic import TypeInferenceFailure, DependencyTypeInferenceFailure @@ -249,6 +250,26 @@ def get_opencl_vec_member(idx): return "s%d" % idx +class SeenFunction(Record): + """ + .. attribute:: name + .. attribute:: c_name + .. attribute:: arg_dtypes + + a tuple of arg dtypes + """ + + def __init__(self, name, c_name, arg_dtypes): + Record.__init__(self, + name=name, + c_name=c_name, + arg_dtypes=arg_dtypes) + + def __hash__(self): + return hash((type(self),) + + tuple((f, getattr(self, f)) for f in type(self).fields)) + + class LoopyCCodeMapper(RecursiveMapper): def __init__(self, kernel, seen_dtypes, seen_functions, var_subst_map={}, allow_complex=False): @@ -511,7 +532,7 @@ class LoopyCCodeMapper(RecursiveMapper): def seen_func(name): idt = self.kernel.index_dtype - self.seen_functions.add((name, name, (idt, idt))) + self.seen_functions.add(SeenFunction(name, name, (idt, idt))) if den_nonneg: if num_nonneg: @@ -622,7 +643,7 @@ class LoopyCCodeMapper(RecursiveMapper): "for function '%s' not understood" % identifier) - self.seen_functions.add((identifier, c_name, par_dtypes)) + self.seen_functions.add(SeenFunction(identifier, c_name, par_dtypes)) if str_parameters is None: # /!\ FIXME For some functions (e.g. 'sin'), it makes sense to # propagate the type context here. But for many others, it does diff --git a/loopy/library/preamble.py b/loopy/library/preamble.py index 590320b058622bc89f188e98f92e891efcda8a4d..2b7be40af87b6d0fe374f64715f6cf0d5500c8ce 100644 --- a/loopy/library/preamble.py +++ b/loopy/library/preamble.py @@ -58,7 +58,7 @@ def default_preamble_generator(seen_dtypes, seen_functions): #include <pyopencl-complex.h> """) - c_funcs = set(c_name for name, c_name, arg_dtypes in seen_functions) + c_funcs = set(func.c_name for func in seen_functions) if "int_floor_div" in c_funcs: yield ("05_int_floor_div", """ #define int_floor_div(a,b) \ diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index e3ea9f47f00b931186bf4f72f405432da43a0528..bc83cf31f8411b40e8454df869d392a13c11a0ab 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -272,8 +272,8 @@ def reduction_function_mangler(func_id, arg_dtypes): def reduction_preamble_generator(seen_dtypes, seen_functions): - for func_id, c_name, arg_dtypes in seen_functions: - if isinstance(func_id, ArgExtFunction): - yield get_argext_preamble(func_id) + for func in seen_functions: + if isinstance(func.name, ArgExtFunction): + yield get_argext_preamble(func.name) # vim: fdm=marker