diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 64332637910340d68cb035d64ad6f4f643c0b5c9..7b1deb7951392e2e0c46360f8fd979ebf5aedb37 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -353,6 +353,26 @@ def remove_unused_arguments(knl): for insn in exp_knl.instructions: refd_vars.update(insn.dependency_names()) + from loopy.kernel.array import ArrayBase, FixedStrideArrayDimTag + from loopy.symbolic import get_dependencies + from itertools import chain + + def tolerant_get_deps(expr): + if expr is None or expr is lp.auto: + return set() + return get_dependencies(expr) + + for ary in chain(knl.args, six.itervalues(knl.temporary_variables)): + if isinstance(ary, ArrayBase): + refd_vars.update( + tolerant_get_deps(ary.shape) + | tolerant_get_deps(ary.offset)) + + for dim_tag in ary.dim_tags: + if isinstance(dim_tag, FixedStrideArrayDimTag): + refd_vars.update( + tolerant_get_deps(dim_tag.stride)) + for arg in knl.args: if arg.name in refd_vars: new_args.append(arg)