diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 91be54bac3296415fe00b3006e707965c92f1059..1c140474c33b6425c556198f46c26a2b9188e61a 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -359,13 +359,13 @@ class CodeGenMapper(Mapper): # TODO: Respect tags. - shared_ctx = PersistentExpressionContext(state) - upstream_ctx = LocalExpressionContext(local_namespace=expr.bindings, + prstnt_ctx = PersistentExpressionContext(state) + local_ctx = LocalExpressionContext(local_namespace=expr.bindings, num_indices=expr.ndim, reduction_bounds={}) - loopy_expr = self.exprgen_mapper(expr.expr, shared_ctx, upstream_ctx) + loopy_expr = self.exprgen_mapper(expr.expr, prstnt_ctx, local_ctx) - result = InlinedResult(loopy_expr, expr.ndim, shared_ctx.depends_on) + result = InlinedResult(loopy_expr, expr.ndim, prstnt_ctx.depends_on) state.results[expr] = result shape_to_scalar_expression(expr.shape, self, state) # walk over size params @@ -472,20 +472,20 @@ class CodeGenMapper(Mapper): else: assert isinstance(arg, lp.ValueArg) and arg.is_input pt_arg = expr.bindings[arg.name] - shared_ctx = PersistentExpressionContext(state) + prstnt_ctx = PersistentExpressionContext(state) if isinstance(pt_arg, Array): assert pt_arg.ndim == 0 pt_arg_rec = self.rec(pt_arg, state) - params.append(pt_arg_rec.to_loopy_expression((), shared_ctx)) + params.append(pt_arg_rec.to_loopy_expression((), prstnt_ctx)) depends_on.update(pt_arg_rec.depends_on) else: - upstream_ctx = LocalExpressionContext(reduction_bounds={}, + local_ctx = LocalExpressionContext(reduction_bounds={}, num_indices=0, local_namespace={}) params.append(self.exprgen_mapper(pt_arg, - shared_ctx, - upstream_ctx)) + prstnt_ctx, + local_ctx)) new_insn = make_assignment( tuple(assignees), @@ -540,61 +540,61 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): self.codegen_mapper = codegen_mapper def __call__(self, expr: ScalarExpression, - shared_ctx: PersistentExpressionContext, - upstream_ctx: Optional[LocalExpressionContext], + prstnt_ctx: PersistentExpressionContext, + local_ctx: Optional[LocalExpressionContext], ) -> ScalarExpression: - return self.rec(expr, shared_ctx, upstream_ctx) + return self.rec(expr, prstnt_ctx, local_ctx) def map_subscript(self, expr: prim.Subscript, - shared_ctx: PersistentExpressionContext, - upstream_ctx: LocalExpressionContext, + prstnt_ctx: PersistentExpressionContext, + local_ctx: LocalExpressionContext, ) -> ScalarExpression: assert isinstance(expr.aggregate, prim.Variable) result: ImplementedResult = self.codegen_mapper( - upstream_ctx.lookup(expr.aggregate.name), shared_ctx.state) - return result.to_loopy_expression(self.rec(expr.index, shared_ctx, - upstream_ctx), - shared_ctx) + local_ctx.lookup(expr.aggregate.name), prstnt_ctx.state) + return result.to_loopy_expression(self.rec(expr.index, prstnt_ctx, + local_ctx), + prstnt_ctx) def map_variable(self, expr: prim.Variable, - shared_ctx: PersistentExpressionContext, - upstream_ctx: LocalExpressionContext, + prstnt_ctx: PersistentExpressionContext, + local_ctx: LocalExpressionContext, ) -> ScalarExpression: elw_match = ELWISE_INDEX_RE.fullmatch(expr.name) if elw_match: # Found an index of the form _0, _1, ... index = int(elw_match.group(1)) - if not (0 <= index < upstream_ctx.num_indices): + if not (0 <= index < local_ctx.num_indices): raise ValueError(f"invalid index encountered: _{index}") return expr - elif expr.name in upstream_ctx.reduction_bounds: + elif expr.name in local_ctx.reduction_bounds: return expr else: - array = upstream_ctx.lookup(expr.name) + array = local_ctx.lookup(expr.name) impl_result: ImplementedResult = self.codegen_mapper(array, - shared_ctx.state) - return impl_result.to_loopy_expression((), shared_ctx) + prstnt_ctx.state) + return impl_result.to_loopy_expression((), prstnt_ctx) def map_call(self, expr: prim.Call, - shared_ctx: PersistentExpressionContext, - upstream_ctx: LocalExpressionContext + prstnt_ctx: PersistentExpressionContext, + local_ctx: LocalExpressionContext ) -> ScalarExpression: if isinstance(expr.function, prim.Variable) and ( expr.function.name.startswith("pytato.c99.")): name_in_loopy = expr.function.name[11:] return prim.Call(prim.Variable(name_in_loopy), - self.rec(expr.parameters, shared_ctx, upstream_ctx)) + self.rec(expr.parameters, prstnt_ctx, local_ctx)) - return super().map_call(expr, shared_ctx, upstream_ctx) + return super().map_call(expr, prstnt_ctx, local_ctx) def map_reduce(self, expr: scalar_expr.Reduce, - shared_ctx: PersistentExpressionContext, - upstream_ctx: LocalExpressionContext + prstnt_ctx: PersistentExpressionContext, + local_ctx: LocalExpressionContext ) -> ScalarExpression: from loopy.symbolic import Reduction as LoopyReduction - state = shared_ctx.state + state = prstnt_ctx.state unique_names_mapping = { old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name) @@ -606,8 +606,8 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): new_bounds = {unique_names_mapping[name]: bound_exprs for name, bound_exprs in expr.bounds.items()} - inner_expr = self.rec(inner_expr, shared_ctx, - upstream_ctx.copy(reduction_bounds=new_bounds)) + inner_expr = self.rec(inner_expr, prstnt_ctx, + local_ctx.copy(reduction_bounds=new_bounds)) try: loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op] @@ -619,7 +619,7 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): inner_expr) domain = domain_for_shape((), shape=(), reductions={ - redn_iname: self.rec(bounds, shared_ctx, upstream_ctx) + redn_iname: self.rec(bounds, prstnt_ctx, local_ctx) for redn_iname, bounds in new_bounds.items()}) kernel = state.kernel state.update_kernel(kernel.copy(domains=kernel.domains+[domain]))