diff --git a/examples/hello-loopy.py b/examples/hello-loopy.py index a35d7272de26a4993f0d13682f8aba38091ea792..a88d81f88e3e85179088564c9f4370adcbfdaf34 100644 --- a/examples/hello-loopy.py +++ b/examples/hello-loopy.py @@ -18,12 +18,7 @@ a = cl.array.arange(queue, n, dtype=np.float32) knl = lp.make_kernel( ctx.devices[0], "{ [i]: 0<=i<n }", - "out[i] = 2*a[i]", - [ # argument declarations - lp.GlobalArg("out"), - lp.GlobalArg("a"), - lp.ValueArg("n"), - ]) + "out[i] = 2*a[i]") # ----------------------------------------------------------------------------- # transformation diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index afa70f385be7aed3c0a9718e4ac066ebfe32da67..2362b79665dbbc2eb6a71406989bdd6999774ca2 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -389,9 +389,107 @@ def parse_domains(ctx, domains, defines): # }}} -def guess_kernel_args_if_requested(domains, instructions, kernel_args): +# {{{ guess kernel args (if requested) + +class IndexRankFinder(WalkMapper): + def __init__(self, arg_name): + self.arg_name = arg_name + self.index_ranks = [] + + def map_subscript(self, expr): + WalkMapper.map_subscript(self, expr) + + from pymbolic.primitives import Variable + assert isinstance(expr.aggregate, Variable) + + if expr.aggregate.name == self.arg_name: + if not isinstance(expr.index, tuple): + self.index_ranks.append(1) + else: + self.index_ranks.append(len(expr.index)) + +def guess_kernel_args_if_requested(domains, instructions, temporary_variables, kernel_args): + if "..." not in kernel_args: + return kernel_args + + kernel_args = kernel_args[:] + kernel_args.remove("...") + + # {{{ find names that are *not* arguments + + all_inames = set() + for dom in domains: + all_inames.update(dom.get_var_names(dim_type.set)) + + temp_var_names = set(temporary_variables.iterkeys()) + + for insn in instructions: + if insn.temp_var_type is not None: + temp_var_names.add(insn.get_assignee_var_name()) + + # }}} + + # {{{ find existing and new arg names + + existing_arg_names = set() + for arg in kernel_args: + existing_arg_names.add(arg.name) + + not_new_arg_names = existing_arg_names | temp_var_names | all_inames + + all_names = set() + all_written_names = set() + from loopy.symbolic import get_dependencies + for insn in instructions: + all_written_names.add(insn.get_assignee_var_name()) + all_names.update(get_dependencies(insn.expression)) + all_names.update(get_dependencies(insn.assignee)) + + all_params = set() + for dom in domains: + all_params.update(dom.get_var_names(dim_type.param)) + all_params = all_params - all_inames + + new_arg_names = (all_names - not_new_arg_names) | all_params + + # }}} + + def find_index_rank(name): + irf = IndexRankFinder(name) + for insn in instructions: + irf(insn.expression) + irf(insn.assignee) + + if not irf.index_ranks: + return 0 + else: + from pytools import single_valued + return single_valued(irf.index_ranks) + + from loopy.kernel.data import ValueArg, GlobalArg + for arg_name in sorted(new_arg_names): + if arg_name in all_params: + kernel_args.append(ValueArg(arg_name)) + continue + + if arg_name in all_written_names: + # It's not a temp var, and thereby not a domain parameter--the only + # other writable type of variable is an argument. + + kernel_args.append(GlobalArg(arg_name, shape="auto")) + continue + + irank = find_index_rank(arg_name) + if irank == 0: + # read-only, no indices + kernel_args.append(ValueArg(arg_name)) + else: + kernel_args.append(GlobalArg(arg_name, strides="auto")) + return kernel_args +# }}} + # {{{ tag reduction inames as sequential def tag_reduction_inames_as_sequential(knl): @@ -715,7 +813,7 @@ def guess_arg_shape_if_requested(kernel, default_order): # {{{ kernel creation top-level -def make_kernel(device, domains, instructions, kernel_args=[], **kwargs): +def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs): """User-facing kernel creation entrypoint. :arg device: :class:`pyopencl.Device` @@ -793,7 +891,8 @@ def make_kernel(device, domains, instructions, kernel_args=[], **kwargs): domains = parse_domains(isl_context, domains, defines) - kernel_args = guess_kernel_args_if_requested(domains, instructions, kernel_args) + kernel_args = guess_kernel_args_if_requested(domains, instructions, + kwargs.get("temporary_variables", {}), kernel_args) from loopy.kernel import LoopKernel knl = LoopKernel(device, domains, instructions, kernel_args, **kwargs) diff --git a/test/test_loopy.py b/test/test_loopy.py index b26340e39877cc3e2f254bfe7914a5c853595144..28275fb1331b94202f9cf2770fb53601025b241c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1055,6 +1055,22 @@ def test_arg_shape_guessing(ctx_factory): +def test_arg_guessing(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j]: 0<=i,j<n }", + ], + """ + a = 1.5 + sum((i,j), i*j) + b[i, j] = i*j + c[i+j, j] = b[j,i] + """, + assumptions="n>=1") + + print knl + print lp.CompiledKernel(ctx, knl).get_highlighted_code() + def test_nonlinear_index(ctx_factory): ctx = ctx_factory()