diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py index cd1b00f9bff6bf587c4ac9eff7e73ca35e93f815..3378ed81ee56f97cc11f8f8998aeb67221061633 100644 --- a/loopy/codegen/control.py +++ b/loopy/codegen/control.py @@ -456,11 +456,10 @@ def build_loop_nest(codegen_state, schedule_index): prev_gen_code = gen_code def gen_code(inner_codegen_state): - from pymbolic.primitives import Variable condition_exprs = [ constraint_to_expr(cns) for cns in bounds_checks] + [ - Variable(pred_chk) for pred_chk in pred_checks] + pred_chk for pred_chk in pred_checks] prev_result = prev_gen_code(inner_codegen_state) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 18953d0081dfe4c970ba62bfa9a313b8e37f292b..af8194cd0b7f1a07fd22d7086c92b45c475b5312 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -88,10 +88,9 @@ class InstructionBase(Record): .. attribute:: predicates - a :class:`frozenset` of variable names the conjunction (logical and) of - whose truth values (as defined by C) determine whether this instruction - should be run. Each variable name may, optionally, be preceded by - an exclamation point, indicating negation. + a :class:`frozenset` of expressions. The conjunction (logical and) of + their truth values (as defined by C) determines whether this instruction + should be run. .. rubric:: Iname dependencies @@ -161,6 +160,24 @@ class InstructionBase(Record): within_inames = forced_iname_deps within_inames_is_final = forced_iname_deps_is_final + new_predicates = set() + for pred in predicates: + if isinstance(pred, str): + from pymbolic.primitives import LogicalNot + from loopy.symbolic import parse + if pred.startswith("!"): + from warnings import warn + warn("predicates starting with '!' are deprecated. " + "Simply use 'not' instead") + pred = LogicalNot(parse(pred[1:])) + else: + pred = parse(pred) + + new_predicates.add(pred) + + predicates = new_predicates + del new_predicates + # }}} if depends_on is None: @@ -259,7 +276,13 @@ class InstructionBase(Record): # {{{ abstract interface def read_dependency_names(self): - raise NotImplementedError + from loopy.symbolic import get_dependencies + result = frozenset() + + for pred in self.predicates: + result = result | get_dependencies(pred) + + return result def reduction_inames(self): raise NotImplementedError @@ -607,15 +630,13 @@ class MultiAssignmentBase(InstructionBase): @memoize_method def read_dependency_names(self): from loopy.symbolic import get_dependencies - result = get_dependencies(self.expression) + result = ( + super(MultiAssignmentBase, self).read_dependency_names() + | get_dependencies(self.expression)) + for subscript_deps in self.assignee_subscript_deps(): result = result | subscript_deps - processed_predicates = frozenset( - pred.lstrip("!") for pred in self.predicates) - - result = result | processed_predicates - return result @memoize_method @@ -755,7 +776,9 @@ class Assignment(MultiAssignmentBase): def with_transformed_expressions(self, f, *args): return self.copy( assignee=f(self.assignee, *args), - expression=f(self.expression, *args)) + expression=f(self.expression, *args), + predicates=frozenset( + f(pred, *args) for pred in self.predicates)) # }}} @@ -919,7 +942,9 @@ class CallInstruction(MultiAssignmentBase): def with_transformed_expressions(self, f, *args): return self.copy( assignees=f(self.assignees, *args), - expression=f(self.expression, *args)) + expression=f(self.expression, *args), + predicates=frozenset( + f(pred) for pred in self.predicates)) # }}} @@ -1098,7 +1123,9 @@ class CInstruction(InstructionBase): # {{{ abstract interface def read_dependency_names(self): - result = set(self.read_variables) + result = ( + super(MultiAssignmentBase, self).read_dependency_names() + | frozenset(self.read_variables)) from loopy.symbolic import get_dependencies for name, iname_expr in self.iname_exprs: @@ -1125,7 +1152,9 @@ class CInstruction(InstructionBase): iname_exprs=[ (name, f(expr, *args)) for name, expr in self.iname_exprs], - assignees=[f(a, *args) for a in self.assignees]) + assignees=[f(a, *args) for a in self.assignees], + predicates=frozenset( + f(pred) for pred in self.predicates)) # }}} @@ -1168,8 +1197,7 @@ class CInstruction(InstructionBase): class _DataObliviousInstruction(InstructionBase): # {{{ abstract interface - def read_dependency_names(self): - return frozenset() + # read_dependency_names inherited def reduction_inames(self): return frozenset() @@ -1181,7 +1209,9 @@ class _DataObliviousInstruction(InstructionBase): return frozenset() def with_transformed_expressions(self, f, *args): - return self + return self.copy( + predicates=frozenset( + f(pred) for pred in self.predicates)) # }}} diff --git a/loopy/version.py b/loopy/version.py index 12b2fedbec2669205de2a82a8d5eca42678e9353..76d130f87c459f3dd7bf4b26705f12b69c1924ca 100644 --- a/loopy/version.py +++ b/loopy/version.py @@ -32,4 +32,4 @@ except ImportError: else: _islpy_version = islpy.version.VERSION_TEXT -DATA_MODEL_VERSION = "v42-islpy%s" % _islpy_version +DATA_MODEL_VERSION = "v43-islpy%s" % _islpy_version diff --git a/test/test_loopy.py b/test/test_loopy.py index 49a369c2fe05c231945672ea08d62cd8e67c5180..ad1b0db7796ec15e923c2670e2d8a7f7a2006447 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1407,6 +1407,29 @@ def test_index_cse(ctx_factory): print(lp.generate_code_v2(knl).device_code()) +def test_ilp_and_conditionals(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel('{[k]: 0<=k<n}}', + """ + for k + <> Tcond = T[k] < 0.5 + if Tcond + cp[k] = 2 * T[k] + Tcond + end + end + """) + + knl = lp.fix_parameters(knl, n=200) + knl = lp.add_and_infer_dtypes(knl, {"T": np.float32}) + + ref_knl = knl + + knl = lp.split_iname(knl, 'k', 2, inner_tag='ilp') + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])