diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index dc09e9051fa19666cc66945d6d8bbda605188767..0ffb017edd825258b23aea450cb5d2853bde1cd3 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -33,9 +33,8 @@ import pytato.scalar_expr as scalar_expr import pymbolic.primitives as prim from pymbolic import var -from typing import ( - Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, Callable, - Any, List) +from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, + Any, List) from pytato.array import (Array, DictOfNamedArrays, ShapeType, IndexLambda, @@ -55,7 +54,7 @@ if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False): import pyopencl __doc__ = """ -.. autoclass:: LoopyExpressionContext +.. autoclass:: SharedLoopyExpressionContext .. autoclass:: ImplementedResult .. autoclass:: StoredResult .. autoclass:: InlinedResult @@ -68,7 +67,6 @@ __doc__ = """ .. autofunction:: domain_for_shape .. autofunction:: get_loopy_temporary .. autofunction:: add_store -.. autofunction:: rename_reductions .. autofunction:: normalize_outputs .. autofunction:: get_initial_codegen_state """ @@ -82,13 +80,13 @@ def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> # SymbolicIndex and ShapeType are semantically distinct but identical at the # type level. -ReductionBounds = Dict[str, Tuple[ScalarExpression, ScalarExpression]] +ReductionBounds = Mapping[str, Tuple[ScalarExpression, ScalarExpression]] -# {{{ LoopyExpressionContext +# {{{ LoopyExpressionContexts @dataclasses.dataclass(init=True, repr=False, eq=False) -class LoopyExpressionContext(object): +class SharedLoopyExpressionContext(object): """Mutable state used while generating :mod:`loopy` expressions. Wraps :class:`CodeGenState` with more expression-specific information. @@ -129,8 +127,6 @@ class LoopyExpressionContext(object): dataclasses.field(default_factory=frozenset) local_namespace: Mapping[str, Array] = \ dataclasses.field(default_factory=dict) - reduction_bounds: ReductionBounds = \ - dataclasses.field(default_factory=dict) def lookup(self, name: str) -> Array: return self.local_namespace[name] @@ -142,6 +138,22 @@ class LoopyExpressionContext(object): def update_depends_on(self, other: FrozenSet[str]) -> None: self._depends_on = self._depends_on | other + +@dataclasses.dataclass(frozen=True) +class UpstreamLoopyExpressionContext: + """ + Records context being to be conveyed from a parent expression to its + sub-expressions. + """ + reduction_bounds: ReductionBounds = \ + dataclasses.field(default_factory=dict) + + def copy(self, *, + reduction_bounds: Optional[ReductionBounds] = None + ) -> UpstreamLoopyExpressionContext: + reduction_bounds = reduction_bounds or self.reduction_bounds + return UpstreamLoopyExpressionContext(reduction_bounds) + # }}} @@ -156,7 +168,7 @@ class ImplementedResult(ABC): @abstractmethod def to_loopy_expression(self, indices: SymbolicIndex, - expr_context: LoopyExpressionContext) -> ScalarExpression: + expr_context: SharedLoopyExpressionContext) -> ScalarExpression: """Return a :mod:`loopy` expression for this result. :param indices: symbolic expressions for the indices of the array @@ -165,11 +177,6 @@ class ImplementedResult(ABC): - *depends_on* is populated with any dependencies needed for the generated expression. - - - *reduction_bounds* is populated with reduction bounds for the - reduction inames in the returned expression. If - *reduction_bounds* is nonempty, then the returned inames are - ensured to be disjoint from those present. """ # }}} @@ -188,7 +195,7 @@ class StoredResult(ImplementedResult): self.depends_on = depends_on def to_loopy_expression(self, indices: SymbolicIndex, - expr_context: LoopyExpressionContext) -> ScalarExpression: + expr_context: SharedLoopyExpressionContext) -> ScalarExpression: assert len(indices) == self.num_indices expr_context.update_depends_on(self.depends_on) if indices == (): @@ -209,36 +216,23 @@ class InlinedResult(ImplementedResult): """ def __init__(self, expr: ScalarExpression, num_indices: int, - reduction_bounds: ReductionBounds, depends_on: FrozenSet[str]): self.expr = expr self.num_indices = num_indices - self.reduction_bounds = dict(reduction_bounds) self.depends_on = depends_on @staticmethod def from_loopy_expression( loopy_expr: ScalarExpression, - loopy_expr_context: LoopyExpressionContext) -> InlinedResult: + loopy_expr_context: SharedLoopyExpressionContext) -> InlinedResult: return InlinedResult(loopy_expr, loopy_expr_context.num_indices, - loopy_expr_context.reduction_bounds, loopy_expr_context.depends_on) def to_loopy_expression(self, indices: SymbolicIndex, - expr_context: LoopyExpressionContext) -> ScalarExpression: + expr_context: SharedLoopyExpressionContext) -> ScalarExpression: assert len(indices) == self.num_indices substitutions = {f"_{d}": i for d, i in enumerate(indices)} - - reduction_start = len(expr_context.reduction_bounds) - - # Rename reductions in expression not to conflict with those in expr_context. - for i, (old_name, bounds) in enumerate(self.reduction_bounds.items()): - new_name = f"_r{i + reduction_start}" - assert new_name not in expr_context.reduction_bounds - substitutions[old_name] = var(new_name) - expr_context.reduction_bounds[new_name] = bounds - expr_context.update_depends_on(self.depends_on) return loopy_substitute(self.expr, substitutions) @@ -366,7 +360,7 @@ class CodeGenMapper(Mapper): # TODO: Respect tags. - loopy_expr_context = LoopyExpressionContext(state, + loopy_expr_context = SharedLoopyExpressionContext(state, local_namespace=expr.bindings, num_indices=expr.ndim) loopy_expr = self.exprgen_mapper(expr.expr, loopy_expr_context) @@ -479,7 +473,7 @@ class CodeGenMapper(Mapper): else: assert isinstance(arg, lp.ValueArg) and arg.is_input pt_arg = expr.bindings[arg.name] - loopy_expr_context = LoopyExpressionContext(state, + loopy_expr_context = SharedLoopyExpressionContext(state, local_namespace={}, num_indices=0) if isinstance(pt_arg, Array): assert pt_arg.ndim == 0 @@ -542,67 +536,76 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): self.codegen_mapper = codegen_mapper def __call__(self, expr: ScalarExpression, - expr_context: LoopyExpressionContext) -> ScalarExpression: - return self.rec(expr, expr_context) + shared_ctx: SharedLoopyExpressionContext, + upstream_ctx: Optional[UpstreamLoopyExpressionContext] = None, + ) -> ScalarExpression: + if upstream_ctx is None: + upstream_ctx = UpstreamLoopyExpressionContext() + return self.rec(expr, shared_ctx, upstream_ctx) def map_subscript(self, expr: prim.Subscript, - expr_context: LoopyExpressionContext) -> ScalarExpression: + shared_ctx: SharedLoopyExpressionContext, + upstream_ctx: UpstreamLoopyExpressionContext, + ) -> ScalarExpression: assert isinstance(expr.aggregate, prim.Variable) result: ImplementedResult = self.codegen_mapper( - expr_context.lookup(expr.aggregate.name), expr_context.state) - return result.to_loopy_expression(self.rec(expr.index, expr_context), - expr_context) + shared_ctx.lookup(expr.aggregate.name), shared_ctx.state) + return result.to_loopy_expression(self.rec(expr.index, shared_ctx, + upstream_ctx), + shared_ctx) def map_variable(self, expr: prim.Variable, - expr_context: LoopyExpressionContext) -> ScalarExpression: + shared_ctx: SharedLoopyExpressionContext, + upstream_ctx: UpstreamLoopyExpressionContext, + ) -> ScalarExpression: elw_match = ELWISE_INDEX_RE.fullmatch(expr.name) - redn_match = REDUCTION_INDEX_RE.fullmatch(expr.name) if elw_match: # Found an index of the form _0, _1, ... index = int(elw_match.group(1)) - if not (0 <= index < expr_context.num_indices): + if not (0 <= index < shared_ctx.num_indices): raise ValueError(f"invalid index encountered: _{index}") return expr - elif redn_match: - if expr.name not in expr_context.reduction_bounds: - raise ValueError(f"invalid index encountered: '{expr}'.") + elif expr.name in upstream_ctx.reduction_bounds: return expr else: - array = expr_context.lookup(expr.name) + array = shared_ctx.lookup(expr.name) impl_result: ImplementedResult = self.codegen_mapper(array, - expr_context.state) - return impl_result.to_loopy_expression((), expr_context) + shared_ctx.state) + return impl_result.to_loopy_expression((), shared_ctx) def map_call(self, expr: prim.Call, - expr_context: LoopyExpressionContext) -> ScalarExpression: + shared_ctx: SharedLoopyExpressionContext, + upstream_ctx: UpstreamLoopyExpressionContext + ) -> ScalarExpression: if isinstance(expr.function, prim.Variable) and ( expr.function.name.startswith("pytato.c99.")): name_in_loopy = expr.function.name[11:] return prim.Call(prim.Variable(name_in_loopy), - self.rec(expr.parameters, expr_context)) + self.rec(expr.parameters, shared_ctx, upstream_ctx)) - return super().map_call(expr, expr_context) + return super().map_call(expr, shared_ctx, upstream_ctx) def map_reduce(self, expr: scalar_expr.Reduce, - expr_context: LoopyExpressionContext) -> ScalarExpression: + shared_ctx: SharedLoopyExpressionContext, + upstream_ctx: UpstreamLoopyExpressionContext + ) -> ScalarExpression: from loopy.symbolic import Reduction as LoopyReduction - state = expr_context.state + state = shared_ctx.state unique_names_mapping = { - old_name: prim.Variable( - state.var_name_gen(f"_pt_{expr.op}" + old_name)) + old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name) for old_name in expr.bounds} - inner_expr = self.rec(expr.inner_expr, - LoopyExpressionContext( - state=state, - _depends_on=expr_context.depends_on, - local_namespace=expr_context.local_namespace, - num_indices=expr_context.num_indices, - reduction_bounds=expr.bounds)) # type: ignore - inner_expr = loopy_substitute(inner_expr, unique_names_mapping) + inner_expr = loopy_substitute(expr.inner_expr, + {k: prim.Variable(v) + for k, v in unique_names_mapping.items()}) + new_bounds = {unique_names_mapping[name]: bound_exprs + for name, bound_exprs in expr.bounds.items()} + + inner_expr = self.rec(inner_expr, shared_ctx, + upstream_ctx.copy(reduction_bounds=new_bounds)) try: loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op] @@ -610,12 +613,12 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): raise NotImplementedError(expr.op) inner_expr = LoopyReduction(loopy_redn, - tuple(v.name for v in unique_names_mapping.values()), - inner_expr) + tuple(unique_names_mapping.values()), + inner_expr) domain = domain_for_shape((), shape=(), reductions={ - unique_names_mapping[redn_iname].name: self.rec(bounds, expr_context) - for redn_iname, bounds in expr.bounds.items()}) + redn_iname: self.rec(bounds, shared_ctx, upstream_ctx) + for redn_iname, bounds in new_bounds.items()}) kernel = state.kernel state.update_kernel(kernel.copy(domains=kernel.domains+[domain])) @@ -630,7 +633,7 @@ def shape_to_scalar_expression(shape: ShapeType, cgen_mapper: CodeGenMapper, state: CodeGenState ) -> Tuple[ScalarExpression, ...]: - shape_context = LoopyExpressionContext(state, num_indices=0) + shape_context = SharedLoopyExpressionContext(state, num_indices=0) result: List[ScalarExpression] = [] for component in shape: if isinstance(component, int): @@ -726,14 +729,9 @@ def add_store(name: str, expr: Array, result: ImplementedResult, state.var_name_gen(f"{name}_dim{d}") for d in range(expr.ndim)) indices = tuple(prim.Variable(iname) for iname in inames) - loopy_expr_context = LoopyExpressionContext(state, num_indices=0) + loopy_expr_context = SharedLoopyExpressionContext(state, num_indices=0) loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) - # Rename reduction variables to names suitable as inames. - loopy_expr = rename_reductions( - loopy_expr, loopy_expr_context, - lambda old_name: state.var_name_gen(f"{name}{old_name}")) - # Make the instruction from loopy.kernel.instruction import make_assignment if indices: @@ -749,8 +747,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult, shape = shape_to_scalar_expression(expr.shape, cgen_mapper, state) # Get the domain. - domain = domain_for_shape(inames, shape, - loopy_expr_context.reduction_bounds) + domain = domain_for_shape(inames, shape, {}) # Update the kernel. kernel = state.kernel @@ -786,31 +783,6 @@ def get_loopy_temporary(name: str, expr: Array, cgen_mapper: CodeGenMapper, dtype=expr.dtype, address_space=address_space) - -def rename_reductions( - loopy_expr: ScalarExpression, - loopy_expr_context: LoopyExpressionContext, - var_name_gen: Callable[[str], str]) -> ScalarExpression: - """Rename the reduction variables in *loopy_expr* and *loopy_expr_context* - using the callable *var_name_gen.* - """ - new_reduction_inames = tuple( - var_name_gen(old_iname) - for old_iname in loopy_expr_context.reduction_bounds) - - substitutions = dict(zip( - loopy_expr_context.reduction_bounds, - map(var, new_reduction_inames))) - - result = loopy_substitute(loopy_expr, substitutions) - - new_reduction_bounds = { - substitutions[old_iname].name: bounds - for old_iname, bounds in loopy_expr_context.reduction_bounds.items()} - - loopy_expr_context.reduction_bounds = new_reduction_bounds - return result - # }}}