diff --git a/loopy/compiled.py b/loopy/compiled.py index 03cb09ab104644f40137c65103f6334d24d66461..02730d8688117d9f58407d243546d04328e365b1 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -499,6 +499,9 @@ def generate_array_arg_setup(gen, kernel, impl_arg_info, options): "%% (%s.shape, %s))" % (arg.name, arg.name, strify_tuple(arg.unvec_shape))) + if kernel_arg.shape is None: + pass + if any(shape_axis is None for shape_axis in kernel_arg.shape): gen("if len(%s.shape) != %s:" % (arg.name, len(arg.unvec_shape))) @@ -514,7 +517,7 @@ def generate_array_arg_setup(gen, kernel, impl_arg_info, options): with Indentation(gen): gen(shape_mismatch_msg) - elif kernel_arg.shape is not None: + else: # not None, no Nones in tuple gen("if %s.shape != %s:" % (arg.name, strify(arg.unvec_shape))) with Indentation(gen):