diff --git a/doc/reference.rst b/doc/reference.rst index 529e866773e2f0c0b5178f347154d794185ea056..a3051d9bd0aa092cd54009ec41f80e6cde8822a9 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -193,7 +193,7 @@ Wrangling inames .. autofunction:: duplicate_inames -.. autofunction:: link_inames +.. undocumented .. autofunction:: link_inames .. autofunction:: rename_iname @@ -205,6 +205,11 @@ Wrangling inames .. autofunction:: split_reduction_outward +Dealing with Parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: fix_parameter + Dealing with Substitution Rules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index c72c7da1705d8e33fcc4938fa9fe78e0b0591958..fcc19be9cb5cf69f2e89eda49d65d95ccb607688 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1152,4 +1152,68 @@ def split_reduction_outward(kernel, inames, within=None): # }}} + +# {{{ fix_parameter + +def fix_parameter(kernel, name, value): + def process_set(s): + var_dict = s.get_var_dict() + + try: + dt, idx = var_dict[name] + except KeyError: + return s + + value_aff = isl.Aff.zero_on_domain(s.space) + value + + from loopy.isl_helpers import iname_rel_aff + name_equal_value_aff = iname_rel_aff(s.space, name, "==", value_aff) + + s = (s + .add_constraint( + isl.Constraint.equality_from_aff(name_equal_value_aff)) + .project_out(dt, idx, 1)) + + return s + + new_domains = [process_set(dom) for dom in kernel.domains] + + from pymbolic.mapper.substitutor import make_subst_func + subst_func = make_subst_func({name: value}) + + from loopy.symbolic import SubstitutionMapper + subst_map = SubstitutionMapper(subst_func) + + from loopy.kernel.array import ArrayBase + new_args = [] + for arg in kernel.args: + if arg.name == name: + # remove from argument list + continue + + if not isinstance(arg, ArrayBase): + new_args.append(arg) + else: + new_args.append(arg.map_exprs(subst_map)) + + new_temp_vars = {} + for tv in kernel.temporary_variables.itervalues(): + new_temp_vars[tv.name] = tv.map_exprs(subst_map) + + from loopy.context_matching import parse_stack_match + within = parse_stack_match(None) + + from loopy.symbolic import ExpandingSubstitutionMapper + esubst_map = ExpandingSubstitutionMapper( + kernel.substitutions, kernel.get_var_name_generator(), + subst_func, within=within) + return (esubst_map.map_kernel(kernel) + .copy( + domains=new_domains, + args=new_args, + assumptions=process_set(kernel.assumptions), + )) + +# }}} + # vim: foldmethod=marker diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index f3e89e0091073b330492648c1457131c9b7402ed..f8c52b2a0c0802b991aa86362e93d0ecf5bcd900 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -156,19 +156,21 @@ def iname_rel_aff(space, iname, rel, aff): """*aff*'s domain space is allowed to not match *space*.""" dt, pos = space.get_var_dict()[iname] - assert dt == isl.dim_type.set + assert dt in [isl.dim_type.set, isl.dim_type.param] + if dt == isl.dim_type.set: + dt = isl.dim_type.in_ from islpy import align_spaces aff = align_spaces(aff, isl.Aff.zero_on_domain(space)) if rel in ["==", "<="]: - return aff.add_coefficient(isl.dim_type.in_, pos, -1) + return aff.add_coefficient(dt, pos, -1) elif rel == ">=": - return aff.neg().add_coefficient(isl.dim_type.in_, pos, 1) + return aff.neg().add_coefficient(dt, pos, 1) elif rel == "<": - return (aff-1).add_coefficient(isl.dim_type.in_, pos, -1) + return (aff-1).add_coefficient(dt, pos, -1) elif rel == ">": - return (aff+1).neg().add_coefficient(isl.dim_type.in_, pos, 1) + return (aff+1).neg().add_coefficient(dt, pos, 1) else: raise ValueError("unknown value of 'rel': %s" % rel) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index c9adbf0c9354add723a2966ef18599813a05e85b..d91458ad85b46dc98f219a50242b59a3f0f6ed2b 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -105,8 +105,8 @@ def infer_unknown_types(kernel, expect_completion=False): logger.debug("%s: %s" % (kernel.name, s)) if kernel.substitutions: - from warnings import warn - warn("type inference called when substitution " + from warnings import warn as py_warn + py_warn("type inference called when substitution " "rules are still unexpanded, expanding", LoopyWarning, stacklevel=2) diff --git a/test/test_loopy.py b/test/test_loopy.py index 8c4ecb2dc8ca93cbd6de0afbee49a59093e36bea..c4420a0214fe1c35f69f0b85805e4fdb59803448 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1367,8 +1367,9 @@ def test_vector_types(ctx_factory, vec_len): lp.GlobalArg("a", np.float32, shape=lp.auto), lp.GlobalArg("out", np.float32, shape=lp.auto), "..." - ], - defines=dict(vec_len=vec_len)) + ]) + + knl = lp.fix_parameter(knl, "vec_len", vec_len) ref_knl = knl