Skip to content
Snippets Groups Projects
Commit d3903747 authored by Andreas Klöckner's avatar Andreas Klöckner Committed by Andreas Klöckner
Browse files

Type VectorizabilityChecker

parent f39da3bd
No related branches found
No related tags found
No related merge requests found
......@@ -22,14 +22,22 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import TYPE_CHECKING
import numpy as np
import pymbolic.primitives as p
from pymbolic.mapper import Mapper
from loopy.codegen import UnvectorizableError
from loopy.diagnostic import LoopyError
from loopy.symbolic import simplify_using_aff
if TYPE_CHECKING:
from loopy.kernel import LoopKernel
from loopy.kernel.data import Iname
from loopy.symbolic import LinearSubscript, Reduction
# type_context may be:
......@@ -58,7 +66,7 @@ def dtype_to_type_context(target, dtype):
# {{{ vectorizability checker
class VectorizabilityChecker(Mapper):
class VectorizabilityChecker(Mapper[bool, []]):
"""The return value from this mapper is a :class:`bool` indicating whether
the result of the expression is vectorized along :attr:`vec_iname`.
If the expression is not vectorizable, the mapper raises
......@@ -67,31 +75,32 @@ class VectorizabilityChecker(Mapper):
.. attribute:: vec_iname
"""
def __init__(self, kernel, vec_iname, vec_iname_length):
def __init__(self,
kernel: LoopKernel,
vec_iname: Iname,
vec_iname_length: int
) -> None:
self.kernel = kernel
self.vec_iname = vec_iname
self.vec_iname_length = vec_iname_length
@staticmethod
def combine(vectorizabilities):
from functools import reduce
from operator import and_
return reduce(and_, vectorizabilities)
def map_sum(self, expr):
def map_sum(self, expr: p.Sum) -> bool:
return any(self.rec(child) for child in expr.children)
map_product = map_sum
def map_product(self, expr: p.Product) -> bool:
return any(self.rec(child) for child in expr.children)
def map_quotient(self, expr):
def map_quotient(self, expr: p.QuotientBase) -> bool:
return (self.rec(expr.numerator)
or
self.rec(expr.denominator))
def map_linear_subscript(self, expr):
return False
map_remainder = map_quotient
def map_linear_subscript(self, expr: LinearSubscript) -> bool:
raise UnvectorizableError("linear subscripts cannot be vectorized")
def map_call(self, expr):
def map_call(self, expr: p.Call) -> bool:
# FIXME: Should implement better vectorization check for function calls
rec_pars = [
......@@ -101,16 +110,11 @@ class VectorizabilityChecker(Mapper):
return False
def map_subscript(self, expr):
def map_subscript(self, expr: p.Subscript) -> bool:
assert isinstance(expr.aggregate, p.Variable)
name = expr.aggregate.name
var = self.kernel.arg_dict.get(name)
if var is None:
var = self.kernel.temporary_variables.get(name)
if var is None:
raise LoopyError("unknown array variable in subscript: %s"
% name)
var = self.kernel.get_var_descriptor(name)
from loopy.kernel.array import ArrayBase
if not isinstance(var, ArrayBase):
......@@ -124,11 +128,16 @@ class VectorizabilityChecker(Mapper):
from loopy.symbolic import get_dependencies
possible = None
assert isinstance(var.shape, tuple)
assert var.dim_tags is not None
for i in range(len(var.shape)):
idx_i = index[i]
if (
isinstance(var.dim_tags[i], VectorArrayDimTag)
and isinstance(index[i], Variable)
and index[i].name == self.vec_iname):
and isinstance(idx_i, Variable)
and idx_i.name == self.vec_iname):
if var.shape[i] != self.vec_iname_length:
raise UnvectorizableError("vector length was mismatched")
......@@ -136,7 +145,7 @@ class VectorizabilityChecker(Mapper):
possible = True
else:
if self.vec_iname in get_dependencies(index[i]):
if self.vec_iname in get_dependencies(idx_i):
raise UnvectorizableError("vectorizing iname '%s' occurs in "
"unvectorized subscript axis %d (1-based) of "
"expression '%s'"
......@@ -145,10 +154,11 @@ class VectorizabilityChecker(Mapper):
return bool(possible)
def map_constant(self, expr):
def map_constant(self, expr: object) -> bool:
# Loopy does not have vector literals.
return False
def map_variable(self, expr):
def map_variable(self, expr: p.Variable) -> bool:
if expr.name == self.vec_iname:
# Technically, this is doable. But we're not going there.
raise UnvectorizableError()
......@@ -158,7 +168,7 @@ class VectorizabilityChecker(Mapper):
map_tagged_variable = map_variable
def map_lookup(self, expr):
def map_lookup(self, expr: p.Lookup) -> bool:
if self.rec(expr.aggregate):
raise UnvectorizableError()
......@@ -170,13 +180,13 @@ class VectorizabilityChecker(Mapper):
raise UnvectorizableError()
def map_logical_not(self, expr):
def map_logical_not(self, expr: object) -> bool:
raise UnvectorizableError()
map_logical_and = map_logical_not
map_logical_or = map_logical_not
def map_reduction(self, expr):
def map_reduction(self, expr: Reduction) -> bool:
# FIXME: Do this more carefully
raise UnvectorizableError()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment