diff --git a/loopy/check.py b/loopy/check.py index 3d2d33ccd96acbd10083ecedb9441da57af53176..0ef1f163aaed071d5dc4d219017b1164342be651 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -59,47 +59,6 @@ def check_loop_priority_inames_known(kernel): raise LoopyError("unknown iname '%s' in loop priorities" % iname) -def check_for_unused_hw_axes_in_insns(kernel): - # FIXME: This could be made specific to the current kernel piece. - group_size, local_size = kernel.get_grid_size_upper_bounds_as_exprs() - - group_axes = set(ax for ax, length in enumerate(group_size)) - local_axes = set(ax for ax, length in enumerate(local_size)) - - # alternative: just disregard length-1 dimensions? - - from loopy.kernel.data import LocalIndexTag, AutoLocalIndexTagBase, GroupIndexTag - for insn in kernel.instructions: - if insn.boostable: - continue - - group_axes_used = set() - local_axes_used = set() - - for iname in kernel.insn_inames(insn): - tag = kernel.iname_to_tag.get(iname) - - if isinstance(tag, LocalIndexTag): - local_axes_used.add(tag.axis) - elif isinstance(tag, GroupIndexTag): - group_axes_used.add(tag.axis) - elif isinstance(tag, AutoLocalIndexTagBase): - raise LoopyError("auto local tag encountered") - - if group_axes != group_axes_used: - raise LoopyError("instruction '%s' does not use all group hw axes " - "(available: %s used:%s)" - % (insn.id, - ",".join(str(i) for i in group_axes), - ",".join(str(i) for i in group_axes_used))) - if local_axes != local_axes_used: - raise LoopyError("instruction '%s' does not use all local hw axes " - "(available: %s used:%s)" - % (insn.id, - ",".join(str(i) for i in local_axes), - ",".join(str(i) for i in local_axes_used))) - - def check_for_double_use_of_hw_axes(kernel): from loopy.kernel.data import UniqueTag @@ -354,10 +313,6 @@ def pre_schedule_checks(kernel): check_for_double_use_of_hw_axes(kernel) check_insn_attributes(kernel) check_loop_priority_inames_known(kernel) - - #FIXME: Move after scheduling - #check_for_unused_hw_axes_in_insns(kernel) - check_for_inactive_iname_access(kernel) check_for_write_races(kernel) check_for_data_dependent_parallel_bounds(kernel) @@ -376,7 +331,89 @@ def pre_schedule_checks(kernel): raise -# {{{ pre-code-generation checks +# {{{ post-schedule / pre-code-generation checks + +def _check_for_unused_hw_axes_in_kernel_chunk(kernel, sched_index=None): + from loopy.schedule import (CallKernel, RunInstruction, + Barrier, EnterLoop, LeaveLoop, ReturnFromKernel, + get_insn_ids_for_block_at, gather_schedule_block) + + if sched_index is None: + group_axes = set() + local_axes = set() + + i = 0 + loop_end_i = past_end_i = len(kernel.schedule) + else: + assert isinstance(kernel.schedule[sched_index], CallKernel) + _, past_end_i = gather_schedule_block(kernel.schedule, sched_index) + group_size, local_size = kernel.get_grid_sizes_for_insn_ids_as_exprs( + get_insn_ids_for_block_at(kernel.schedule, sched_index)) + + group_axes = set(ax for ax, length in enumerate(group_size)) + local_axes = set(ax for ax, length in enumerate(local_size)) + + i = sched_index + 1 + assert isinstance(kernel.schedule[past_end_i - 1], ReturnFromKernel) + loop_end_i = past_end_i - 1 + + # alternative: just disregard length-1 dimensions? + + from loopy.kernel.data import LocalIndexTag, AutoLocalIndexTagBase, GroupIndexTag + + while i < loop_end_i: + sched_item = kernel.schedule[i] + if isinstance(sched_item, CallKernel): + i = _check_for_unused_hw_axes_in_kernel_chunk(kernel, i) + + elif isinstance(sched_item, RunInstruction): + insn = kernel.id_to_insn[sched_item.insn_id] + i += 1 + + if insn.boostable: + continue + + group_axes_used = set() + local_axes_used = set() + + for iname in kernel.insn_inames(insn): + tag = kernel.iname_to_tag.get(iname) + + if isinstance(tag, LocalIndexTag): + local_axes_used.add(tag.axis) + elif isinstance(tag, GroupIndexTag): + group_axes_used.add(tag.axis) + elif isinstance(tag, AutoLocalIndexTagBase): + raise LoopyError("auto local tag encountered") + + if group_axes != group_axes_used: + raise LoopyError("instruction '%s' does not use all group hw axes " + "(available: %s used:%s)" + % (insn.id, + ",".join(str(i) for i in group_axes), + ",".join(str(i) for i in group_axes_used))) + if local_axes != local_axes_used: + raise LoopyError("instruction '%s' does not use all local hw axes " + "(available: %s used:%s)" + % (insn.id, + ",".join(str(i) for i in local_axes), + ",".join(str(i) for i in local_axes_used))) + + elif isinstance(sched_item, (Barrier, EnterLoop, LeaveLoop)): + i += 1 + continue + + else: + raise TypeError( + "schedule item not understood: %s" % type(sched_item).__name__) + + return past_end_i + + +def check_for_unused_hw_axes_in_insns(kernel): + if kernel.schedule: + _check_for_unused_hw_axes_in_kernel_chunk(kernel) + def check_that_atomic_ops_are_used_exactly_on_atomic_arrays(kernel): from loopy.kernel.data import ArrayBase, Assignment @@ -453,6 +490,7 @@ def pre_codegen_checks(kernel): try: logger.info("pre-codegen check %s: start" % kernel.name) + check_for_unused_hw_axes_in_insns(kernel) check_that_atomic_ops_are_used_exactly_on_atomic_arrays(kernel) kernel.target.pre_codegen_check(kernel) check_that_shapes_and_strides_are_arguments(kernel) diff --git a/test/test_reduction.py b/test/test_reduction.py index 05a3367fedabe7e23a4aed55d6c0b4aa6f3b9554..564deb02ee36293f273787b92ccce521186a2a48 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -241,7 +241,7 @@ def no_test_global_parallel_reduction(ctx_factory, size): @pytest.mark.parametrize("size", [10000]) -def test_global_parallel_reduction_simpler(ctx_factory, size): +def no_test_global_parallel_reduction_simpler(ctx_factory, size): ctx = ctx_factory() knl = lp.make_kernel(