Skip to content
Snippets Groups Projects
Commit 9ecdddce authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix a variety of bugs in tagged subst rule use.

parent 655b07d2
No related branches found
No related tags found
No related merge requests found
......@@ -57,6 +57,8 @@ To-do
- Scalar insn priority
- If finding a maximum proves troublesome, move parameters into the domain
- : (as in, Matlab full-sclice) in prefetches
Future ideas
......
......@@ -396,12 +396,12 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
else:
subst_name_as_expr = use
if isinstance(subst_name_as_expr, Variable):
new_subst_name = subst_name_as_expr.name
new_subst_tag = None
elif isinstance(subst_name_as_expr, TaggedVariable):
if isinstance(subst_name_as_expr, TaggedVariable):
new_subst_name = subst_name_as_expr.name
new_subst_tag = subst_name_as_expr.tag
elif isinstance(subst_name_as_expr, Variable):
new_subst_name = subst_name_as_expr.name
new_subst_tag = None
else:
raise ValueError("unexpected type of subst_name")
......@@ -473,7 +473,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
else:
return None
if subst_tag != tag:
if subst_tag is None or subst_tag != tag:
# use fall-back identity mapper
return None
......
......@@ -535,10 +535,10 @@ class SubstitutionCallbackMapper(IdentityMapper):
def parse_name(self, expr):
from pymbolic.primitives import Variable
if isinstance(expr, Variable):
e_name, e_tag = expr.name, None
elif isinstance(expr, TaggedVariable):
if isinstance(expr, TaggedVariable):
e_name, e_tag = expr.name, expr.tag
elif isinstance(expr, Variable):
e_name, e_tag = expr.name, None
else:
return None
......@@ -568,7 +568,13 @@ class SubstitutionCallbackMapper(IdentityMapper):
map_tagged_variable = map_variable
def map_call(self, expr):
from pymbolic.primitives import Lookup
if isinstance(expr.function, Lookup):
raise RuntimeError("dotted name '%s' not allowed as "
"function identifier" % expr.function)
parsed_name = self.parse_name(expr.function)
if parsed_name is None:
return IdentityMapper.map_call(self, expr)
......
......@@ -32,7 +32,7 @@ def test_laplacian_stiffness(ctx_factory):
"dPsi(ij, dxi) := sum_float32(@ax_b,"
" jacInv[ax_b,dxi,K,q] * DPsi[ax_b,ij,q])",
"A[K, i, j] = sum_float32(q, w[q] * jacDet[K,q] * ("
"sum_float32(dx_axis, dPsi.one(i,dx_axis)*dPsi.two(j,dx_axis))))"
"sum_float32(dx_axis, dPsi$one(i,dx_axis)*dPsi$two(j,dx_axis))))"
],
[
lp.ArrayArg("jacInv", dtype, shape=(dim, dim, Nc_sym, Nq), order=order),
......@@ -79,7 +79,7 @@ def test_laplacian_stiffness(ctx_factory):
Ncloc = 16
knl = lp.split_dimension(knl, "K", Ncloc,
outer_iname="Ko", inner_iname="Kloc")
knl = lp.precompute(knl, "dPsi.one", np.float32, ["dx_axis"], default_tag=None)
knl = lp.precompute(knl, "dPsi$one", np.float32, ["dx_axis"], default_tag=None)
knl = lp.tag_dimensions(knl, {"j": "ilp.seq"})
return knl, ["Ko", "Kloc"]
......@@ -131,8 +131,7 @@ def test_laplacian_stiffness(ctx_factory):
lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
op_count=0, op_label="GFlops",
parameters={"Nc": Nc}, print_ref_code=True,
timing_rounds=30)
parameters={"Nc": Nc}, print_ref_code=True)
......
......@@ -89,7 +89,7 @@ def get_suitable_size(ctx):
def check_float4(result, ref_result):
for comp in ["x", "y", "z", "w"]:
return np.allclose(ref_result[comp], result[comp], rtol=1e-3, atol=1e-3)
return np.allclose(ref_result[comp], result[comp], rtol=1e-3, atol=1e-3), None
def test_axpy(ctx_factory):
ctx = ctx_factory()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment