diff --git a/loopy/compiled.py b/loopy/compiled.py index 24d6ff5e51be51322ba5f332280695a80506c39c..2859a20476e20b8f9bfe059ff82a0cc47f4dcfb6 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -43,17 +43,20 @@ def _arg_matches_spec(arg, val, other_args): import loopy as lp if isinstance(arg, lp.GlobalArg): from pymbolic import evaluate - shape = evaluate(arg.shape, other_args) - strides = evaluate(arg.numpy_strides, other_args) if arg.dtype != val.dtype: raise TypeError("dtype mismatch on argument '%s' " "(got: %s, expected: %s)" % (arg.name, val.dtype, arg.dtype)) - if shape != val.shape: - raise TypeError("shape mismatch on argument '%s' " - "(got: %s, expected: %s)" - % (arg.name, val.shape, shape)) + + if arg.shape is not None: + shape = evaluate(arg.shape, other_args) + if shape != val.shape: + raise TypeError("shape mismatch on argument '%s' " + "(got: %s, expected: %s)" + % (arg.name, val.shape, shape)) + + strides = evaluate(arg.numpy_strides, other_args) if strides != tuple(val.strides): raise ValueError("strides mismatch on argument '%s' " "(got: %s, expected: %s)"