Skip to content
Snippets Groups Projects
Commit 396da6d1 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Switch padding machinery over to ExpandingIdentityMapper.

parent 65f6d419
No related branches found
No related tags found
No related merge requests found
......@@ -50,15 +50,6 @@ To-do
- when are link_inames, duplicate_inames safe?
- ExpandingIdentityMapper
extract_subst -> needs WalkMapper
padding
replace make_unique_var_name [DONE]
join_inames [DONE]
duplicate_inames [DONE]
split_iname [DONE]
CSE [DONE]
- Data implementation tags
TODO initial bringup:
- implemented_arg_info
......@@ -143,6 +134,15 @@ Future ideas
Dealt with
^^^^^^^^^^
- ExpandingIdentityMapper
extract_subst -> needs WalkMapper [actually fine as is]
padding [DONE]
replace make_unique_var_name [DONE]
join_inames [DONE]
duplicate_inames [DONE]
split_iname [DONE]
CSE [DONE]
- rename iname
- delete unused inames
......
......@@ -25,21 +25,22 @@ THE SOFTWARE.
from loopy.symbolic import IdentityMapper
from loopy.symbolic import ExpandingIdentityMapper
class ArgAxisSplitHelper(IdentityMapper):
def __init__(self, arg_names, handler):
class ArgAxisSplitHelper(ExpandingIdentityMapper):
def __init__(self, rules, var_name_gen, arg_names, handler):
ExpandingIdentityMapper.__init__(self, rules, var_name_gen)
self.arg_names = arg_names
self.handler = handler
def map_subscript(self, expr):
def map_subscript(self, expr, expn_state):
if expr.aggregate.name in self.arg_names:
return self.handler(expr)
else:
return IdentityMapper.map_subscript(self, expr)
return ExpandingIdentityMapper.map_subscript(self, expr, expn_state)
......@@ -176,19 +177,18 @@ def split_arg_axis(kernel, args_and_axes, count):
return expr.aggregate[tuple(idx)]
aash = ArgAxisSplitHelper(set(arg_to_rest.iterkeys()), split_access_axis)
aash = ArgAxisSplitHelper(kernel.substitutions, var_name_gen,
set(arg_to_rest.iterkeys()), split_access_axis)
kernel = aash.map_kernel(kernel)
result = (kernel
.map_expressions(aash)
.copy(args=new_args))
kernel = kernel.copy(args=new_args)
from loopy import split_iname
for iname, (outer_iname, inner_iname) in split_vars.iteritems():
result = split_iname(result, iname, count,
kernel = split_iname(kernel, iname, count,
outer_iname=outer_iname, inner_iname=inner_iname)
return result
return kernel
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment