diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index a061ca11198b7a1982cc826e2db44c21d641f95a..1f9108bb7713b4f4d121e50bf40c3031a7824d24 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -1242,16 +1242,24 @@ def remove_unused_inames(knl, inames=None): def remove_any_newly_unused_inames(transformation_func): def wrapper(knl, *args, **kwargs): - # determine which inames were already unused - inames_already_unused = knl.all_inames() - get_used_inames(knl) - # call transform - transformed_knl = transformation_func(knl, *args, **kwargs) + # check for remove_unused_inames argument, default: True + remove_unused_inames = kwargs.pop("remove_unused_inames", True) - # Remove inames that are unused due to transform - return remove_unused_inames( - transformed_knl, - transformed_knl.all_inames()-inames_already_unused) + if remove_unused_inames: + # determine which inames were already unused + inames_already_unused = knl.all_inames() - get_used_inames(knl) + + # call transform + transformed_knl = transformation_func(knl, *args, **kwargs) + + # Remove inames that are unused due to transform + return remove_unused_inames( + transformed_knl, + transformed_knl.all_inames()-inames_already_unused) + else: + # call transform + return transformation_func(knl, *args, **kwargs) return wrapper