diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index bfc06cff54c746277c3fb3fe47f35d7df8f54fcc..4d3b25497f10e3f385734f520fdbaf44842339a8 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -1086,7 +1086,8 @@ class CallInstruction(MultiAssignmentBase): self.temp_var_types = (Optional(),) * len(self.assignees) else: self.temp_var_types = tuple( - _check_and_fix_temp_var_type(tvt) for tvt in temp_var_types) + _check_and_fix_temp_var_type(tvt, stacklevel=3) + for tvt in temp_var_types) # {{{ implement InstructionBase interface @@ -1497,22 +1498,29 @@ def _get_insn_hash_key(insn): # {{{ _check_and_fix_temp_var_type -def _check_and_fix_temp_var_type(temp_var_type): +def _check_and_fix_temp_var_type(temp_var_type, stacklevel=2): """Check temp_var_type for deprecated usage, and convert to the right value. """ - if temp_var_type is not None: - import loopy as lp - if temp_var_type is lp.auto: - warn("temp_var_type should be Optional(None) if " - "unspecified, not auto. This usage will be disallowed soon.", - DeprecationWarning, stacklevel=3) - temp_var_type = lp.Optional(None) - elif not isinstance(temp_var_type, lp.Optional): - warn("temp_var_type should be None or an instance of Optional. " - "Other values for temp_var_type will be disallowed soon.", - DeprecationWarning, stacklevel=3) - temp_var_type = lp.Optional(temp_var_type) + import loopy as lp + + if temp_var_type is None: + warn("temp_var_type should be Optional() if no temporary, not None. " + "This usage will be disallowed soon.", + DeprecationWarning, stacklevel=1 + stacklevel) + temp_var_type = lp.Optional() + + elif temp_var_type is lp.auto: + warn("temp_var_type should be Optional(None) if " + "unspecified, not auto. This usage will be disallowed soon.", + DeprecationWarning, stacklevel=1 + stacklevel) + temp_var_type = lp.Optional(None) + + elif not isinstance(temp_var_type, lp.Optional): + warn("temp_var_type should be an instance of Optional. " + "Other values for temp_var_type will be disallowed soon.", + DeprecationWarning, stacklevel=1 + stacklevel) + temp_var_type = lp.Optional(temp_var_type) return temp_var_type diff --git a/test/test_loopy.py b/test/test_loopy.py index 38d1cd6b0e5f2e9ccd64c6ddb41b161040e515e4..defd27dc0dd0d14789a8f569c297ac1f86572abc 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2927,6 +2927,32 @@ def test_backwards_dep_printing_and_error(): print(knl) +def test_temp_var_type_deprecated_usage(): + import warnings + warnings.simplefilter("always") + + with pytest.warns(DeprecationWarning): + lp.Assignment("x", 1, temp_var_type=lp.auto) + + with pytest.warns(DeprecationWarning): + lp.Assignment("x", 1, temp_var_type=None) + + with pytest.warns(DeprecationWarning): + lp.Assignment("x", 1, temp_var_type=np.dtype(np.int32)) + + from loopy.symbolic import parse + + with pytest.warns(DeprecationWarning): + lp.CallInstruction("(x,)", parse("f(1)"), temp_var_types=(lp.auto,)) + + with pytest.warns(DeprecationWarning): + lp.CallInstruction("(x,)", parse("f(1)"), temp_var_types=(None,)) + + with pytest.warns(DeprecationWarning): + lp.CallInstruction("(x,)", parse("f(1)"), + temp_var_types=(np.dtype(np.int32),)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])