From 174485cfe41c8ef8353dc7b8755ddfb05665e580 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 7 Jun 2013 23:23:27 -0400 Subject: [PATCH] PEP8 loopy.symbolic --- loopy/symbolic.py | 52 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b6a389e46..de821eea4 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -25,8 +25,6 @@ THE SOFTWARE. """ - - from pytools import memoize, memoize_method, Record import pytools.lex @@ -37,7 +35,6 @@ from pymbolic.primitives import ( from pymbolic.mapper import ( CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, - RecursiveMapper, WalkMapper as WalkMapperBase, CallbackMapper as CallbackMapperBase, ) @@ -61,8 +58,6 @@ import re import numpy as np - - # {{{ loopy-specific primitives class FunctionIdentifier(Leaf): @@ -71,6 +66,7 @@ class FunctionIdentifier(Leaf): mapper_method = intern("map_loopy_function_identifier") + class TypedCSE(CommonSubexpression): def __init__(self, child, prefix=None, dtype=None): CommonSubexpression.__init__(self, child, prefix) @@ -101,6 +97,7 @@ class TaggedVariable(Variable): mapper_method = intern("map_tagged_variable") + class Reduction(AlgebraicLeaf): def __init__(self, operation, inames, expr): assert isinstance(inames, tuple) @@ -136,6 +133,7 @@ class Reduction(AlgebraicLeaf): mapper_method = intern("map_reduction") + class LinearSubscript(AlgebraicLeaf): def __init__(self, aggregate, index): self.aggregate = aggregate @@ -151,6 +149,7 @@ class LinearSubscript(AlgebraicLeaf): # }}} + # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin(object): @@ -166,9 +165,11 @@ class IdentityMapperMixin(object): map_linear_subscript = IdentityMapperBase.map_subscript + class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass + class WalkMapper(WalkMapperBase): def map_reduction(self, expr, *args): if not self.visit(expr): @@ -183,18 +184,22 @@ class WalkMapper(WalkMapperBase): map_linear_subscript = WalkMapperBase.map_subscript + class CallbackMapper(CallbackMapperBase, IdentityMapper): map_reduction = CallbackMapperBase.map_constant + class CombineMapper(CombineMapperBase): def map_reduction(self, expr): return self.rec(expr.expr) map_linear_subscript = CombineMapperBase.map_subscript + class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin): pass + class StringifyMapper(StringifyMapperBase): def map_reduction(self, expr, prec): return "reduce(%s, [%s], %s)" % ( @@ -225,6 +230,7 @@ class DependencyMapper(DependencyMapperBase): map_linear_subscript = DependencyMapperBase.map_subscript + class UnidirectionalUnifier(UnidirectionalUnifierBase): def map_reduction(self, expr, other, unis): if not isinstance(other, type(expr)): @@ -253,6 +259,7 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): # }}} + # {{{ identity mapper that expands subst rules on the fly def parse_tagged_name(expr): @@ -263,6 +270,7 @@ def parse_tagged_name(expr): else: raise RuntimeError("subst rule name not understood: %s" % expr) + class ExpansionState(Record): """ :ivar stack: a tuple representing the current expansion stack, as a tuple @@ -271,6 +279,7 @@ class ExpansionState(Record): :ivar arg_context: a dict representing current argument values """ + class SubstitutionRuleRenamer(IdentityMapper): def __init__(self, renames): self.renames = renames @@ -304,12 +313,14 @@ class SubstitutionRuleRenamer(IdentityMapper): else: return TaggedVariable(new_name, tag) + def rename_subst_rules_in_instructions(insns, renames): subst_renamer = SubstitutionRuleRenamer(renames) return [ insn.copy(expression=subst_renamer(insn.expression)) for insn in insns] + class ExpandingIdentityMapper(IdentityMapper): """Note: the third argument dragged around by this mapper is the current expansion expansion state. @@ -476,6 +487,7 @@ class ExpandingIdentityMapper(IdentityMapper): substitutions=new_substs, instructions=rename_subst_rules_in_instructions(new_insns, renames)) + class ExpandingSubstitutionMapper(ExpandingIdentityMapper): def __init__(self, rules, make_unique_var_name, subst_func, within): ExpandingIdentityMapper.__init__(self, rules, make_unique_var_name) @@ -492,6 +504,7 @@ class ExpandingSubstitutionMapper(ExpandingIdentityMapper): # }}} + # {{{ substitution rule expander class SubstitutionRuleExpander(ExpandingIdentityMapper): @@ -531,6 +544,7 @@ class SubstitutionRuleExpander(ExpandingIdentityMapper): # }}} + # {{{ functions to primitives, parsing class VarToTaggedVarMapper(IdentityMapper): @@ -542,6 +556,7 @@ class VarToTaggedVarMapper(IdentityMapper): return TaggedVariable(expr.name[:dollar_idx], expr.name[dollar_idx+1:]) + class FunctionToPrimitiveMapper(IdentityMapper): """Looks for invocations of a function called 'cse' or 'reduce' and turns those into the actual pymbolic primitives used for that. @@ -624,6 +639,7 @@ _close_dbl_bracket = intern("close_dbl_bracket") TRAILING_FLOAT_TAG_RE = re.compile("^(.*?)([a-zA-Z]*)$") + class LoopyParser(ParserBase): lex_table = [ (_open_dbl_bracket, pytools.lex.RE(r"\[\[")), @@ -644,7 +660,7 @@ class LoopyParser(ParserBase): elif tag == frozenset("d"): return np.float64(val) else: - return float(val) # generic float + return float(val) # generic float def parse_postfix(self, pstate, min_precedence, left_exp): from pymbolic.parser import _PREC_CALL @@ -660,12 +676,14 @@ class LoopyParser(ParserBase): # }}} + def parse(expr_str): return VarToTaggedVarMapper()( FunctionToPrimitiveMapper()(LoopyParser()(expr_str))) # }}} + # {{{ coefficient collector class CoefficientCollector(CoefficientCollectorBase): @@ -676,6 +694,7 @@ class CoefficientCollector(CoefficientCollectorBase): # }}} + # {{{ variable index expression collector class ArrayAccessFinder(CombineMapper): @@ -696,13 +715,15 @@ class ArrayAccessFinder(CombineMapper): from pymbolic.primitives import Variable assert isinstance(expr.aggregate, Variable) - if self.tgt_vector_name is None or expr.aggregate.name == self.tgt_vector_name: + if self.tgt_vector_name is None \ + or expr.aggregate.name == self.tgt_vector_name: return set([expr]) | self.rec(expr.index) else: return CombineMapper.map_subscript(self, expr) # }}} + # {{{ aff <-> expr conversion def aff_to_expr(aff, except_name=None, error_on_name=None): @@ -745,8 +766,6 @@ def aff_to_expr(aff, except_name=None, error_on_name=None): return result // denom - - def pw_aff_to_expr(pw_aff, int_ok=False): if isinstance(pw_aff, int): if not int_ok: @@ -763,6 +782,7 @@ def pw_aff_to_expr(pw_aff, int_ok=False): (set, aff), = pieces return aff_to_expr(aff) + def aff_from_expr(space, expr, vars_to_zero=set()): zero = isl.Aff.zero_on_domain(isl.LocalSpace.from_space(space)) context = {} @@ -780,13 +800,16 @@ def aff_from_expr(space, expr, vars_to_zero=set()): # }}} + # {{{ expression <-> constraint conversion def eq_constraint_from_expr(space, expr): - return isl.Constraint.equality_from_aff(aff_from_expr(space,expr)) + return isl.Constraint.equality_from_aff(aff_from_expr(space, expr)) + def ineq_constraint_from_expr(space, expr): - return isl.Constraint.inequality_from_aff(aff_from_expr(space,expr)) + return isl.Constraint.inequality_from_aff(aff_from_expr(space, expr)) + def constraint_to_expr(cns, except_name=None): # Looks like this is ok after all--get_aff() performs some magic. @@ -800,6 +823,7 @@ def constraint_to_expr(cns, except_name=None): # }}} + # {{{ Reduction callback mapper class ReductionCallbackMapper(IdentityMapper): @@ -814,6 +838,7 @@ class ReductionCallbackMapper(IdentityMapper): # }}} + # {{{ index dependency finding class IndexVariableFinder(CombineMapper): @@ -855,6 +880,7 @@ class IndexVariableFinder(CombineMapper): # }}} + # {{{ wildcard -> unique variable mapper class WildcardToUniqueVariableMapper(IdentityMapper): @@ -867,6 +893,7 @@ class WildcardToUniqueVariableMapper(IdentityMapper): # }}} + # {{{ prime ("'") adder class PrimeAdder(IdentityMapper): @@ -888,6 +915,7 @@ class PrimeAdder(IdentityMapper): # }}} + @memoize def get_dependencies(expr): from loopy.symbolic import DependencyMapper @@ -895,6 +923,7 @@ def get_dependencies(expr): return frozenset(dep.name for dep in dep_mapper(expr)) + # {{{ get access range def get_access_range(domain, subscript): @@ -932,6 +961,7 @@ def get_access_range(domain, subscript): # }}} + # {{{ access range mapper class AccessRangeMapper(WalkMapper): -- GitLab