diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 92fb232e92a6e891efcef7022e0dceedcb4f692f..34dc9e5c723b3979386c090bbedb98ec2bc9bc88 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -732,6 +732,23 @@ class ArrayBase(Record): order=order, **kwargs) + def __eq__(self, other): + from loopy.symbolic import ( + is_tuple_of_expressions_equal as istoee, + is_expression_equal as isee) + return ( + type(self) == type(other) + and self.name == other.name + and self.picklable_dtype == other.picklable_dtype + and istoee(self.shape, other.shape) + and self.dim_tags == other.dim_tags + and isee(self.offset, other.offset) + and self.order == other.order + ) + + def __ne__(self, other): + return not self.__eq__(other) + @property def dtype(self): from loopy.tools import PicklableDtype diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index c5cecfde2fa4005669d1fca5f3439ca282f2c3c0..647eb5e36c2d05748c4882cb8b3e3d3d84d981ae 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -376,6 +376,14 @@ class TemporaryVariable(ArrayBase): def __str__(self): return self.stringify(include_typename=False) + def __eq__(self, other): + return ( + super(TemporaryVariable, self).__eq__(other) + and self.storage_shape == other.storage_shape + and self.base_indices == other.base_indices + and self.is_local == other.is_local + and self.base_storage == other.base_storage) + def update_persistent_hash(self, key_hash, key_builder): """Custom hash computation function for use with :class:`pytools.persistent_dict.PersistentDict`. diff --git a/loopy/statistics.py b/loopy/statistics.py index 834f482072a51386460e09d6c4f4d6a4406fa56a..7a064520b0d7da152f4d152568b83b68437961bd 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -605,4 +605,3 @@ def get_barrier_poly(knl): barrier_poly += isl.PwQPolynomial('{ 1 }') return barrier_poly - diff --git a/loopy/subst.py b/loopy/subst.py index a29e950a1f32d660eb10147c8638612078e816aa..5211be2d9a7e2163f87dbf27b131497ef485e0db 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -276,6 +276,8 @@ class AssignmentToSubstChanger(RuleAwareIdentityMapper): from pymbolic import var if index is None: return var(subst_name) + elif not isinstance(index, tuple): + return var(subst_name)(index) else: return var(subst_name)(*index) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index d65440d57005183b2756040473080337257c5b52..d7a49fb22508c27d0a827189204480ab61cc7132 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1186,4 +1186,47 @@ class AccessRangeMapper(WalkMapper): # }}} + +# {{{ is_expression_equal + +def is_expression_equal(a, b): + if a == b: + return True + + from pymbolic.primitives import Expression + if isinstance(a, Expression) or isinstance(b, Expression): + if a is None or b is None: + return False + + maybe_zero = a - b + from pymbolic import distribute + + d_result = distribute(maybe_zero) + return d_result == 0 + + else: + return False + + +def is_tuple_of_expressions_equal(a, b): + if a is None or b is None: + if a is None and b is None: + return True + return False + + if not isinstance(a, tuple): + a = (a,) + + if not isinstance(b, tuple): + b = (b,) + + if len(a) != len(b): + return False + + return all( + is_expression_equal(ai, bi) + for ai, bi in zip(a, b)) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index c8072108032e82517773709f5f5fd257928d3bd9..0142fd6d4962da8c555412ee43a2a349c3e6d521 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -29,6 +29,7 @@ import sys import numpy as np import loopy as lp import pyopencl as cl +import pyopencl.clmath # noqa import pyopencl.clrandom # noqa import pytest @@ -2228,6 +2229,67 @@ def test_indexof_vec(ctx_factory): #assert np.array_equal(out.ravel(order="C"), np.arange(25)) +def test_finite_difference_expr_subst(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + grid = np.linspace(0, 2*np.pi, 2048, endpoint=False) + h = grid[1] - grid[0] + u = cl.clmath.sin(cl.array.to_device(queue, grid)) + + fin_diff_knl = lp.make_kernel( + "{[i]: 1<=i<=n}", + "out[i] = -(f[i+1] - f[i-1])/h", + [lp.GlobalArg("out", shape="n+2"), "..."]) + + flux_knl = lp.make_kernel( + "{[j]: 1<=j<=n}", + "f[j] = u[j]**2/2", + [ + lp.GlobalArg("f", shape="n+2"), + lp.GlobalArg("u", shape="n+2"), + ]) + + fused_knl = lp.fuse_kernels([fin_diff_knl, flux_knl]) + + fused_knl = lp.set_options(fused_knl, write_cl=True) + evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1)) + + fused_knl = lp.assignment_to_subst(fused_knl, "f") + + fused_knl = lp.set_options(fused_knl, write_cl=True) + + # This is the real test here: The automatically generated + # shape expressions are '2+n' and the ones above are 'n+2'. + # Is loopy smart enough to understand that these are equal? + evt, _ = fused_knl(queue, u=u, h=np.float32(1e-1)) + + fused0_knl = lp.affine_map_inames(fused_knl, "i", "inew", "inew+1=i") + + gpu_knl = lp.split_iname( + fused0_knl, "inew", 128, outer_tag="g.0", inner_tag="l.0") + + precomp_knl = lp.precompute( + gpu_knl, "f_subst", "inew_inner", fetch_bounding_box=True) + + precomp_knl = lp.tag_inames(precomp_knl, {"j_0_outer": "unr"}) + precomp_knl = lp.set_options(precomp_knl, return_dict=True) + evt, _ = precomp_knl(queue, u=u, h=h) + + +def test_is_expression_equal(): + from loopy.symbolic import is_expression_equal + from pymbolic import var + + x = var("x") + y = var("y") + + assert is_expression_equal(x+2, 2+x) + + assert is_expression_equal((x+2)**2, x**2 + 4*x + 4) + assert is_expression_equal((x+y)**2, x**2 + 2*x*y + y**2) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])