diff --git a/lappy/core/array.py b/lappy/core/array.py index ac057270009769b2f631c27b5feb0290fcc08dc5..cce8311583351a5bf022240c2ab7125b60b4d19d 100644 --- a/lappy/core/array.py +++ b/lappy/core/array.py @@ -253,6 +253,13 @@ class Array(LazyObjectBase): index names for iterating over the array. + .. attribute:: domain_expr + + set expression for constructing the loop domain. A set expression + has leafs of :class:`lappy.core.primitives.PwAff` variables and + integers. If domain_expr is None, the index domain of the array is + used. + .. attribute:: dtype data type, either None or convertible to loopy type. Lappy @@ -270,7 +277,7 @@ class Array(LazyObjectBase): _counter = 0 _name_prefix = '__lappy_array_' - def __init__(self, name=None, shape=None, **kwargs): + def __init__(self, name=None, shape=None, domain_expr=None, **kwargs): # default expr if 'expr' not in kwargs: kwargs['expr'] = None @@ -280,13 +287,14 @@ class Array(LazyObjectBase): super(Array, self).__init__(**kwargs) self._ndim = to_lappy_unsigned(kwargs.pop("ndim", None)) + self.domain_expr = domain_expr if self.ndim == 0: self.name = self.name.replace('array', 'number') if self.ndim is None: # infer ndim from given shape - # NOTE: use fixed ndim only (for the moment) + # NOTE: use fixed ndim only shape = to_lappy_shape(shape) self._ndim = to_lappy_unsigned(int(len(shape))) @@ -327,6 +335,7 @@ class Array(LazyObjectBase): repr_dict.update({ 'ndim': self.ndim, 'inames': self.inames, + 'domain': self.domain_expr, 'shape': self.shape, 'dtype': self.dtype, 'is_integral': self.is_integral, @@ -506,6 +515,7 @@ class Array(LazyObjectBase): inames=self.inames, shape=self._shape, dtype=self.dtype, expr=self.expr, is_integral=self.is_integral, + domain_expr=self.domain_expr, arguments=self.arguments.copy(), bound_arguments=self.bound_arguments.copy(), intermediaries=self.intermediaries.copy(), diff --git a/lappy/core/compiler.py b/lappy/core/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..64f3513f5e25001ba2924cdc6b2bd45c44590a48 --- /dev/null +++ b/lappy/core/compiler.py @@ -0,0 +1,180 @@ +__copyright__ = "Copyright (C) 2019 Xiaoyu Wei" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +import islpy as isl +from pymbolic.mapper import Mapper +from lappy.core.mapper import SetVariableCollector, SetParameterCollector +from lappy.core.primitives import EMPTY_SET, UNIVERSE_SET + +# {{{ loop domain compiler + + +class LoopDomainCompiler(Mapper): + """Compiles a set expression in to an ISL set. + """ + def map_constant(self, expr, var_dict): + if expr == 0: + return var_dict[0] + return expr + + def map_set_variable(self, expr, var_dict): + return var_dict[expr.name] + + def map_set_parameter(self, expr, var_dict): + return var_dict[expr.name] + + def map_sum(self, expr, var_dict): + res = var_dict[0] + for c in expr.children: + res += self.rec(c, var_dict) + return res + + def map_pwaff_comparison(self, expr, var_dict): + lhs = self.rec(expr.left, var_dict) + rhs = self.rec(expr.right, var_dict) + op = expr.operator + + print(type(lhs), type(rhs)) + + if isinstance(lhs, int) and isinstance(rhs, int): + exec('valid = (%s %s %s)' % (str(lhs), str(op), str(rhs))) + if valid: # noqa: F821 + return 1 # stands for the whole set + else: + return 0 # stands for the empty set + + if isinstance(lhs, int) and isinstance(rhs, isl.PwAff): + lhs, rhs = rhs, lhs + + if isinstance(lhs, isl.PwAff) and isinstance(rhs, int): + lhs = lhs - rhs + rhs = var_dict[0] + + if isinstance(lhs, isl.PwAff) and isinstance(rhs, isl.PwAff): + if op == '==': + return lhs.eq_set(rhs) + if op == '!=': + return lhs.ne_set(rhs) + if op == '<=': + return lhs.le_set(rhs) + if op == '<': + return lhs.lt_set(rhs) + if op == '>=': + return lhs.ge_set(rhs) + if op == '>': + return lhs.gt_set(rhs) + raise ValueError('unknown operator "%s"' % str(op)) + + else: + raise ValueError('cannot compile operator "%s" (lhs=%s, rhs=%s)' + % (op, str(expr.left), str(expr.right))) + + def map_pwaff_sum(self, expr, var_dict): + """Mapper for sum of pwaffs. + """ + return sum([self.rec(child, var_dict) for child in expr.children], 0) + + def map_pwaff_product(self, expr, var_dict): + """Mapper for product of pwaffs. + """ + res = 1 + for rc in [self.rec(child, var_dict) for child in expr.children]: + res *= rc + return res + + def map_pwaff_floor_div(self, expr, var_dict): + return (self.rec(expr.numerator, var_dict) + // self.rec(expr.denominator, var_dict)) + + def map_pwaff_remainder(self, expr, var_dict): + return (self.rec(expr.numerator, var_dict) + % self.rec(expr.denominator, var_dict)) + + def map_set_union(self, expr, var_dict): + """Mapper for set union. + + Singular sets (empty set and its complement) are represented with ints. + """ + children = [self.rec(child, var_dict) for child in expr.children] + + for c in children: + if isinstance(c, int) and c == UNIVERSE_SET: + return UNIVERSE_SET + + children = [c for c in children if not isinstance(c, int)] + + if len(children) > 0: + res = children[0] + else: + res = UNIVERSE_SET + + for c in children: + assert isinstance(res, isl.Set) and isinstance(c, isl.Set) + res = res | c + + return res + + def map_set_intersection(self, expr, var_dict): + """Mapper for set intersection. + + Singular sets (empty set and its complement) are represented with ints. + """ + children = [self.rec(child, var_dict) for child in expr.children] + + for c in children: + if isinstance(c, int) and c == EMPTY_SET: + return EMPTY_SET + + children = [c for c in children if not isinstance(c, int)] + + if len(children) > 0: + res = children[0] + else: + res = UNIVERSE_SET + + for c in children: + assert isinstance(res, isl.Set) and isinstance(c, isl.Set) + res = res & c + + return res + + def __call__(self, expr): + """Returns a list of basic sets representing the loop domain. + + The variables and parameter are lexicographically ordered. + """ + var = sorted(list(SetVariableCollector()(expr))) + param = sorted(list(SetParameterCollector()(expr))) + var_dict = isl.make_zero_and_vars(var, param) + + domain_pre = super(LoopDomainCompiler, self).__call__(expr, var_dict) + if isinstance(domain_pre, isl.PwAff): + # unconstrained domain + raise ValueError("unconstrained (thus unbounded) loop domain") + + if isinstance(domain_pre, isl.Set): + domain = domain_pre.coalesce().get_basic_sets() + else: + raise RuntimeError() + + return domain + +# }}} End loop domain compiler diff --git a/lappy/core/mapper.py b/lappy/core/mapper.py index e790dfb5812a4647597808bc718047af6619305b..56d40657e3dd29105e13cf6d6f4ada69ad11fcbc 100644 --- a/lappy/core/mapper.py +++ b/lappy/core/mapper.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import division, absolute_import, print_function copyright__ = "Copyright (C) 2017 Sophia Lin and Andreas Kloeckner" @@ -22,6 +23,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from pymbolic.mapper import Collector +from pymbolic.mapper.stringifier import ( + StringifyMapper, + PREC_SUM, PREC_COMPARISON, PREC_NONE, PREC_PRODUCT) class SetVariableCollector(Collector): @@ -30,7 +34,15 @@ class SetVariableCollector(Collector): def map_set_variable(self, expr): return {expr.name, } + map_set_zero = Collector.map_constant map_set_parameter = Collector.map_constant + map_pwaff_comparison = Collector.map_comparison + map_pwaff_sum = Collector.map_sum + map_pwaff_product = Collector.map_sum + map_set_union = Collector.map_sum + map_set_intersection = Collector.map_sum + map_pwaff_floor_div = Collector.map_quotient + map_pwaff_remainder = Collector.map_quotient class SetParameterCollector(Collector): @@ -39,4 +51,54 @@ class SetParameterCollector(Collector): def map_set_parameter(self, expr): return {expr.name, } + map_set_zero = Collector.map_constant map_set_variable = Collector.map_constant + map_pwaff_comparison = Collector.map_comparison + map_pwaff_sum = Collector.map_sum + map_pwaff_product = Collector.map_sum + map_set_union = Collector.map_sum + map_set_intersection = Collector.map_sum + map_pwaff_floor_div = Collector.map_quotient + map_pwaff_remainder = Collector.map_quotient + + +class SetExpressionStringifyMapper(StringifyMapper): + def map_set_variable(self, expr, *args, **kwargs): + return expr.name + + def map_set_parameter(self, expr, *args, **kwargs): + return expr.name + + def map_set_zero(self, expr, *args, **kwargs): + return expr.name + + def map_set_union(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(u" ∪ ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) + + def map_set_intersection(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(u" ∩ ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) + + def map_pwaff_comparison(self, expr, enclosing_prec, *args, **kwargs): + return self.format("{%s %s %s}", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + expr.operator, + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)) + + def map_pwaff_floor_div(self, expr, enclosing_prec, *args, **kwargs): + return self.format("%s // %s", + self.rec(expr.numerator, PREC_NONE, *args, **kwargs), + self.rec(expr.denominator, PREC_NONE, *args, **kwargs)) + + def map_pwaff_sum(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) + + def map_pwaff_product(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, PREC_PRODUCT) diff --git a/lappy/core/primitives.py b/lappy/core/primitives.py index a13a275b6c3fa046fc15fcfde8d83ecccaadf9b9..403c2053e265cd822157b3bb8d3ec5c1c94b50e4 100644 --- a/lappy/core/primitives.py +++ b/lappy/core/primitives.py @@ -25,9 +25,11 @@ THE SOFTWARE. from six.moves import intern from pymbolic.primitives import AlgebraicLeaf -from pymbolic.primitives import Variable +from pymbolic.primitives import Expression from pymbolic.primitives import QuotientBase -from pymbolic.primitives import make_sym_vector # noqa: F401 +from pymbolic.primitives import Variable + +from lappy.core.mapper import SetExpressionStringifyMapper # {{{ Array computation primitives @@ -91,11 +93,319 @@ class Reduction(AlgebraicLeaf): # {{{ Integer set computation primitives -class PwAff(Variable): +EMPTY_SET = 0 +UNIVERSE_SET = 1 # complement of the empty set + + +class SetExpression(Expression): + """Expression with replaced stringifier. + + The expression tree has two layers, a lower layer of PwAffs and an upper layer of + Sets. + The leaf nodes are PwAffs, and the internal layer above the comparison nodes are + Sets. All nodes further above are also Sets. + """ + # {{{ disable unused methods for easier debugging + + def __add__(self, *args, **kwargs): + raise NotImplementedError() + + def __radd__(self, *args, **kwargs): + raise NotImplementedError() + + def __sub__(self, *args, **kwargs): + raise NotImplementedError() + + def __rsub__(self, *args, **kwargs): + raise NotImplementedError() + + def __mul__(self, *args, **kwargs): + raise NotImplementedError() + + def __rmul__(self, *args, **kwargs): + raise NotImplementedError() + + def __div__(self, *args, **kwargs): + raise NotImplementedError() + + def __truediv__(self, *args, **kwargs): + return self.__div__(*args, **kwargs) + + def __rdiv__(self, *args, **kwargs): + raise NotImplementedError() + + def __floordiv__(self, *args, **kwargs): + raise NotImplementedError() + + def __rfloordiv__(self, *args, **kwargs): + raise NotImplementedError() + + def __mod__(self, *args, **kwargs): + raise NotImplementedError() + + def __rmod__(self, *args, **kwargs): + raise NotImplementedError() + + def __pow__(self, *args, **kwargs): + raise NotImplementedError() + + def __rpow__(self, *args, **kwargs): + raise NotImplementedError() + + def __lshift__(self, other): + raise NotImplementedError() + + def __rlshift__(self, other): + raise NotImplementedError() + + def __rshift__(self, other): + raise NotImplementedError() + + def __rrshift__(self, other): + raise NotImplementedError() + + def __inv__(self): + raise NotImplementedError() + + def __or__(self, other): + raise NotImplementedError() + + def __ror__(self, other): + raise NotImplementedError() + + def __xor__(self, other): + raise NotImplementedError() + + def __rxor__(self, other): + raise NotImplementedError() + + def __and__(self, other): + raise NotImplementedError() + + def __rand__(self, other): + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def eq(self, *args, **kwargs): + raise NotImplementedError() + + def ne(self, *args, **kwargs): + raise NotImplementedError() + + def le(self, *args, **kwargs): + raise NotImplementedError() + + def lt(self, *args, **kwargs): + raise NotImplementedError() + + def ge(self, *args, **kwargs): + raise NotImplementedError() + + def gt(self, *args, **kwargs): + raise NotImplementedError() + + def not_(self, *args, **kwargs): + raise NotImplementedError() + + # }}} End disabled methods + + def make_stringifier(self, originating_stringifier=None): + return SetExpressionStringifyMapper() + + def or_(self, other): + if isinstance(self, PwAff) or isinstance(other, PwAff): + raise ValueError("cannot compute set union of %s and %s" + % (str(self), str(other))) + if isinstance(other, SetUnion): + return SetUnion((self, ) + other.children) + else: + return SetUnion((self, other)) + + def and_(self, other): + if isinstance(self, PwAff) or isinstance(other, PwAff): + raise ValueError("cannot compute set intersection of %s and %s" + % (str(self), str(other))) + if isinstance(other, SetIntersection): + return SetIntersection((self, ) + other.children) + else: + return SetIntersection((self, other)) + + def _make_pwaff_comparison(self, op, other): + if not isinstance(self, PwAff) and isinstance(other, PwAff): + raise ValueError("cannot do pwaff comparison with %s and %s" + % (str(self), str(other))) + return PwAffComparison(self, op, other) + + +class PwAffComparison(SetExpression): + """PwAffs form sets via the comparison node. + + .. attribute:: left + .. attribute:: operator + + One of ``[">", ">=", "==", "!=", "<", "<="]``. + + .. attribute:: right + """ + + init_arg_names = ("left", "operator", "right") + + def __init__(self, left, operator, right): + self.left = left + self.right = right + if operator not in [">", ">=", "==", "!=", "<", "<="]: + raise RuntimeError("invalid operator") + self.operator = operator + + def __getinitargs__(self): + return self.left, self.operator, self.right + + mapper_method = intern("map_pwaff_comparison") + + +class _MultiChildSetExpression(SetExpression): + init_arg_names = ("children",) + + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children, + + +class SetUnion(_MultiChildSetExpression): + """ + .. attribute:: children + + A :class:`tuple`. + """ + def or_(self, other): + if isinstance(other, PwAff): + raise ValueError() + if isinstance(other, SetUnion): + return SetUnion(self.children + other.children) + else: + return SetUnion(self.children + (other, )) + + mapper_method = intern("map_set_union") + + +class SetIntersection(_MultiChildSetExpression): + """ + .. attribute:: children + + A :class:`tuple`. + """ + def and_(self, other): + if isinstance(other, PwAff): + raise ValueError() + if isinstance(other, SetIntersection): + return SetIntersection(self.children + other.children) + else: + return SetIntersection(self.children + (other, )) + + mapper_method = intern("map_set_intersection") + + +class PwAff(SetExpression): """PwAff: variables used to construct piecewise-affine expressions. + + .. attribute:: name """ + init_arg_names = ("name",) + mapper_method = intern('map_pwaff') + + def __init__(self, name): + assert name + self.name = intern(name) + + def __getinitargs__(self): + return self.name, + + def __lt__(self, other): + if isinstance(other, Variable): + return self.name.__lt__(other.name) + else: + return NotImplemented + + def __setstate__(self, val): + super(Variable, self).__setstate__(val) + + self.name = intern(self.name) + + def le(self, other): + return self._make_pwaff_comparison('<=', other) + + def lt(self, other): + return self._make_pwaff_comparison('<', other) + + def ge(self, other): + return self._make_pwaff_comparison('>=', other) + + def gt(self, other): + return self._make_pwaff_comparison('>', other) + + def eq(self, other): + return self._make_pwaff_comparison('==', other) - mapper_method = 'map_pwaff' + def ne(self, other): + return self._make_pwaff_comparison('!=', other) + + def __add__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum((self,) + other.children) + else: + return PwAffSum((self, other)) + + def __radd__(self, other): + assert isinstance(other, int) + return PwAffSum((other, self)) + + def __sub__(self, other): + return self + (-1) * other + + def __rsub__(self, other): + return self + (-1) * other + + def __floordiv__(self, other): + if not isinstance(other, int) and other > 0: + return ValueError() + return PwAffFloorDiv(self, other) + + def __mod__(self, other): + if not isinstance(other, int) and other > 0: + return ValueError() + return PwAffRemainder(self, other) + + def __mul__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffProduct): + return PwAffProduct((self,) + other.children) + else: + return PwAffProduct((self, other)) + + def __rmul__(self, other): + assert isinstance(other, int) + return PwAffProduct((other, self)) + + +class _MultiChildPwAff(PwAff): + init_arg_names = ("children",) + + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children, class SetVar(PwAff): @@ -111,4 +421,89 @@ class SetParam(PwAff): mapper_method = 'map_set_parameter' + +class SetZero(PwAff): + """The constant 0. + """ + + mapper_method = 'map_set_zero' + + +class PwAffSum(_MultiChildPwAff): + """Sum of PwAffs. + + .. attribute:: children + + A :class:`tuple`. + """ + + def __add__(self, other): + if not isinstance(other, (int, PwAff)): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum(self.children + other.children) + return PwAffSum(self.children + (other,)) + + def __radd__(self, other): + if not isinstance(other, int): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum(other.children + self.children) + return PwAffSum((other,) + self.children) + + mapper_method = intern("map_pwaff_sum") + + +class PwAffProduct(_MultiChildPwAff): + """Product of PwAffs. Up to one non-constant child is allowed. + + .. attribute:: children + + A :class:`tuple`. + """ + def __mul__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffProduct): + return PwAffProduct(self.children + other.children) + return PwAffProduct(self.children + (other,)) + + def __rmul__(self, other): + if not isinstance(other, int): + return ValueError() + return PwAffProduct((other,) + self.children) + + mapper_method = intern("map_pwaff_product") + + +class PwAffQuotientBase(PwAff): + init_arg_names = ("numerator", "denominator",) + + def __init__(self, numerator, denominator=1): + self.numerator = numerator + self.denominator = denominator + + def __getinitargs__(self): + return self.numerator, self.denominator + + mapper_method = None + + +class PwAffFloorDiv(PwAffQuotientBase): + """ + .. attribute:: numerator + .. attribute:: denominator + """ + + mapper_method = intern("map_pwaff_floor_div") + + +class PwAffRemainder(PwAffQuotientBase): + """ + .. attribute:: numerator + .. attribute:: denominator + """ + + mapper_method = intern("map_pwaff_remainder") + # }}} End Integer set computation primitives diff --git a/lappy/core/ufuncs.py b/lappy/core/ufuncs.py index 19065d89c70342d3628056730931581cfb987ac7..8a958f987e9f415dc38633a689f9eabc0640dfcb 100644 --- a/lappy/core/ufuncs.py +++ b/lappy/core/ufuncs.py @@ -121,6 +121,7 @@ class UnaryOperation(UFunc): 'name': name, 'inames': a.inames, 'expr': self.f(a.expr), 'value': None, + 'domain_expr': a.domain_expr, 'arguments': a.arguments.copy(), 'bound_arguments': a.bound_arguments.copy(), 'env': a.env.copy(), @@ -288,6 +289,7 @@ class BinaryOperation(UFunc): 'name': name, 'inames': a.inames, 'expr': self.f(a.expr, b.expr), 'value': None, + 'domain_expr': a.domain_expr, 'arguments': new_arglist, 'bound_arguments': new_bound_arglist, 'intermediaries': new_interm, diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index 0dd3975fa543c44e1283bd64575265773d0fb533..2618b346e6a220d9918acd2740c770e43c9f769e 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import division, absolute_import, print_function __copyright__ = "Copyright (C) 2019 Xiaoyu Wei" @@ -23,8 +24,10 @@ THE SOFTWARE. """ import pytest +import islpy as isl from lappy.core.primitives import SetVar, SetParam from lappy.core.mapper import SetVariableCollector, SetParameterCollector +from lappy.core.compiler import LoopDomainCompiler @pytest.mark.parametrize('expr, var', [ @@ -45,3 +48,53 @@ def test_variable_collector(expr, var): def test_parameter_collector(expr, var): collector = SetParameterCollector() assert collector(expr) == var + + +@pytest.mark.parametrize('expr,expected_set', [ + ((SetVar('i') + SetVar('j')).ge(10).and_( + SetVar('i').le(SetParam('m'))).and_( + SetVar('j').lt(SetParam('n'))), + ["[m, n] -> { [i, j]: i <= m and j < n and i + j >= 10}"]), + (SetVar('i').lt(SetParam('n')).and_(SetVar('i').ge(0)), + ["[n] -> { [i]: i < n and i >= 0 }"]), + (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), + ["[n] -> { [i]: i < n and i * 12 + 3 >= -2 }"]), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') - 9).ge(8)), + ["[n] -> { [i]: i < n and i - 9 >= 8 }"]), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') // 9).ge(8)), + ["[n] -> { [i]: i < n and i / 9 >= 8 }"]), + ]) +def test_loop_domain_compiler(expr, expected_set): + + compiler = LoopDomainCompiler() + domain = compiler(expr) + + print(domain) + expected_domain = [ + isl.BasicSet.read_from_str(isl.DEFAULT_CONTEXT, es) + for es in expected_set] + + assert domain == expected_domain + + +@pytest.mark.parametrize('expr,expected_str', [ + ((SetVar('i') + SetVar('j')).ge(10).and_( + SetVar('i').le(SetParam('m'))).and_( + SetVar('j').lt(SetParam('n'))), + u"{i + j >= 10} ∩ {i <= m} ∩ {j < n}"), + (SetVar('i').lt(SetParam('n')).and_(SetVar('i').ge(0)), + u"{i < n} ∩ {i >= 0}"), + (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), + u"{i < n} ∩ {i*12 + 3 >= -2}"), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') - 9).ge(8)), + u"{i < n} ∩ {i + -9 >= 8}"), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') // 9).ge(8)), + u"{i < n} ∩ {i // 9 >= 8}"), + ]) +def test_loop_domain_expr_stringifier(expr, expected_str): + stringified = expr.make_stringifier()(expr) + assert stringified == expected_str diff --git a/test/test_transpose.py b/test/test_transpose.py index 7ea3af42a44a49ea8929413d369a11e285f6d2a6..3ea5516130752a3ce7e56f6b887821de161d2500 100644 --- a/test/test_transpose.py +++ b/test/test_transpose.py @@ -31,7 +31,10 @@ from pyopencl.tools import ( # noqa @pytest.mark.parametrize('test_shape,axes', [ + ((24, 13), None), ((2, 3, 4), None), + ((2, 3, 4), (1, 0, 2)), + ((2, 3, 4, 6), (0, 2, 1, 3)), ]) def test_transpose(ctx_factory, test_shape, axes, dtype=np.float64): ndim = len(test_shape)