diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 2202190159cc3ed43fde5688f8b31fb3ac9f4dcf..eaf645a1c7e916f78818a4ab4c7fdad03d0802d6 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -156,17 +156,17 @@ class LocalExpressionContext: .. automethod:: lookup """ num_indices: int - local_namespace: Mapping[str, Array] + local_namespace: Mapping[str, ImplementedResult] reduction_bounds: ReductionBounds var_to_reduction_descr: Mapping[str, ReductionDescriptor] - def lookup(self, name: str) -> Array: + def lookup(self, name: str) -> ImplementedResult: return self.local_namespace[name] def copy(self, *, reduction_bounds: Optional[ReductionBounds] = None, num_indices: Optional[int] = None, - local_namespace: Optional[Mapping[str, Array]] = None, + local_namespace: Optional[Mapping[str, ImplementedResult]] = None, var_to_reduction_descr: Optional[ Mapping[str, ReductionDescriptor]] = None, ) -> LocalExpressionContext: @@ -347,7 +347,7 @@ class CodeGenMapper(Mapper): def __init__(self, array_tag_t_to_not_propagate: FrozenSet[Type[Tag]], axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]) -> None: - self.exprgen_mapper = InlinedExpressionGenMapper(self) + self.exprgen_mapper = InlinedExpressionGenMapper(axis_tag_t_to_not_propagate) self.array_tag_t_to_not_propagate = array_tag_t_to_not_propagate self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate self.has_loopy_call = False @@ -401,7 +401,9 @@ class CodeGenMapper(Mapper): prstnt_ctx = PersistentExpressionContext(state) local_ctx = LocalExpressionContext( - local_namespace=expr.bindings, + local_namespace={ + name: self.rec(subexpr, state) + for name, subexpr in expr.bindings.items()}, num_indices=expr.ndim, reduction_bounds={}, var_to_reduction_descr=expr.var_to_reduction_descr) @@ -622,10 +624,10 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): The outputs of this mapper are scalar expressions suitable for wrapping in :class:`InlinedResult`. """ - codegen_mapper: CodeGenMapper + axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]] - def __init__(self, codegen_mapper: CodeGenMapper): - self.codegen_mapper = codegen_mapper + def __init__(self, axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]): + self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate if TYPE_CHECKING: def __call__(self, expr: ScalarExpression, @@ -639,11 +641,8 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): local_ctx: LocalExpressionContext, ) -> ScalarExpression: assert isinstance(expr.aggregate, prim.Variable) - result: ImplementedResult = self.codegen_mapper( - local_ctx.lookup(expr.aggregate.name), prstnt_ctx.state) - return result.to_loopy_expression(self.rec(expr.index, prstnt_ctx, - local_ctx), - prstnt_ctx) + return local_ctx.lookup(expr.aggregate.name).to_loopy_expression( + self.rec(expr.index, prstnt_ctx, local_ctx), prstnt_ctx) def map_variable(self, expr: prim.Variable, prstnt_ctx: PersistentExpressionContext, @@ -660,10 +659,7 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): elif expr.name in local_ctx.reduction_bounds: return expr else: - array = local_ctx.lookup(expr.name) - impl_result: ImplementedResult = self.codegen_mapper(array, - prstnt_ctx.state) - return impl_result.to_loopy_expression((), prstnt_ctx) + return local_ctx.lookup(expr.name).to_loopy_expression((), prstnt_ctx) def map_call(self, expr: prim.Call, prstnt_ctx: PersistentExpressionContext, @@ -718,7 +714,7 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): for name_in_expr, name_in_kernel in sorted(unique_names_mapping.items()): for tag in local_ctx.var_to_reduction_descr[name_in_expr].tags: if all(not isinstance(tag, tag_t) - for tag_t in self.codegen_mapper.axis_tag_t_to_not_propagate): + for tag_t in self.axis_tag_t_to_not_propagate): state.update_kernel(lp.tag_inames(state.kernel, {name_in_kernel: tag}))