diff --git a/loopy/expression.py b/loopy/expression.py index e3eb65dc5c6e100d427f78dfd7e8da41d21e9b17..64172b88a6cc2dda17b60a6427cfcf9bf430600f 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -22,14 +22,22 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - +from typing import TYPE_CHECKING import numpy as np +import pymbolic.primitives as p from pymbolic.mapper import Mapper from loopy.codegen import UnvectorizableError from loopy.diagnostic import LoopyError +from loopy.symbolic import simplify_using_aff + + +if TYPE_CHECKING: + from loopy.kernel import LoopKernel + from loopy.kernel.data import Iname + from loopy.symbolic import LinearSubscript, Reduction # type_context may be: @@ -58,7 +66,7 @@ def dtype_to_type_context(target, dtype): # {{{ vectorizability checker -class VectorizabilityChecker(Mapper): +class VectorizabilityChecker(Mapper[bool, []]): """The return value from this mapper is a :class:`bool` indicating whether the result of the expression is vectorized along :attr:`vec_iname`. If the expression is not vectorizable, the mapper raises @@ -67,31 +75,32 @@ class VectorizabilityChecker(Mapper): .. attribute:: vec_iname """ - def __init__(self, kernel, vec_iname, vec_iname_length): + def __init__(self, + kernel: LoopKernel, + vec_iname: Iname, + vec_iname_length: int + ) -> None: self.kernel = kernel self.vec_iname = vec_iname self.vec_iname_length = vec_iname_length - @staticmethod - def combine(vectorizabilities): - from functools import reduce - from operator import and_ - return reduce(and_, vectorizabilities) - - def map_sum(self, expr): + def map_sum(self, expr: p.Sum) -> bool: return any(self.rec(child) for child in expr.children) - map_product = map_sum + def map_product(self, expr: p.Product) -> bool: + return any(self.rec(child) for child in expr.children) - def map_quotient(self, expr): + def map_quotient(self, expr: p.QuotientBase) -> bool: return (self.rec(expr.numerator) or self.rec(expr.denominator)) - def map_linear_subscript(self, expr): - return False + map_remainder = map_quotient + + def map_linear_subscript(self, expr: LinearSubscript) -> bool: + raise UnvectorizableError("linear subscripts cannot be vectorized") - def map_call(self, expr): + def map_call(self, expr: p.Call) -> bool: # FIXME: Should implement better vectorization check for function calls rec_pars = [ @@ -101,16 +110,11 @@ class VectorizabilityChecker(Mapper): return False - def map_subscript(self, expr): + def map_subscript(self, expr: p.Subscript) -> bool: + assert isinstance(expr.aggregate, p.Variable) name = expr.aggregate.name - var = self.kernel.arg_dict.get(name) - if var is None: - var = self.kernel.temporary_variables.get(name) - - if var is None: - raise LoopyError("unknown array variable in subscript: %s" - % name) + var = self.kernel.get_var_descriptor(name) from loopy.kernel.array import ArrayBase if not isinstance(var, ArrayBase): @@ -124,11 +128,16 @@ class VectorizabilityChecker(Mapper): from loopy.symbolic import get_dependencies possible = None + + assert isinstance(var.shape, tuple) + assert var.dim_tags is not None + for i in range(len(var.shape)): + idx_i = index[i] if ( isinstance(var.dim_tags[i], VectorArrayDimTag) - and isinstance(index[i], Variable) - and index[i].name == self.vec_iname): + and isinstance(idx_i, Variable) + and idx_i.name == self.vec_iname): if var.shape[i] != self.vec_iname_length: raise UnvectorizableError("vector length was mismatched") @@ -136,7 +145,7 @@ class VectorizabilityChecker(Mapper): possible = True else: - if self.vec_iname in get_dependencies(index[i]): + if self.vec_iname in get_dependencies(idx_i): raise UnvectorizableError("vectorizing iname '%s' occurs in " "unvectorized subscript axis %d (1-based) of " "expression '%s'" @@ -145,10 +154,11 @@ class VectorizabilityChecker(Mapper): return bool(possible) - def map_constant(self, expr): + def map_constant(self, expr: object) -> bool: + # Loopy does not have vector literals. return False - def map_variable(self, expr): + def map_variable(self, expr: p.Variable) -> bool: if expr.name == self.vec_iname: # Technically, this is doable. But we're not going there. raise UnvectorizableError() @@ -158,7 +168,7 @@ class VectorizabilityChecker(Mapper): map_tagged_variable = map_variable - def map_lookup(self, expr): + def map_lookup(self, expr: p.Lookup) -> bool: if self.rec(expr.aggregate): raise UnvectorizableError() @@ -170,13 +180,13 @@ class VectorizabilityChecker(Mapper): raise UnvectorizableError() - def map_logical_not(self, expr): + def map_logical_not(self, expr: object) -> bool: raise UnvectorizableError() map_logical_and = map_logical_not map_logical_or = map_logical_not - def map_reduction(self, expr): + def map_reduction(self, expr: Reduction) -> bool: # FIXME: Do this more carefully raise UnvectorizableError()