diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 6b0fa64db2a3741c983fc3250f3d65ace6075db7..bfc06cff54c746277c3fb3fe47f35d7df8f54fcc 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -933,14 +933,7 @@ class Assignment(MultiAssignmentBase): self.assignee = assignee self.expression = expression - import loopy as lp - if temp_var_type is lp.auto: - warn("temp_var_type should be Optional(None) if " - "unspecified, not auto. This usage will be disallowed soon.", - DeprecationWarning, stacklevel=2) - temp_var_type = lp.Optional(None) - - self.temp_var_type = temp_var_type + self.temp_var_type = _check_and_fix_temp_var_type(temp_var_type) self.atomicity = atomicity # {{{ implement InstructionBase interface @@ -1092,18 +1085,8 @@ class CallInstruction(MultiAssignmentBase): if temp_var_types is None: self.temp_var_types = (Optional(),) * len(self.assignees) else: - import loopy as lp - processed_temp_var_types = [] - for temp_var_type in temp_var_types: - if temp_var_type is lp.auto: - warn("temp_var_type should be Optional(None) if " - "unspecified, not auto. " - "This usage will be disallowed soon.", - DeprecationWarning, stacklevel=2) - temp_var_type = lp.Optional(None) - processed_temp_var_types.append(temp_var_type) - - self.temp_var_types = tuple(processed_temp_var_types) + self.temp_var_types = tuple( + _check_and_fix_temp_var_type(tvt) for tvt in temp_var_types) # {{{ implement InstructionBase interface @@ -1512,4 +1495,28 @@ def _get_insn_hash_key(insn): # }}} +# {{{ _check_and_fix_temp_var_type + +def _check_and_fix_temp_var_type(temp_var_type): + """Check temp_var_type for deprecated usage, and convert to the right value. + """ + + if temp_var_type is not None: + import loopy as lp + if temp_var_type is lp.auto: + warn("temp_var_type should be Optional(None) if " + "unspecified, not auto. This usage will be disallowed soon.", + DeprecationWarning, stacklevel=3) + temp_var_type = lp.Optional(None) + elif not isinstance(temp_var_type, lp.Optional): + warn("temp_var_type should be None or an instance of Optional. " + "Other values for temp_var_type will be disallowed soon.", + DeprecationWarning, stacklevel=3) + temp_var_type = lp.Optional(temp_var_type) + + return temp_var_type + +# }}} + + # vim: foldmethod=marker