From fd4cef680fd4a67dd8aff2e829323e3ecd1e0ae5 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 8 Dec 2024 14:47:28 +0200 Subject: [PATCH] loopy: fix expr empty subscript deprecation --- arraycontext/loopy.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index d6f9078..7b1d6a0 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -83,12 +83,14 @@ def get_default_entrypoint(t_unit): def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): @memoize_in(actx, _get_scalar_func_loopy_program) def get(c_name, nargs, naxes): - from pymbolic import var + from pymbolic.primitives import Subscript, Variable var_names = [f"i{i}" for i in range(naxes)] size_names = [f"n{i}" for i in range(naxes)] - subscript = tuple(var(vname) for vname in var_names) + subscript = tuple(Variable(vname) for vname in var_names) + from islpy import make_zero_and_vars + v = make_zero_and_vars(var_names, params=size_names) domain = v[0].domain() for vname, sname in zip(var_names, size_names, strict=True): @@ -98,22 +100,22 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): import loopy as lp - from .loopy import make_loopy_program from arraycontext.transform_metadata import ElementwiseMapKernelTag + + def sub(name: str) -> Variable | Subscript: + return Variable(name)[subscript] if subscript else Variable(name) + return make_loopy_program( - [domain_bset], - [ + [domain_bset], [ lp.Assignment( - var("out")[subscript], - var(c_name)(*[ - var(f"inp{i}")[subscript] for i in range(nargs)])) - ], - [ - lp.GlobalArg("out", - dtype=None, shape=lp.auto, offset=lp.auto)] + [ - lp.GlobalArg(f"inp{i}", - dtype=None, shape=lp.auto, offset=lp.auto) - for i in range(nargs)] + [...], + sub("out"), + Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)])) + ], [ + lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto) + ] + [ + lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto) + for i in range(nargs) + ] + [...], name=f"actx_special_{c_name}", tags=(ElementwiseMapKernelTag(),)) -- GitLab