From 2b409273b5657c6881d37fe10a36d9fdbd748a07 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 19 Jun 2009 14:54:42 -0400 Subject: [PATCH] Add numpy.dtype mangling utilities. --- pytools/__init__.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pytools/__init__.py b/pytools/__init__.py index 41888d0..97f98f6 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -1251,3 +1251,48 @@ def _test(): if __name__ == "__main__": _test() + + + + +# numpy dtype mangling -------------------------------------------------------- +def common_dtype(dtypes): + return argmax2((dtype, dtype.num) for dtype in dtypes) + + + + + +def to_uncomplex_dtype(dtype): + import numpy + if dtype == numpy.complex64: + return numpy.float32 + elif dtype == numpy.complex128: + return numpy.float64 + if dtype == numpy.float32: + return numpy.float32 + elif dtype == numpy.float64: + return numpy.float64 + else: + raise TypeError("unrecgonized dtype '%s'" % dtype) + + + + +def match_precision(dtype, dtype_to_match): + import numpy + + tgt_is_double = dtype_to_match in [ + numpy.float64, numpy.complex128] + + dtype_is_complex = complex in dtype.type.__mro__ + if dtype_is_complex: + if tgt_is_double: + return numpy.dtype(numpy.complex128) + else: + return numpy.dtype(numpy.complex64) + else: + if tgt_is_double: + return numpy.dtype(numpy.float64) + else: + return numpy.dtype(numpy.float32) -- GitLab