From 2ad41ce726325c07125fe52347eba1ab4fec0d88 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 1 Jul 2009 13:54:53 -0400
Subject: [PATCH] Add symbolic FFT to pymbolic.algorithm. Update tests.

---
 pymbolic/algorithm.py | 50 +++++++++++++++++++++++++++++++++++++++++--
 test/test_pymbolic.py | 13 +++--------
 2 files changed, 51 insertions(+), 12 deletions(-)

diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py
index a63aea9..207df9d 100644
--- a/pymbolic/algorithm.py
+++ b/pymbolic/algorithm.py
@@ -99,7 +99,6 @@ def fft(x, sign=1, wrap_intermediate=lambda x: x):
 
     N1, N2 = find_factors(N)
 
-    # do the transform
     sub_ffts = [
             wrap_intermediate(
                 fft(x[n1::N1], sign, wrap_intermediate)
@@ -123,7 +122,54 @@ def ifft(x, wrap_intermediate=lambda x:x):
 
 
 
-def csr_matrix_multiply(S,x):
+def sym_fft(x, sign=1):
+    """Perform an FFT on the numpy object array x.
+
+    Remove near-zero floating point constants, insert
+    CommonSubexpression wrappers at opportune points.
+    """
+
+    from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
+    class NearZeroKiller(CSECachingMapperMixin, IdentityMapper):
+        map_common_subexpression_uncached = \
+                IdentityMapper.map_common_subexpression
+
+        def map_constant(self, expr):
+            if isinstance(expr, complex):
+                r = expr.real
+                i = expr.imag
+                if abs(r) < 1e-15:
+                    r = 0
+                if abs(i) < 1e-15:
+                    i = 0
+                if i == 0:
+                    return r
+                else:
+                    return complex(r, i)
+            else:
+                return expr
+
+    import numpy
+
+    def wrap_intermediate(x):
+        if len(x) > 1:
+            from pymbolic.primitives import CommonSubexpression
+            result = numpy.empty(len(x), dtype=object)
+            for i, x_i in enumerate(x):
+                result[i] = CommonSubexpression(x_i)
+            return result
+        else:
+            return x
+
+    return NearZeroKiller()(
+            fft(wrap_intermediate(x), sign=sign, wrap_intermediate=wrap_intermediate))
+
+
+
+
+
+
+def csr_matrix_multiply(S, x):
     """Multiplies a scipy.sparse.csr_matrix S by an object-array vector x.
     """
     h, w = S.shape
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index ff79551..7b8e97a 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -64,21 +64,14 @@ def test_fft():
     numpy = py.test.importorskip("numpy")
 
     from pymbolic import var
-    from pymbolic.algorithm import fft
+    from pymbolic.algorithm import fft, sym_fft
 
     vars = numpy.array([var(chr(97+i)) for i in range(16)], dtype=object)
     print vars
 
-    def wrap_intermediate(x):
-        if len(x) > 1:
-            from hedge.optemplate import make_common_subexpression
-            return make_common_subexpression(x)
-        else:
-            return x
-
     nzk = NearZeroKiller()
-    print nzk(fft(vars))
-    traced_fft = nzk(fft(vars, wrap_intermediate=wrap_intermediate))
+    print fft(vars)
+    traced_fft = sym_fft(vars)
 
     from pymbolic.mapper.stringifier import PREC_NONE
     from pymbolic.mapper.c_code import CCodeMapper
-- 
GitLab