diff --git a/loopy/check.py b/loopy/check.py index aeec19c7831c6e46d4d3cef99dab095ad7f195ed..ad7717226d4e41aa9712a6dc6b33ced77e9c356f 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -30,8 +30,8 @@ import islpy as isl from loopy.symbolic import WalkMapper from loopy.diagnostic import LoopyError, WriteRaceConditionWarning, warn_with_kernel from loopy.type_inference import TypeInferenceMapper -from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction, - _DataObliviousInstruction) +from loopy.kernel.instruction import (MultiAssignmentBase, CallInstruction, + CInstruction, _DataObliviousInstruction) import logging logger = logging.getLogger(__name__) @@ -83,7 +83,8 @@ def check_for_integer_subscript_indices(kernel): idx_int_checker = SubscriptIndicesIsIntChecker(kernel) for insn in kernel.instructions: if isinstance(insn, MultiAssignmentBase): - idx_int_checker(insn.expression) + idx_int_checker(insn.expression, return_tuple=isinstance(insn, + CallInstruction), return_dtype_set=True) [idx_int_checker(assignee) for assignee in insn.assignees] elif isinstance(insn, (CInstruction, _DataObliviousInstruction)): pass