diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py index a63aea94b727da44d61a0dcfd58287df1573da0c..207df9d16751f491b8015c5e62e0649bcc4ab3b4 100644 --- a/pymbolic/algorithm.py +++ b/pymbolic/algorithm.py @@ -99,7 +99,6 @@ def fft(x, sign=1, wrap_intermediate=lambda x: x): N1, N2 = find_factors(N) - # do the transform sub_ffts = [ wrap_intermediate( fft(x[n1::N1], sign, wrap_intermediate) @@ -123,7 +122,54 @@ def ifft(x, wrap_intermediate=lambda x:x): -def csr_matrix_multiply(S,x): +def sym_fft(x, sign=1): + """Perform an FFT on the numpy object array x. + + Remove near-zero floating point constants, insert + CommonSubexpression wrappers at opportune points. + """ + + from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin + class NearZeroKiller(CSECachingMapperMixin, IdentityMapper): + map_common_subexpression_uncached = \ + IdentityMapper.map_common_subexpression + + def map_constant(self, expr): + if isinstance(expr, complex): + r = expr.real + i = expr.imag + if abs(r) < 1e-15: + r = 0 + if abs(i) < 1e-15: + i = 0 + if i == 0: + return r + else: + return complex(r, i) + else: + return expr + + import numpy + + def wrap_intermediate(x): + if len(x) > 1: + from pymbolic.primitives import CommonSubexpression + result = numpy.empty(len(x), dtype=object) + for i, x_i in enumerate(x): + result[i] = CommonSubexpression(x_i) + return result + else: + return x + + return NearZeroKiller()( + fft(wrap_intermediate(x), sign=sign, wrap_intermediate=wrap_intermediate)) + + + + + + +def csr_matrix_multiply(S, x): """Multiplies a scipy.sparse.csr_matrix S by an object-array vector x. """ h, w = S.shape diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index ff79551b68ad3b9bd49e61ea02389bf5976c3c69..7b8e97a3872c0e0b190ba252c990ee742ac15b1a 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -64,21 +64,14 @@ def test_fft(): numpy = py.test.importorskip("numpy") from pymbolic import var - from pymbolic.algorithm import fft + from pymbolic.algorithm import fft, sym_fft vars = numpy.array([var(chr(97+i)) for i in range(16)], dtype=object) print vars - def wrap_intermediate(x): - if len(x) > 1: - from hedge.optemplate import make_common_subexpression - return make_common_subexpression(x) - else: - return x - nzk = NearZeroKiller() - print nzk(fft(vars)) - traced_fft = nzk(fft(vars, wrap_intermediate=wrap_intermediate)) + print fft(vars) + traced_fft = sym_fft(vars) from pymbolic.mapper.stringifier import PREC_NONE from pymbolic.mapper.c_code import CCodeMapper