diff --git a/MEMO b/MEMO index 13d8e6afe025f4a5f8836f16e90f02742c943bb0..839979e9aae8edbf25ce6019281ce314f4057730 100644 --- a/MEMO +++ b/MEMO @@ -50,15 +50,6 @@ To-do - when are link_inames, duplicate_inames safe? -- ExpandingIdentityMapper - extract_subst -> needs WalkMapper - padding - replace make_unique_var_name [DONE] - join_inames [DONE] - duplicate_inames [DONE] - split_iname [DONE] - CSE [DONE] - - Data implementation tags TODO initial bringup: - implemented_arg_info @@ -143,6 +134,15 @@ Future ideas Dealt with ^^^^^^^^^^ +- ExpandingIdentityMapper + extract_subst -> needs WalkMapper [actually fine as is] + padding [DONE] + replace make_unique_var_name [DONE] + join_inames [DONE] + duplicate_inames [DONE] + split_iname [DONE] + CSE [DONE] + - rename iname - delete unused inames diff --git a/loopy/padding.py b/loopy/padding.py index ccfcf4b1b6c49217805b7283cd8ff359b9e59245..9fbc4b59c8d54dc1777e78e445363b3293b28be3 100644 --- a/loopy/padding.py +++ b/loopy/padding.py @@ -25,21 +25,22 @@ THE SOFTWARE. -from loopy.symbolic import IdentityMapper +from loopy.symbolic import ExpandingIdentityMapper -class ArgAxisSplitHelper(IdentityMapper): - def __init__(self, arg_names, handler): +class ArgAxisSplitHelper(ExpandingIdentityMapper): + def __init__(self, rules, var_name_gen, arg_names, handler): + ExpandingIdentityMapper.__init__(self, rules, var_name_gen) self.arg_names = arg_names self.handler = handler - def map_subscript(self, expr): + def map_subscript(self, expr, expn_state): if expr.aggregate.name in self.arg_names: return self.handler(expr) else: - return IdentityMapper.map_subscript(self, expr) + return ExpandingIdentityMapper.map_subscript(self, expr, expn_state) @@ -176,19 +177,18 @@ def split_arg_axis(kernel, args_and_axes, count): return expr.aggregate[tuple(idx)] - aash = ArgAxisSplitHelper(set(arg_to_rest.iterkeys()), split_access_axis) + aash = ArgAxisSplitHelper(kernel.substitutions, var_name_gen, + set(arg_to_rest.iterkeys()), split_access_axis) + kernel = aash.map_kernel(kernel) - result = (kernel - .map_expressions(aash) - .copy(args=new_args)) + kernel = kernel.copy(args=new_args) from loopy import split_iname - for iname, (outer_iname, inner_iname) in split_vars.iteritems(): - result = split_iname(result, iname, count, + kernel = split_iname(kernel, iname, count, outer_iname=outer_iname, inner_iname=inner_iname) - return result + return kernel