Skip to content
Snippets Groups Projects
Commit e2e44a1d authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Check for non-use of hardware axes.

parent e606dd46
No related branches found
No related tags found
No related merge requests found
......@@ -6,10 +6,7 @@ For writeup:
TODO: Reimplement forced lengths
TODO: Try, fix reg. prefetch (DG example) / CSEs
ILP and reg. prefetch interact!
TODO: Custom reductions per red. axis
TODO: Functions
TODO: Common subexpressions
TODO: Array common subexpressions (shared and private!)
TODO: ILP arrays
FIXME: support non-reductive dimensions (what did I mean here?)
FIXME: write names should be assigned during scheduling
......@@ -96,6 +93,8 @@ TODO
Dealt with
^^^^^^^^^^
- Check for non-use of hardware axes
- Slab decomposition for parallel dimensions
- implement at the outermost nesting level regardless
- bound *all* tagged inames
......
......@@ -100,14 +100,44 @@ def realize_reduction(kernel, inames=None, reduction_tag=None):
def check_double_use_of_hw_dimensions(kernel):
from loopy.kernel import UniqueTag
def check_non_use_of_hw_axes(kernel):
group_size, local_size = kernel.get_grid_sizes_as_exprs()
group_axes = set(range(len(group_size)))
local_axes = set(range(len(local_size)))
from loopy.kernel import TAG_LOCAL_IDX, TAG_AUTO_LOCAL_IDX, TAG_GROUP_IDX
for insn in kernel.instructions:
group_axes_used = set()
local_axes_used = set()
for iname in insn.all_inames():
tag = kernel.iname_to_tag.get(iname)
if isinstance(tag, TAG_LOCAL_IDX):
local_axes_used.add(tag.axis)
elif isinstance(tag, TAG_GROUP_IDX):
group_axes_used.add(tag.axis)
elif isinstance(tag, TAG_AUTO_LOCAL_IDX):
raise RuntimeError("auto local tag encountered")
if group_axes != group_axes_used:
raise RuntimeError("instruction '%s' does not use all hw group axes")
if local_axes != local_axes_used:
raise RuntimeError("instruction '%s' does not use all hw local axes")
def check_double_use_of_hw_axes(kernel):
from loopy.kernel import HardwareParallelTag
for insn in kernel.instructions:
insn_tag_keys = set()
for iname in insn.all_inames():
tag = kernel.iname_to_tag.get(iname)
if isinstance(tag, UniqueTag):
if isinstance(tag, HardwareParallelTag):
key = tag.key
if key in insn_tag_keys:
raise RuntimeError("instruction '%s' has two "
......@@ -669,9 +699,7 @@ def insert_barriers(kernel, schedule, level=0):
def generate_loop_schedules(kernel):
kernel = realize_reduction(kernel)
check_double_use_of_hw_dimensions(kernel)
check_double_use_of_hw_axes(kernel)
kernel = adjust_local_temp_var_storage(kernel)
# {{{ check that all CSEs have been realized
......@@ -687,8 +715,8 @@ def generate_loop_schedules(kernel):
# }}}
kernel = add_automatic_dependencies(kernel)
kernel = assign_automatic_axes(kernel)
check_non_use_of_hw_axes(kernel)
for gen_sched in generate_loop_schedules_internal(kernel):
gen_sched, owed_barriers = insert_barriers(kernel, gen_sched)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment