diff --git a/examples/hello-loopy.py b/examples/hello-loopy.py index a6b8e6d8b8ecdeabbd164d31352958975edf9ec9..a35d7272de26a4993f0d13682f8aba38091ea792 100644 --- a/examples/hello-loopy.py +++ b/examples/hello-loopy.py @@ -15,14 +15,15 @@ a = cl.array.arange(queue, n, dtype=np.float32) # ----------------------------------------------------------------------------- # generation (loopy bits start here) # ----------------------------------------------------------------------------- -knl = lp.make_kernel(ctx.devices[0], - "{[i]: 0<=i<n}", # "loop domain"-- what values does i take? - "out[i] = 2*a[i]", # "instructions" to be executed across the domain - [ # argument declarations - lp.GlobalArg("out", shape="n"), - lp.GlobalArg("a", shape="n"), - lp.ValueArg("n"), - ]) +knl = lp.make_kernel( + ctx.devices[0], + "{ [i]: 0<=i<n }", + "out[i] = 2*a[i]", + [ # argument declarations + lp.GlobalArg("out"), + lp.GlobalArg("a"), + lp.ValueArg("n"), + ]) # ----------------------------------------------------------------------------- # transformation diff --git a/loopy/__init__.py b/loopy/__init__.py index 41ce634c4ced81971010f2bd319bb0ab155ea5e7..7a1ac8d264b909a1f2da5362d8c0633171747f38 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -50,6 +50,7 @@ class LoopyAdvisory(UserWarning): # {{{ imported user interface from loopy.kernel.data import ( + auto_shape, auto_strides, ValueArg, ScalarArg, GlobalArg, ArrayArg, ConstantArg, ImageArg, default_function_mangler, single_arg_function_mangler, opencl_function_mangler, @@ -72,7 +73,9 @@ from loopy.codegen import generate_code from loopy.compiled import CompiledKernel, auto_test_vs_ref from loopy.check import check_kernels -__all__ = ["ValueArg", "ScalarArg", "GlobalArg", "ArrayArg", "ConstantArg", "ImageArg", +__all__ = [ + "auto_shape", "auto_strides", + "ValueArg", "ScalarArg", "GlobalArg", "ArrayArg", "ConstantArg", "ImageArg", "LoopKernel", "Instruction", "default_function_mangler", "single_arg_function_mangler", diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 902ed7ef2aafc6e94df3dec0103d51fc5685a181..f8b3d6775798f2a4bff95b6e420ed14b63fca298 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -379,25 +379,21 @@ class LoopyCCodeMapper(RecursiveMapper): if not isinstance(expr.index, tuple): index_expr = (index_expr,) - if arg.strides is not None: - ary_strides = arg.strides - else: - ary_strides = (1,) + if arg.strides is None: + raise RuntimeError("index access to '%s' requires known " + "strides" % arg.name) - if len(ary_strides) != len(index_expr): + if len(arg.strides) != len(index_expr): raise RuntimeError("subscript to '%s' in '%s' has the wrong " "number of indices (got: %d, expected: %d)" % ( expr.aggregate.name, expr, - len(index_expr), len(ary_strides))) - - if len(index_expr) == 0: - return "*" + expr.aggregate.name + len(index_expr), len(arg.strides))) from pymbolic.primitives import Subscript return base_impl( Subscript(expr.aggregate, arg.offset+sum( stride*expr_i for stride, expr_i in zip( - ary_strides, index_expr))), + arg.strides, index_expr))), enclosing_prec, type_context) diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index b07a487fc192b6fc6ab9ae1063e730872e6e591f..2cd1bce37eccfb02468896aeb8deed397f31974a 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -54,46 +54,28 @@ class CannotBranchDomainTree(RuntimeError): # {{{ loop kernel object class LoopKernel(Record): - """ + """These correspond more or less directly to arguments of + :func:`loopy.make_kernel`. + :ivar device: :class:`pyopencl.Device` - :ivar domains: :class:`islpy.BasicSet` + :ivar domains: a list of :class:`islpy.BasicSet` instances :ivar instructions: :ivar args: :ivar schedule: :ivar name: - :ivar preambles: a list of (tag, code) tuples that identify preamble snippets. - Each tag's snippet is only included once, at its first occurrence. - The preambles will be inserted in order of their tags. - :ivar preamble_generators: a list of functions of signature - (seen_dtypes, seen_functions) where seen_functions is a set of - (name, c_name, arg_dtypes), generating extra entries for `preambles`. - :ivar assumptions: the initial implemented_domain, captures assumptions - on the parameters. (an isl.Set) - :ivar local_sizes: A dictionary from integers to integers, mapping - workgroup axes to their sizes, e.g. *{0: 16}* forces axis 0 to be - length 16. + :ivar preambles: + :ivar preamble_generators: + :ivar assumptions: + :ivar local_sizes: :ivar temporary_variables: :ivar iname_to_tag: - :ivar substitutions: a mapping from substitution names to :class:`SubstitutionRule` - objects - :ivar function_manglers: list of functions of signature (name, arg_dtypes) - returning a tuple (result_dtype, c_name) - or a tuple (result_dtype, c_name, arg_dtypes), - where c_name is the C-level function to be called. - :ivar symbol_manglers: list of functions of signature (name) returning - a tuple (result_dtype, c_name), where c_name is the C-level symbol to be - evaluated. - :ivar defines: a dictionary of replacements to be made in instructions given - as strings before parsing. A macro instance intended to be replaced should - look like "MACRO" in the instruction code. The expansion given in this - parameter is allowed to be a list. In this case, instructions are generated - for *each* combination of macro values. - - These defines may also be used in the domain and in argument shapes and - strides. They are expanded only upon kernel creation. + :ivar function_manglers: + :ivar symbol_manglers: The following arguments are not user-facing: + :ivar substitutions: a mapping from substitution names to :class:`SubstitutionRule` + objects :ivar iname_slab_increments: a dictionary mapping inames to (lower_incr, upper_incr) tuples that will be separated out in the execution to generate 'bulk' slabs with fewer conditionals. @@ -121,7 +103,6 @@ class LoopKernel(Record): single_arg_function_mangler, ], symbol_manglers=[opencl_symbol_mangler], - defines={}, # non-user-facing iname_slab_increments={}, @@ -203,25 +184,6 @@ class LoopKernel(Record): # }}} - # {{{ expand macros in arg shapes - - from loopy.kernel.data import ShapedArg - from loopy.kernel.creation import expand_defines_in_expr - - processed_args = [] - for arg in args: - for arg_name in arg.name.split(","): - new_arg = arg.copy(name=arg_name) - if isinstance(arg, ShapedArg): - if arg.shape is not None: - new_arg = new_arg.copy(shape=expand_defines_in_expr(arg.shape, defines)) - if arg.strides is not None: - new_arg = new_arg.copy(strides=expand_defines_in_expr(arg.strides, defines)) - - processed_args.append(new_arg) - - # }}} - index_dtype = np.dtype(index_dtype) if index_dtype.kind != 'i': raise TypeError("index_dtype must be an integer") @@ -235,7 +197,7 @@ class LoopKernel(Record): Record.__init__(self, device=device, domains=domains, instructions=instructions, - args=processed_args, + args=args, schedule=schedule, name=name, preambles=preambles, diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 3b706c80f53b8c74b67dca439b7de11539af3339..94bc195a43eede61f8bf6f8cfc7130625e348ea0 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -28,7 +28,7 @@ THE SOFTWARE. import numpy as np -from loopy.symbolic import IdentityMapper +from loopy.symbolic import IdentityMapper, WalkMapper from loopy.kernel.data import Instruction, SubstitutionRule import islpy as isl from islpy import dim_type @@ -85,76 +85,6 @@ class MakeUnique: # }}} -# {{{ domain parsing - -EMPTY_SET_DIMS_RE = re.compile(r"^\s*\{\s*\:") -SET_DIMS_RE = re.compile(r"^\s*\{\s*\[([a-zA-Z0-9_, ]+)\]\s*\:") - -def _find_inames_in_set(dom_str): - empty_match = EMPTY_SET_DIMS_RE.match(dom_str) - if empty_match is not None: - return set() - - match = SET_DIMS_RE.match(dom_str) - if match is None: - raise RuntimeError("invalid syntax for domain '%s'" % dom_str) - - result = set(iname.strip() for iname in match.group(1).split(",") - if iname.strip()) - - return result - -EX_QUANT_RE = re.compile(r"\bexists\s+([a-zA-Z0-9])\s*\:") - -def _find_existentially_quantified_inames(dom_str): - return set(ex_quant.group(1) for ex_quant in EX_QUANT_RE.finditer(dom_str)) - -def parse_domains(ctx, domains, defines): - if isinstance(domains, str): - domains = [domains] - - result = [] - used_inames = set() - - for dom in domains: - if isinstance(dom, str): - dom, = expand_defines(dom, defines) - - if not dom.lstrip().startswith("["): - # i.e. if no parameters are already given - parameters = (_gather_isl_identifiers(dom) - - _find_inames_in_set(dom) - - _find_existentially_quantified_inames(dom)) - dom = "[%s] -> %s" % (",".join(parameters), dom) - - try: - dom = isl.BasicSet.read_from_str(ctx, dom) - except: - print "failed to parse domain '%s'" % dom - raise - else: - assert isinstance(dom, (isl.Set, isl.BasicSet)) - # assert dom.get_ctx() == ctx - - for i_iname in xrange(dom.dim(dim_type.set)): - iname = dom.get_dim_name(dim_type.set, i_iname) - - if iname is None: - raise RuntimeError("domain '%s' provided no iname at index " - "%d (redefined iname?)" % (dom, i_iname)) - - if iname in used_inames: - raise RuntimeError("domain '%s' redefines iname '%s' " - "that is part of a previous domain" % (dom, iname)) - - used_inames.add(iname) - - result.append(dom) - - return result - -# }}} - # {{{ expand defines WORD_RE = re.compile(r"\b([a-zA-Z0-9_]+)\b") @@ -346,6 +276,79 @@ def parse_if_necessary(insn, defines): # }}} +# {{{ domain parsing + +EMPTY_SET_DIMS_RE = re.compile(r"^\s*\{\s*\:") +SET_DIMS_RE = re.compile(r"^\s*\{\s*\[([a-zA-Z0-9_, ]+)\]\s*\:") + +def _find_inames_in_set(dom_str): + empty_match = EMPTY_SET_DIMS_RE.match(dom_str) + if empty_match is not None: + return set() + + match = SET_DIMS_RE.match(dom_str) + if match is None: + raise RuntimeError("invalid syntax for domain '%s'" % dom_str) + + result = set(iname.strip() for iname in match.group(1).split(",") + if iname.strip()) + + return result + +EX_QUANT_RE = re.compile(r"\bexists\s+([a-zA-Z0-9])\s*\:") + +def _find_existentially_quantified_inames(dom_str): + return set(ex_quant.group(1) for ex_quant in EX_QUANT_RE.finditer(dom_str)) + +def parse_domains(ctx, domains, defines): + if isinstance(domains, str): + domains = [domains] + + result = [] + used_inames = set() + + for dom in domains: + if isinstance(dom, str): + dom, = expand_defines(dom, defines) + + if not dom.lstrip().startswith("["): + # i.e. if no parameters are already given + parameters = (_gather_isl_identifiers(dom) + - _find_inames_in_set(dom) + - _find_existentially_quantified_inames(dom)) + dom = "[%s] -> %s" % (",".join(parameters), dom) + + try: + dom = isl.BasicSet.read_from_str(ctx, dom) + except: + print "failed to parse domain '%s'" % dom + raise + else: + assert isinstance(dom, (isl.Set, isl.BasicSet)) + # assert dom.get_ctx() == ctx + + for i_iname in xrange(dom.dim(dim_type.set)): + iname = dom.get_dim_name(dim_type.set, i_iname) + + if iname is None: + raise RuntimeError("domain '%s' provided no iname at index " + "%d (redefined iname?)" % (dom, i_iname)) + + if iname in used_inames: + raise RuntimeError("domain '%s' redefines iname '%s' " + "that is part of a previous domain" % (dom, iname)) + + used_inames.add(iname) + + result.append(dom) + + return result + +# }}} + +def guess_kernel_args_if_requested(domains, instructions, kernel_args): + return kernel_args + # {{{ tag reduction inames as sequential def tag_reduction_inames_as_sequential(knl): @@ -592,23 +595,168 @@ def check_for_reduction_inames_duplication_requests(kernel): # }}} -# {{{ kernel creation top-level +# {{{ -def make_kernel(device, domains, instructions, kernel_args=[], *args, **kwargs): - """User-facing kernel creation entrypoint.""" +def apply_default_order_to_args(kernel, default_order): + from loopy.kernel.data import ShapedArg - for forbidden_kwarg in [ - "substitutions", - "iname_slab_increments", - "applied_iname_rewrites", - "cache_manager", - "isl_context", - ]: - if forbidden_kwarg in kwargs: - raise RuntimeError("'%s' is not part of user-facing interface" - % forbidden_kwarg) + processed_args = [] + for arg in kernel.args: + if isinstance(arg, ShapedArg): + arg = arg.copy(order=default_order) + processed_args.append(arg) + + return kernel.copy(args=processed_args) + +# }}} + +# {{{ duplicate arguments and expand defines in shapes + +def dup_args_and_expand_defines_in_shapes(kernel, defines): + from loopy.kernel.data import ShapedArg, auto_shape, auto_strides + from loopy.kernel.creation import expand_defines_in_expr + + processed_args = [] + for arg in kernel.args: + for arg_name in arg.name.split(","): + new_arg = arg.copy(name=arg_name) + if isinstance(arg, ShapedArg): + if arg.shape is not None and arg.shape is not auto_shape: + new_arg = new_arg.copy(shape=expand_defines_in_expr(arg.shape, defines)) + if arg.strides is not None and arg.strides is not auto_strides: + new_arg = new_arg.copy(strides=expand_defines_in_expr(arg.strides, defines)) + + processed_args.append(new_arg) + + return kernel.copy(args=processed_args) + +# }}} + +# {{{ guess argument shapes + +class _AccessRangeMapper(WalkMapper): + def __init__(self, arg_name): + self.arg_name = arg_name + self.access_range = None + + def map_subscript(self, expr, domain): + WalkMapper.map_subscript(self, expr, domain) + + from pymbolic.primitives import Variable + assert isinstance(expr.aggregate, Variable) + + if expr.aggregate.name != self.arg_name: + return + + subscript = expr.index + if not isinstance(subscript, tuple): + subscript = (subscript,) + + from loopy.symbolic import get_dependencies, get_access_range - defines = kwargs.get("defines", {}) + if not get_dependencies(subscript) <= set(domain.get_var_dict()): + raise RuntimeError("cannot determine access range for '%s': " + "undetermined index in '%s'" + % (self.arg_name, ", ".join(str(i) for i in subscript))) + + access_range = get_access_range(domain, subscript) + + if self.access_range is None: + self.access_range = access_range + else: + if (self.access_range.dim(dim_type.set) + != access_range.dim(dim_type.set)): + raise RuntimeError( + "error while determining shape of argument '%s': " + "varying number of indices encountered" + % self.arg_name) + + self.access_range = self.access_range | access_range + +def guess_arg_shape_if_requested(kernel, default_order): + new_args = [] + + from loopy.kernel.data import ShapedArg, auto_shape, auto_strides + + for arg in kernel.args: + if isinstance(arg, ShapedArg) and ( + arg.shape is auto_shape or arg.strides is auto_strides): + armap = _AccessRangeMapper(arg.name) + + for insn in kernel.instructions: + domain = kernel.get_inames_domain(kernel.insn_inames(insn)) + armap(insn.assignee, domain) + armap(insn.expression, domain) + + if armap.access_range is None: + # no subscripts found, let's call it a scalar + shape = () + else: + from loopy.isl_helpers import static_max_of_pw_aff + from loopy.symbolic import pw_aff_to_expr + + shape = tuple( + pw_aff_to_expr(static_max_of_pw_aff( + kernel.cache_manager.dim_max(armap.access_range, i) + 1, + constants_only=False)) + for i in xrange(armap.access_range.dim(dim_type.set))) + + if arg.shape is auto_shape: + arg = arg.copy(shape=shape) + if arg.strides is auto_strides: + from loopy.kernel.data import make_strides + arg = arg.copy(strides=make_strides(shape, default_order)) + + new_args.append(arg) + + return kernel.copy(args=new_args) + +# }}} + +# {{{ kernel creation top-level + +def make_kernel(device, domains, instructions, kernel_args=[], **kwargs): + """User-facing kernel creation entrypoint. + + :arg device: :class:`pyopencl.Device` + :arg domains: :class:`islpy.BasicSet` + :arg instructions: + :arg kernel_args: + + The following keyword arguments are recognized: + + :arg preambles: a list of (tag, code) tuples that identify preamble snippets. + Each tag's snippet is only included once, at its first occurrence. + The preambles will be inserted in order of their tags. + :arg preamble_generators: a list of functions of signature + (seen_dtypes, seen_functions) where seen_functions is a set of + (name, c_name, arg_dtypes), generating extra entries for *preambles*. + :arg defines: a dictionary of replacements to be made in instructions given + as strings before parsing. A macro instance intended to be replaced should + look like "MACRO" in the instruction code. The expansion given in this + parameter is allowed to be a list. In this case, instructions are generated + for *each* combination of macro values. + + These defines may also be used in the domain and in argument shapes and + strides. They are expanded only upon kernel creation. + :arg default_order: "C" (default) or "F" + :arg function_manglers: list of functions of signature (name, arg_dtypes) + returning a tuple (result_dtype, c_name) + or a tuple (result_dtype, c_name, arg_dtypes), + where c_name is the C-level function to be called. + :arg symbol_manglers: list of functions of signature (name) returning + a tuple (result_dtype, c_name), where c_name is the C-level symbol to be + evaluated. + :arg assumptions: the initial implemented_domain, captures assumptions + on the parameters. (an isl.Set) + :arg local_sizes: A dictionary from integers to integers, mapping + workgroup axes to their sizes, e.g. *{0: 16}* forces axis 0 to be + length 16. + :arg temporary_variables: + """ + + defines = kwargs.pop("defines", {}) + default_order = kwargs.pop("default_order", "C") # {{{ instruction/subst parsing @@ -645,8 +793,10 @@ def make_kernel(device, domains, instructions, kernel_args=[], *args, **kwargs): domains = parse_domains(isl_context, domains, defines) + kernel_args = guess_kernel_args_if_requested(domains, instructions, kernel_args) + from loopy.kernel import LoopKernel - knl = LoopKernel(device, domains, instructions, kernel_args, *args, **kwargs) + knl = LoopKernel(device, domains, instructions, kernel_args, **kwargs) check_for_nonexistent_iname_deps(knl) check_for_reduction_inames_duplication_requests(knl) @@ -654,6 +804,9 @@ def make_kernel(device, domains, instructions, kernel_args=[], *args, **kwargs): knl = tag_reduction_inames_as_sequential(knl) knl = create_temporaries(knl) knl = expand_cses(knl) + knl = dup_args_and_expand_defines_in_shapes(knl, defines) + knl = guess_arg_shape_if_requested(knl, default_order) + knl = apply_default_order_to_args(knl, default_order) # ------------------------------------------------------------------------- # Ordering dependency: diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 80e53b17d541bbba399f452bc266e2d808d3162d..8248b4ca993c34010a9e81d1da74483fc0a14d16 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -136,8 +136,26 @@ def parse_tag(tag): # {{{ arguments +class auto_shape: + pass + +class auto_strides: + pass + +def make_strides(shape, order): + from pyopencl.compyte.array import ( + f_contiguous_strides, + c_contiguous_strides) + + if order == "F": + return f_contiguous_strides(1, shape) + elif order == "C": + return c_contiguous_strides(1, shape) + else: + raise ValueError("invalid order: %s" % order) + class ShapedArg(Record): - def __init__(self, name, dtype=None, shape=None, strides=None, order="C", + def __init__(self, name, dtype=None, shape=None, strides=None, order=None, offset=0): """ All of the following are optional. Specify either strides or shape. @@ -166,23 +184,26 @@ class ShapedArg(Record): return tuple(parse_if_necessary(xi) for xi in x) - if strides is not None: + if strides == "auto": + strides = auto_strides + if shape == "auto": + shape = auto_shape + + strides_known = strides is not None and strides is not auto_strides + shape_known = shape is not None and shape is not auto_shape + + if strides_known: strides = process_tuple(strides) - if shape is not None: + if shape_known: shape = process_tuple(shape) - if strides is None and shape is not None: - from pyopencl.compyte.array import ( - f_contiguous_strides, - c_contiguous_strides) - - if order == "F": - strides = f_contiguous_strides(1, shape) - elif order == "C": - strides = c_contiguous_strides(1, shape) - else: - raise ValueError("invalid order: %s" % order) + if not strides_known and shape_known: + if len(shape) == 1: + # don't need order to know that + strides = (1,) + elif order is not None: + strides = make_strides(shape, order) Record.__init__(self, name=name, diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b2a06cda6ad3895cff6451a624d8a9110f256a6d..91f7e4907f281b586317c221eef04cbe99002d39 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -168,15 +168,15 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass class WalkMapper(WalkMapperBase): - def map_reduction(self, expr): + def map_reduction(self, expr, *args): if not self.visit(expr): return - self.rec(expr.expr) + self.rec(expr.expr, *args) map_tagged_variable = WalkMapperBase.map_variable - def map_loopy_function_identifier(self, expr): + def map_loopy_function_identifier(self, expr, *args): self.visit(expr) map_linear_subscript = WalkMapperBase.map_subscript diff --git a/test/test_loopy.py b/test/test_loopy.py index 3694b0390b942521a3a30a02e413dbcbd4fbf92e..8c034d6f55699eb4facfa6e187a755f3d8ce036e 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -202,14 +202,10 @@ def test_wg_too_small(ctx_factory): kernel_gen = lp.generate_loop_schedules(knl) kernel_gen = lp.check_kernels(kernel_gen) + import pytest for gen_knl in kernel_gen: - try: + with pytest.raises(RuntimeError): lp.CompiledKernel(ctx, gen_knl).get_code() - except RuntimeError, e: - assert "implemented and desired" in str(e) - pass # expected! - else: - assert False # expecting an error @@ -644,14 +640,13 @@ def test_dependent_loop_bounds(ctx_factory): ], [ "<> row_len = a_rowstarts[i+1] - a_rowstarts[i]", - "ax[i] = sum(jj, a_values[a_rowstarts[i]+jj])", + "a_sum[i] = sum(jj, a_values[[a_rowstarts[i]+jj]])", ], [ - lp.GlobalArg("a_rowstarts", np.int32), - lp.GlobalArg("a_indices", np.int32), + lp.GlobalArg("a_rowstarts", np.int32, shape="auto"), + lp.GlobalArg("a_indices", np.int32, shape="auto"), lp.GlobalArg("a_values", dtype), - lp.GlobalArg("x", dtype), - lp.GlobalArg("ax", dtype), + lp.GlobalArg("a_sum", dtype, shape="auto"), lp.ValueArg("n", np.int32), ], assumptions="n>=1 and row_len>=1") @@ -676,14 +671,13 @@ def test_dependent_loop_bounds_2(ctx_factory): [ "<> row_start = a_rowstarts[i]", "<> row_len = a_rowstarts[i+1] - row_start", - "ax[i] = sum(jj, a_values[row_start+jj])", + "ax[i] = sum(jj, a_values[[row_start+jj]])", ], [ - lp.GlobalArg("a_rowstarts", np.int32), - lp.GlobalArg("a_indices", np.int32), + lp.GlobalArg("a_rowstarts", np.int32, shape="auto"), + lp.GlobalArg("a_indices", np.int32, shape="auto"), lp.GlobalArg("a_values", dtype), - lp.GlobalArg("x", dtype), - lp.GlobalArg("ax", dtype), + lp.GlobalArg("ax", dtype, shape="auto"), lp.ValueArg("n", np.int32), ], assumptions="n>=1 and row_len>=1") @@ -718,7 +712,7 @@ def test_dependent_loop_bounds_3(ctx_factory): "a[i,jj] = 1", ], [ - lp.GlobalArg("a_row_lengths", np.int32), + lp.GlobalArg("a_row_lengths", np.int32, shape="auto"), lp.GlobalArg("a", dtype, shape=("n,n"), order="C"), lp.ValueArg("n", np.int32), ]) @@ -1029,17 +1023,35 @@ def test_write_parameter(ctx_factory): ], assumptions="n>=1") - try: + import pytest + with pytest.raises(RuntimeError): lp.CompiledKernel(ctx, knl).get_code() - except RuntimeError, e: - assert "may not be written" in str(e) - pass # expected! - else: - assert False # expecting an error +def test_arg_shape_guessing(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j]: 0<=i,j<n }", + ], + """ + a = 1.5 + sum((i,j), i*j) + b[i, j] = i*j + c[i+j, j] = b[j,i] + """, + [ + lp.GlobalArg("a", shape=lp.auto_shape), + lp.GlobalArg("b", shape=lp.auto_shape), + lp.GlobalArg("c", shape=lp.auto_shape), + lp.ValueArg("n"), + ], + assumptions="n>=1") + + print knl + print lp.CompiledKernel(ctx, knl).get_highlighted_code() + if __name__ == "__main__": import sys if len(sys.argv) > 1: