diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 79adb922d69c6eb81212fd061e667b6529bc6852..e83515d31f1c61e52569d8d0754ce79e7a7f602f 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -368,6 +368,7 @@ class PreambleInfo(ImmutableRecord): .. attribute:: seen_dtypes .. attribute:: seen_functions .. attribute:: seen_atomic_dtypes + .. attribute:: codegen_state """ @@ -496,7 +497,9 @@ def generate_code_v2(kernel): seen_dtypes=seen_dtypes, seen_functions=seen_functions, # a set of LoopyTypes (!) - seen_atomic_dtypes=seen_atomic_dtypes) + seen_atomic_dtypes=seen_atomic_dtypes, + codegen_state=codegen_state + ) preamble_generators = (kernel.preamble_generators + kernel.target.get_device_ast_builder().preamble_generators()) diff --git a/test/test_loopy.py b/test/test_loopy.py index 397d4832b0004725d7fc1559569c352861cb033e..bd9b25991b2814ea57363182d8cd404a53d0bacd 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2509,6 +2509,143 @@ def test_execution_backend_can_cache_dtypes(ctx_factory): knl(queue) +def test_preamble_with_separate_temporaries(ctx_factory): + from loopy.kernel.data import temp_var_scope as scopes + # create a function mangler + + func_name = 'indirect' + func_arg_dtypes = (np.int32, np.int32, np.int32) + func_result_dtypes = (np.int32,) + + def __indirectmangler(kernel, name, arg_dtypes): + """ + A function that will return a :class:`loopy.kernel.data.CallMangleInfo` + to interface with the calling :class:`loopy.LoopKernel` + """ + if name != func_name: + return None + + from loopy.types import to_loopy_type + from loopy.kernel.data import CallMangleInfo + + def __compare(d1, d2): + # compare dtypes ignoring atomic + return to_loopy_type(d1, for_atomic=True) == \ + to_loopy_type(d2, for_atomic=True) + + # check types + if len(arg_dtypes) != len(arg_dtypes): + raise Exception('Unexpected number of arguments provided to mangler ' + '{}, expected {}, got {}'.format( + func_name, len(func_arg_dtypes), len(arg_dtypes))) + + for i, (d1, d2) in enumerate(zip(func_arg_dtypes, arg_dtypes)): + if not __compare(d1, d2): + raise Exception('Argument at index {} for mangler {} does not ' + 'match expected dtype. Expected {}, got {}'. + format(i, func_name, str(d1), str(d2))) + + # get target for creation + target = arg_dtypes[0].target + return CallMangleInfo( + target_name=func_name, + result_dtypes=tuple(to_loopy_type(x, target=target) for x in + func_result_dtypes), + arg_dtypes=arg_dtypes) + + # create the preamble generator + def create_preamble(arr): + def __indirectpreamble(preamble_info): + # find a function matching our name + func_match = next( + (x for x in preamble_info.seen_functions + if x.name == func_name), None) + desc = 'custom_funcs_indirect' + if func_match is not None: + from loopy.types import to_loopy_type + # check types + if tuple(to_loopy_type(x) for x in func_arg_dtypes) == \ + func_match.arg_dtypes: + # if match, create our temporary + var = lp.TemporaryVariable( + 'lookup', initializer=arr, dtype=arr.dtype, shape=arr.shape, + scope=scopes.GLOBAL, read_only=True) + # and code + code = """ + int {name}(int start, int end, int match) + {{ + int result = start; + for (int i = start + 1; i < end; ++i) + {{ + if (lookup[i] == match) + result = i; + }} + return result; + }} + """.format(name=func_name) + + # generate temporary variable code + from cgen import Initializer + from loopy.target.c import generate_array_literal + codegen_state = preamble_info.codegen_state.copy( + is_generating_device_code=True) + kernel = preamble_info.kernel + ast_builder = codegen_state.ast_builder + target = kernel.target + decl_info, = var.decl_info(target, index_dtype=kernel.index_dtype) + decl = ast_builder.wrap_global_constant( + ast_builder.get_temporary_decl( + codegen_state, None, var, + decl_info)) + if var.initializer is not None: + decl = Initializer(decl, generate_array_literal( + codegen_state, var, var.initializer)) + # return generated code + yield (desc, '\n'.join([str(decl), code])) + return __indirectpreamble + + # and finally create a test + n = 10 + # for each entry come up with a random number of data points + num_data = np.asarray(np.random.randint(2, 10, size=n), dtype=np.int32) + # turn into offsets + offsets = np.asarray(np.hstack(([0], np.cumsum(num_data))), dtype=np.int32) + # create lookup data + lookup = np.empty(0) + for i in num_data: + lookup = np.hstack((lookup, np.arange(i))) + lookup = np.asarray(lookup, dtype=np.int32) + # and create data array + data = np.random.rand(np.product(num_data)) + + # make kernel + kernel = lp.make_kernel('{[i]: 0 <= i < n}', + """ + for i + <>ind = indirect(offsets[i], offsets[i + 1], 1) + out[i] = data[ind] + end + """, + [lp.GlobalArg('out', shape=('n',)), + lp.TemporaryVariable( + 'offsets', shape=(offsets.size,), initializer=offsets, scope=scopes.GLOBAL, + read_only=True), + lp.GlobalArg('data', shape=(data.size,), dtype=np.float64)], + ) + # fixt params, and add manglers / preamble + kernel = lp.fix_parameters(kernel, **{'n': n}) + kernel = lp.register_preamble_generators(kernel, [create_preamble(lookup)]) + kernel = lp.register_function_manglers(kernel, [__indirectmangler]) + + print(lp.generate_code(kernel)[0]) + # and call (functionality unimportant, more that it compiles) + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + # check that it actually performs the lookup correctly + assert np.allclose(kernel( + queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])