From 53995111eb1dd3274966d1c41382e3fc72f8ebdc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 18 Dec 2013 18:35:06 -0800 Subject: [PATCH] Allow string-o-names to specify argument order --- loopy/kernel/creation.py | 186 +++++++++++++++++++++++---------------- 1 file changed, 109 insertions(+), 77 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f68bc5042..824a8b8f4 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -401,104 +401,130 @@ class IndexRankFinder(WalkMapper): self.index_ranks.append(len(expr.index)) -def guess_kernel_args_if_requested(domains, instructions, temporary_variables, - subst_rules, kernel_args, default_offset): - # Ellipsis is syntactically allowed in Py3. - if "..." not in kernel_args and Ellipsis not in kernel_args: - return kernel_args +class ArgumentGuesser: + def __init__(self, domains, instructions, temporary_variables, + subst_rules, default_offset): + self.domains = domains + self.instructions = instructions + self.temporary_variables = temporary_variables + self.subst_rules = subst_rules + self.default_offset = default_offset + + from loopy.symbolic import SubstitutionRuleExpander + self.submap = SubstitutionRuleExpander(subst_rules) + + self.all_inames = set() + for dom in domains: + self.all_inames.update(dom.get_var_names(dim_type.set)) + + all_params = set() + for dom in domains: + all_params.update(dom.get_var_names(dim_type.param)) + self.all_params = all_params - self.all_inames + + self.all_names = set() + self.all_written_names = set() + from loopy.symbolic import get_dependencies + for insn in instructions: + if isinstance(insn, ExpressionInstruction): + (assignee_var_name, _), = insn.assignees_and_indices() + self.all_written_names.add(assignee_var_name) + self.all_names.update(get_dependencies( + self.submap(insn.assignee, insn.id))) + self.all_names.update(get_dependencies( + self.submap(insn.expression, insn.id))) - kernel_args = [arg for arg in kernel_args - if arg is not Ellipsis and arg != "..."] + def find_index_rank(self, name): + irf = IndexRankFinder(name) - from loopy.symbolic import SubstitutionRuleExpander - submap = SubstitutionRuleExpander(subst_rules) + for insn in self.instructions: + insn.with_transformed_expressions( + lambda expr: irf(self.submap(expr, insn.id))) - # {{{ find names that are *not* arguments + if not irf.index_ranks: + return 0 + else: + from pytools import single_valued + return single_valued(irf.index_ranks) - all_inames = set() - for dom in domains: - all_inames.update(dom.get_var_names(dim_type.set)) + def make_new_arg(self, arg_name): + arg_name = arg_name.strip() - temp_var_names = set(temporary_variables.iterkeys()) + from loopy.kernel.data import ValueArg, GlobalArg + import loopy as lp - for insn in instructions: - if isinstance(insn, ExpressionInstruction): - if insn.temp_var_type is not None: - (assignee_var_name, _), = insn.assignees_and_indices() - temp_var_names.add(assignee_var_name) + if arg_name in self.all_params: + return ValueArg(arg_name) - # }}} + if arg_name in self.all_written_names: + # It's not a temp var, and thereby not a domain parameter--the only + # other writable type of variable is an argument. - # {{{ find existing and new arg names + return GlobalArg(arg_name, + shape=lp.auto, offset=self.default_offset) - existing_arg_names = set() - for arg in kernel_args: - existing_arg_names.add(arg.name) + irank = self.find_index_rank(arg_name) + if irank == 0: + # read-only, no indices + return ValueArg(arg_name) + else: + return GlobalArg(arg_name, shape=lp.auto, offset=self.default_offset) - not_new_arg_names = existing_arg_names | temp_var_names | all_inames + def convert_names_to_full_args(self, kernel_args): + new_kernel_args = [] - all_names = set() - all_written_names = set() - from loopy.symbolic import get_dependencies - for insn in instructions: - if isinstance(insn, ExpressionInstruction): - (assignee_var_name, _), = insn.assignees_and_indices() - all_written_names.add(assignee_var_name) - all_names.update(get_dependencies(submap(insn.assignee, insn.id))) - all_names.update(get_dependencies(submap(insn.expression, insn.id))) + for arg in kernel_args: + if isinstance(arg, str) and arg != "....": + new_kernel_args.append(self.make_new_arg(arg)) + else: + new_kernel_args.append(arg) - from loopy.kernel.data import ArrayBase - for arg in kernel_args: - if isinstance(arg, ArrayBase): - if isinstance(arg.shape, tuple): - all_names.update(get_dependencies(arg.shape)) + return new_kernel_args - all_params = set() - for dom in domains: - all_params.update(dom.get_var_names(dim_type.param)) - all_params = all_params - all_inames + def guess_kernel_args_if_requested(self, kernel_args): + # Ellipsis is syntactically allowed in Py3. + if "..." not in kernel_args and Ellipsis not in kernel_args: + return kernel_args - new_arg_names = (all_names | all_params) - not_new_arg_names + kernel_args = [arg for arg in kernel_args + if arg is not Ellipsis and arg != "..."] - # }}} + # {{{ find names that are *not* arguments - def find_index_rank(name): - irf = IndexRankFinder(name) + temp_var_names = set(self.temporary_variables.iterkeys()) - for insn in instructions: - insn.with_transformed_expressions( - lambda expr: irf(submap(expr, insn.id))) + for insn in self.instructions: + if isinstance(insn, ExpressionInstruction): + if insn.temp_var_type is not None: + (assignee_var_name, _), = insn.assignees_and_indices() + temp_var_names.add(assignee_var_name) - if not irf.index_ranks: - return 0 - else: - from pytools import single_valued - return single_valued(irf.index_ranks) + # }}} - from loopy.kernel.data import ValueArg, GlobalArg - import loopy as lp - for arg_name in sorted(new_arg_names): - if arg_name in all_params: - kernel_args.append(ValueArg(arg_name)) - continue + # {{{ find existing and new arg names - if arg_name in all_written_names: - # It's not a temp var, and thereby not a domain parameter--the only - # other writable type of variable is an argument. + existing_arg_names = set() + for arg in kernel_args: + existing_arg_names.add(arg.name) - kernel_args.append( - GlobalArg(arg_name, shape=lp.auto, offset=default_offset)) - continue + not_new_arg_names = existing_arg_names | temp_var_names | self.all_inames - irank = find_index_rank(arg_name) - if irank == 0: - # read-only, no indices - kernel_args.append(ValueArg(arg_name)) - else: - kernel_args.append( - GlobalArg(arg_name, shape=lp.auto, offset=default_offset)) + from loopy.kernel.data import ArrayBase + from loopy.symbolic import get_dependencies + for arg in kernel_args: + if isinstance(arg, ArrayBase): + if isinstance(arg.shape, tuple): + self.all_names.update( + get_dependencies(arg.shape)) - return kernel_args + new_arg_names = (self.all_names | self.all_params) - not_new_arg_names + + # }}} + + for arg_name in sorted(new_arg_names): + kernel_args.append(self.make_new_arg(arg_name)) + + return kernel_args # }}} @@ -978,6 +1004,9 @@ def make_kernel(device, domains, instructions, kernel_data=["..."], **kwargs): from loopy.kernel.data import TemporaryVariable, ArrayBase + if isinstance(kernel_data, str): + kernel_data = kernel_data.split(",") + kernel_args = [] temporary_variables = {} for dat in kernel_data: @@ -1038,10 +1067,13 @@ def make_kernel(device, domains, instructions, kernel_data=["..."], **kwargs): domains = parse_domains(isl_context, domains, defines) - kernel_args = guess_kernel_args_if_requested(domains, instructions, - temporary_variables, substitutions, kernel_args, + arg_guesser = ArgumentGuesser(domains, instructions, + temporary_variables, substitutions, default_offset) + kernel_args = arg_guesser.convert_names_to_full_args(kernel_args) + kernel_args = arg_guesser.guess_kernel_args_if_requested(kernel_args) + from loopy.kernel import LoopKernel knl = LoopKernel(device, domains, instructions, kernel_args, temporary_variables=temporary_variables, -- GitLab