diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 16c9fd4822a091b8986819da6dd5c8facdb05026..23617c48b2a5cfce23ad9effe2fd04843fd4d5d0 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -56,6 +56,7 @@ from pymbolic.mapper.constant_folder import \ ConstantFoldingMapper as ConstantFoldingMapperBase from pymbolic.parser import Parser as ParserBase +from loopy.diagnostic import LoopyError import islpy as isl from islpy import dim_type @@ -106,6 +107,9 @@ class IdentityMapperMixin(object): def map_type_annotation(self, expr, *args): return type(expr)(expr.type, self.rec(expr.child)) + def map_sub_array_ref(self, expr, *args): + return SubArrayRef(expr.swept_inames, expr.subscript) + map_type_cast = map_type_annotation map_linear_subscript = IdentityMapperBase.map_subscript @@ -169,6 +173,13 @@ class WalkMapper(WalkMapperBase): map_scoped_function = WalkMapperBase.map_variable + def map_sub_array_ref(self, expr, *args): + if not self.visit(expr): + return + + self.rec(expr.swept_inames, *args) + self.rec(expr.subscript, *args) + class CallbackMapper(CallbackMapperBase, IdentityMapper): map_reduction = CallbackMapperBase.map_constant @@ -241,6 +252,11 @@ class StringifyMapper(StringifyMapperBase): def map_scoped_function(self, expr, prec): return "ScopedFunction('%s')" % expr.name + def map_sub_array_ref(self, expr, prec): + return "SubArrayRef({inames}, ({subscr}))".format( + inames=self.rec(expr.swept_inames, prec), + subscr=self.rec(expr.subscript, prec)) + class UnidirectionalUnifier(UnidirectionalUnifierBase): def map_reduction(self, expr, other, unis): @@ -293,6 +309,10 @@ class DependencyMapper(DependencyMapperBase): def map_loopy_function_identifier(self, expr): return set() + def map_sub_array_ref(self, expr, *args): + deps = self.rec(expr.subscript, *args) + return deps - set(iname for iname in expr.swept_inames) + map_linear_subscript = DependencyMapperBase.map_subscript def map_type_cast(self, expr): @@ -660,6 +680,79 @@ class ScopedFunction(p.Variable): def stringifier(self): return StringifyMapper + +class SubArrayRef(p.Expression): + """Represents a generalized sliced notation of an array. + + .. attribute:: swept_inames + + These are a tuple of sweeping inames over the array. + + .. attribute:: subscript + + The subscript whose adress space is to be referenced + """ + + init_arg_names = ("swept_inames", "subscript") + + def __init__(self, swept_inames=None, subscript=None): + + # {{{ sanity checks + + if not isinstance(swept_inames, tuple): + assert isinstance(swept_inames, p.Variable) + swept_inames = (swept_inames,) + + assert isinstance(swept_inames, tuple) + + for iname in swept_inames: + assert isinstance(iname, p.Variable) + assert isinstance(subscript, p.Subscript) + + # }}} + + self.swept_inames = swept_inames + self.subscript = subscript + + def get_begin_subscript(self): + starting_inames = [] + for iname in self.subscript.index_tuple: + if iname in self.swept_inames: + starting_inames.append(parse('0')) + else: + starting_inames.append(iname) + return p.Subscript(self.subscript.aggregate, tuple(starting_inames)) + + def get_inner_dim_tags(self, arg_dim_tags): + """ Gives the dim tags for the inner inames. + This would be used for stride calculation in the child kernel. + This might need to go, once we start calculating the stride length + using the upper and lower bounds of the involved inames. + """ + from loopy.kernel.array import FixedStrideArrayDimTag as DimTag + inner_dim_tags = [] + for dim_tag, iname in zip(arg_dim_tags, self.subscript.index_tuple): + if iname in self.swept_inames: + inner_dim_tags.append(DimTag(dim_tag.stride)) + + return inner_dim_tags + + def __getinitargs__(self): + return (self.swept_inames, self.subscript) + + def get_hash(self): + return hash((self.__class__, self.swept_inames, self.subscript)) + + def is_equal(self, other): + return (other.__class__ == self.__class__ + and other.subscript == self.subscript + and other.swept_inames == self.swept_inames) + + def stringifier(self): + return StringifyMapper + + mapper_method = intern("map_sub_array_ref") + # }}} @@ -1122,6 +1215,14 @@ class FunctionToPrimitiveMapper(IdentityMapper): else: return IdentityMapper.map_call(self, expr) + def map_call_with_kwargs(self, expr): + for par in expr.kw_parameters.values(): + if not isinstance(par, SubArrayRef): + raise LoopyError("Keyword Arguments is only supported for" + " array arguments--use positional order to specify" + " the order of the arguments in the call.") + return IdentityMapper.map_call_with_kwargs(self, expr) + # {{{ customization to pymbolic parser @@ -1152,7 +1253,9 @@ class LoopyParser(ParserBase): return float(val) # generic float def parse_prefix(self, pstate): - from pymbolic.parser import _PREC_UNARY, _less, _greater, _identifier + from pymbolic.parser import (_PREC_UNARY, _less, _greater, _identifier, + _openbracket, _closebracket, _colon) + if pstate.is_next(_less): pstate.advance() if pstate.is_next(_greater): @@ -1168,6 +1271,18 @@ class LoopyParser(ParserBase): return TypeAnnotation( typename, self.parse_expression(pstate, _PREC_UNARY)) + + elif pstate.is_next(_openbracket): + pstate.advance() + pstate.expect_not_end() + swept_inames = self.parse_expression(pstate) + pstate.expect(_closebracket) + pstate.advance() + pstate.expect(_colon) + pstate.advance() + subscript = self.parse_expression(pstate, _PREC_UNARY) + return SubArrayRef(swept_inames, subscript) + else: return super(LoopyParser, self).parse_prefix(pstate) @@ -1767,6 +1882,10 @@ class BatchedAccessRangeMapper(WalkMapper): def map_type_cast(self, expr, inames): return self.rec(expr.child, inames) + def map_sub_array_ref(self, expr, inames): + total_inames = inames | set([iname.name for iname in expr.swept_inames]) + return self.rec(expr.subscript, total_inames) + class AccessRangeMapper(object): """**IMPORTANT**