diff --git a/loopy/__init__.py b/loopy/__init__.py index 3b946f8f89c5440894ac726bbe45db321233939e..bf0a2be1bae22f93bad784ea59874fb305f898f8 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -67,8 +67,8 @@ from loopy.transform.instruction import ( from loopy.transform.data import ( add_prefetch, change_arg_to_image, tag_data_axes, set_array_dim_names, remove_unused_arguments, - alias_temporaries, set_argument_order - ) + alias_temporaries, set_argument_order, + rename_argument) from loopy.transform.subst import (extract_subst, assignment_to_subst, expand_subst, find_rules_matching, @@ -136,6 +136,7 @@ __all__ = [ "add_prefetch", "change_arg_to_image", "tag_data_axes", "set_array_dim_names", "remove_unused_arguments", "alias_temporaries", "set_argument_order", + "rename_argument", "find_instructions", "map_instructions", "set_instruction_priority", "add_dependency", @@ -170,7 +171,8 @@ __all__ = [ "generate_loop_schedules", "get_one_scheduled_kernel", "generate_code", "generate_body", - "get_op_poly", "sum_ops_to_dtypes", "get_gmem_access_poly", "get_DRAM_access_poly", + "get_op_poly", "sum_ops_to_dtypes", "get_gmem_access_poly", + "get_DRAM_access_poly", "get_barrier_poly", "stringify_stats_mapping", "sum_mem_access_to_bytes", "CompiledKernel", diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 53f3479c2cf131d9c99b0fb21d33746131af17c6..f510af59873273486937ad01e8c3b9eba15aefda 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -461,4 +461,48 @@ def set_argument_order(kernel, arg_names): # }}} +# {{{ rename argument + +def rename_argument(kernel, old_name, new_name, existing_ok=False): + """ + .. versionadded:: 2016.2 + """ + + var_name_gen = kernel.get_var_name_generator() + + if old_name not in kernel.arg_dict: + raise LoopyError("old arg name '%s' does not exist" % old_name) + + does_exist = var_name_gen.is_name_conflicting(new_name) + + if does_exist and not existing_ok: + raise LoopyError("argument name '%s' conflicts with an existing identifier" + "--cannot rename" % new_name) + + from pymbolic import var + subst_dict = {old_name: var(new_name)} + + from loopy.symbolic import ( + RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext) + from pymbolic.mapper.substitutor import make_subst_func + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, var_name_gen) + smap = RuleAwareSubstitutionMapper(rule_mapping_context, + make_subst_func(subst_dict), + within=lambda knl, insn, stack: True) + + kernel = smap.map_kernel(kernel) + + new_args = [] + for arg in kernel.args: + if arg.name == old_name: + arg = arg.copy(name=new_name) + + new_args.append(arg) + + return kernel.copy(args=new_args) + +# }}} + # vim: foldmethod=marker diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 20cf5b87b897d7a6b15de1de2fb4698e5ddf7bef..fb22df37a6215a8993f0a95cf88465c2ce8fe978 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -779,8 +779,11 @@ def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): does_exist = var_name_gen.is_name_conflicting(new_iname) + if old_iname not in knl.all_inames(): + raise LoopyError("old iname '%s' does not exist" % old_iname) + if does_exist and not existing_ok: - raise ValueError("iname '%s' conflicts with an existing identifier" + raise LoopyError("iname '%s' conflicts with an existing identifier" "--cannot rename" % new_iname) if does_exist: @@ -824,11 +827,11 @@ def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): from pymbolic.mapper.substitutor import make_subst_func rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, var_name_gen) - ijoin = RuleAwareSubstitutionMapper(rule_mapping_context, + smap = RuleAwareSubstitutionMapper(rule_mapping_context, make_subst_func(subst_dict), within) knl = rule_mapping_context.finish_kernel( - ijoin.map_kernel(knl)) + smap.map_kernel(knl)) new_instructions = [] for insn in knl.instructions: diff --git a/test/test_loopy.py b/test/test_loopy.py index a23782c3cd18101ea892e4af143c52d008875209..667d7365d412d9be6d9a51c8d63c9470ffede2f9 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2160,6 +2160,21 @@ def test_sci_notation_literal(ctx_factory): assert (np.abs(out.get() - 1e-12) < 1e-20).all() +def test_rename_argument(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + kernel = lp.make_kernel( + '''{ [i]: 0<=i<n }''', + '''out[i] = a + 2''') + + kernel = lp.rename_argument(kernel, "a", "b") + + evt, (out,) = kernel(queue, b=np.float32(12), n=20) + + assert (np.abs(out.get() - 14) < 1e-8).all() + + def test_to_batched(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx)