diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 69767d5e689543ab5e2f6641c6697e457d7a0b2b..0bc3d5bc284cb3ae67e744e5cff3f196f9140ec8 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -34,7 +34,8 @@ from loopy.kernel.data import ( InstructionBase, MultiAssignmentBase, Assignment, SubstitutionRule) -from loopy.kernel.instruction import CInstruction, _DataObliviousInstruction +from loopy.kernel.instruction import (CInstruction, _DataObliviousInstruction, + CallInstruction) from loopy.diagnostic import LoopyError, warn_with_kernel import islpy as isl from islpy import dim_type @@ -2095,10 +2096,13 @@ def realize_slices_as_sub_array_refs(kernel): new_insns = [] for insn in kernel.instructions: - if isinstance(insn, (MultiAssignmentBase, CInstruction)): + if isinstance(insn, CallInstruction): new_expr = slice_replacer(insn.expression) - new_insns.append(insn.copy(expression=new_expr)) - elif isinstance(insn, _DataObliviousInstruction): + new_assignees = slice_replacer(insn.assignees) + new_insns.append(insn.copy(assignees=new_assignees, + expression=new_expr)) + elif isinstance(insn, (CInstruction, MultiAssignmentBase, + _DataObliviousInstruction)): new_insns.append(insn) else: raise NotImplementedError("parse_slices not implemented for %s" % diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index d9b6384c8939b153254e3a5822b9aa3cbf795e43..d2d0c54579402549622e5381448d153fead3e05c 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -1046,22 +1046,27 @@ class CallInstruction(MultiAssignmentBase): # }}} +def subscript_contains_slice(subscript): + from pymbolic.primitives import Subscript, Slice + assert isinstance(subscript, Subscript) + return any(isinstance(index, Slice) for index in subscript.index_tuple) + + def is_array_call(assignees, expression): - from pymbolic.primitives import Call, CallWithKwargs + from pymbolic.primitives import Call, CallWithKwargs, Subscript from loopy.symbolic import SubArrayRef if not isinstance(expression, (Call, CallWithKwargs)): return False - for assignee in assignees: - if isinstance(assignee, SubArrayRef): - return True - - for par in expression.parameters: - if isinstance(assignee, SubArrayRef): + for par in expression.parameters+assignees: + if isinstance(par, SubArrayRef): return True + elif isinstance(par, Subscript): + if subscript_contains_slice(par): + return True - # did not encounter SubArrayRef, hence must be a normal call + # did not encounter SubArrayRef/Slice, hence must be a normal call return False