diff --git a/loopy/types.py b/loopy/types.py index 8f0f310c305b3d5b24bd6e771b501bb6d9c69224..7deb7608d7ec8b45af31b1cabc666e87207b6fb8 100644 --- a/loopy/types.py +++ b/loopy/types.py @@ -211,6 +211,9 @@ def to_loopy_type(dtype, allow_auto=False, allow_none=False, for_atomic=False, raise LoopyError("do not know how to convert '%s' to an atomic type" % dtype) + if target and not dtype.target: + return dtype.with_target(target) + return dtype elif numpy_dtype is not None: diff --git a/test/test_loopy.py b/test/test_loopy.py index 80af89f3bb09f2d2bf394279115acfa0fe928e2b..647f0d8d5a344f99d591c4071a64d22be8d85388 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2927,6 +2927,14 @@ def test_backwards_dep_printing_and_error(): print(knl) +def test_to_loopy_type_target_specification(): + arg = lp.GlobalArg('test', dtype=np.int32, shape=(1,)) + # convert dtype w/ to_loopy_type + from loopy.types import to_loopy_type + assert to_loopy_type(arg.dtype, target=lp.OpenCLTarget()).target == \ + lp.OpenCLTarget() + + def test_dump_binary(ctx_factory): ctx = ctx_factory()