Skip to content
Commits on Source (3)
......@@ -287,14 +287,12 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
result = []
bounds = kernel.get_iname_bounds(iname)
domain = kernel.get_inames_domain(iname)
# It's ok to find a bound that's too "loose". The conditional
# generators will mop up after us.
from loopy.isl_helpers import static_min_of_pw_aff
lower_bound = static_min_of_pw_aff(bounds.lower_bound_pw_aff,
constants_only=False)
from loopy.kernel.tools import get_hw_axis_base_for_codegen
lower_bound = get_hw_axis_base_for_codegen(kernel, iname)
# These bounds are 'implemented' by the hardware. Make sure
# that the downstream conditional generators realize that.
......
......@@ -2115,4 +2115,19 @@ def get_outer_params(domains):
# }}}
def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff:
"""
Returns a :class:`isl.PwAff` hardware axes lower bound to serve as an
offsetting expression
during the hardware ina
"""
from loopy.kernel.data import HardwareConcurrentTag
from loopy.isl_helpers import static_min_of_pw_aff
assert kernel.iname_tags_of_type(iname, HardwareConcurrentTag)
bounds = kernel.get_iname_bounds(iname)
lower_bound = static_min_of_pw_aff(bounds.lower_bound_pw_aff,
constants_only=False)
return lower_bound
# vim: foldmethod=marker
......@@ -346,9 +346,10 @@ def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table):
*unequal* global ids that access the same address.
"""
import pymbolic.primitives as p
from loopy.symbolic import isl_set_from_expr
from loopy.symbolic import isl_set_from_expr, aff_from_expr, aff_to_expr
from loopy.kernel.data import (filter_iname_tags_by_type,
HardwareConcurrentTag)
from loopy.kernel.tools import get_hw_axis_base_for_codegen
gsize, lsize = knl.get_grid_size_upper_bounds(callables_table,
return_dict=True)
......@@ -357,9 +358,10 @@ def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table):
# Step 1.1: Project out inames which are also map's dims, but does not form the
# insn's within_inames
# Step 1.2: Project out sequential inames in the access maps
# Step 1.3: Rename the dims with their iname tags i.e. (g.i or l.i)
# Step 1.4: Name the ith output dims as _lp_dim{i}
# Step 1.2: Perform any offsetting required to the hw axes iname terms
# Step 1.3: Project out sequential inames in the access maps
# Step 1.4: Rename the dims with their iname tags i.e. (g.i or l.i)
# Step 1.5: Name the ith output dims as _lp_dim{i}
updated_maps = []
......@@ -381,6 +383,36 @@ def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table):
if dt == isl.dim_type.in_:
tag, = filter_iname_tags_by_type(knl.inames[name].tags,
HardwareConcurrentTag)
iname_lower_bound = get_hw_axis_base_for_codegen(knl, name)
if not iname_lower_bound.plain_is_zero():
# Hardware inames with nonzero base have an offset applied in
# code generation:
# https://github.com/inducer/loopy/blob/4e0b1c7635afe1473c8636377f8e7ef6d78dfd46/loopy/codegen/loop.py#L293-L297
# https://github.com/inducer/loopy/issues/600#issuecomment-1104066735
map_ = map_.add_dims(isl.dim_type.out, 1)
map_ = map_.move_dims(
isl.dim_type.in_, pos+1,
isl.dim_type.out, map_.dim(isl.dim_type.out)-1,
1
)
map_ = map_.set_dim_name(isl.dim_type.in_, pos+1, name+"'")
lbound_offset_expr_aff = aff_from_expr(
map_.domain().space,
(p.Variable(name+"'")
+ aff_to_expr(iname_lower_bound)
- p.Variable(name))
)
lbound_offset_as_domain = lbound_offset_expr_aff.zero_basic_set()
map_ = map_.intersect_domain(lbound_offset_as_domain)
map_ = map_.project_out(dt, pos, 1)
assert map_.get_dim_name(dt, pos) == name+"'"
map_ = map_.set_dim_name(dt, pos, name)
map_ = map_.set_dim_name(dt, pos, str(tag))
for i_l in lsize:
......
......@@ -3626,6 +3626,24 @@ def test_modulo_vs_type_context(ctx_factory):
t_unit(queue)
def test_barrier_non_zero_hw_lbound():
t_unit = lp.make_kernel(
["{[i]: 1<=i<17}",
"{[j]: 0<=j<16}"],
"""
<> a[i] = i {id=w_a}
<> b[j] = 2*a[j] {id=w_b}
""")
t_unit = lp.tag_inames(t_unit, {"i": "l.0", "j": "l.0"})
t_unit = lp.preprocess_kernel(t_unit)
knl = lp.get_one_linearized_kernel(t_unit.default_entrypoint,
t_unit.callables_table)
assert barrier_between(knl, "w_a", "w_b")
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......