From d12b1bc1a10ec99e00b6cca292c52057e36d25f8 Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 08:29:46 -0600 Subject: [PATCH 1/6] Add desc for domain_expr --- lappy/core/array.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lappy/core/array.py b/lappy/core/array.py index 2e16078..c3c6406 100644 --- a/lappy/core/array.py +++ b/lappy/core/array.py @@ -253,6 +253,12 @@ class Array(LazyObjectBase): index names for iterating over the array. + .. attribute:: domain_expr + + set expression for constructing the loop domain. A set expression + has leafs of :class:`lappy.core.primitives.PwAff` variables and + integers. + .. attribute:: dtype data type, either None or convertible to loopy type. Lappy -- GitLab From 962f19bf79e6caeae4324bb6281a3c8d060145b9 Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 17:01:39 -0600 Subject: [PATCH 2/6] Simple domain expr support (+, * of pwaffs) --- lappy/core/array.py | 10 +- lappy/core/compiler.py | 172 ++++++++++++++ lappy/core/mapper.py | 34 +++ lappy/core/primitives.py | 351 +++++++++++++++++++++++++++- lappy/core/ufuncs.py | 2 + test/test_loop_domain_expression.py | 25 ++ test/test_transpose.py | 3 + 7 files changed, 590 insertions(+), 7 deletions(-) create mode 100644 lappy/core/compiler.py diff --git a/lappy/core/array.py b/lappy/core/array.py index 32a9ee8..cce8311 100644 --- a/lappy/core/array.py +++ b/lappy/core/array.py @@ -257,7 +257,8 @@ class Array(LazyObjectBase): set expression for constructing the loop domain. A set expression has leafs of :class:`lappy.core.primitives.PwAff` variables and - integers. + integers. If domain_expr is None, the index domain of the array is + used. .. attribute:: dtype @@ -276,7 +277,7 @@ class Array(LazyObjectBase): _counter = 0 _name_prefix = '__lappy_array_' - def __init__(self, name=None, shape=None, **kwargs): + def __init__(self, name=None, shape=None, domain_expr=None, **kwargs): # default expr if 'expr' not in kwargs: kwargs['expr'] = None @@ -286,13 +287,14 @@ class Array(LazyObjectBase): super(Array, self).__init__(**kwargs) self._ndim = to_lappy_unsigned(kwargs.pop("ndim", None)) + self.domain_expr = domain_expr if self.ndim == 0: self.name = self.name.replace('array', 'number') if self.ndim is None: # infer ndim from given shape - # NOTE: use fixed ndim only (for the moment) + # NOTE: use fixed ndim only shape = to_lappy_shape(shape) self._ndim = to_lappy_unsigned(int(len(shape))) @@ -333,6 +335,7 @@ class Array(LazyObjectBase): repr_dict.update({ 'ndim': self.ndim, 'inames': self.inames, + 'domain': self.domain_expr, 'shape': self.shape, 'dtype': self.dtype, 'is_integral': self.is_integral, @@ -512,6 +515,7 @@ class Array(LazyObjectBase): inames=self.inames, shape=self._shape, dtype=self.dtype, expr=self.expr, is_integral=self.is_integral, + domain_expr=self.domain_expr, arguments=self.arguments.copy(), bound_arguments=self.bound_arguments.copy(), intermediaries=self.intermediaries.copy(), diff --git a/lappy/core/compiler.py b/lappy/core/compiler.py new file mode 100644 index 0000000..c1b2cad --- /dev/null +++ b/lappy/core/compiler.py @@ -0,0 +1,172 @@ +__copyright__ = "Copyright (C) 2019 Xiaoyu Wei" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +import islpy as isl +from pymbolic.mapper import Mapper +from lappy.core.mapper import SetVariableCollector, SetParameterCollector +from lappy.core.primitives import EMPTY_SET, UNIVERSE_SET + +# {{{ loop domain compiler + + +class LoopDomainCompiler(Mapper): + """Compiles a set expression in to an ISL set. + """ + def map_constant(self, expr, var_dict): + if expr == 0: + return var_dict[0] + return expr + + def map_set_variable(self, expr, var_dict): + return var_dict[expr.name] + + def map_set_parameter(self, expr, var_dict): + return var_dict[expr.name] + + def map_sum(self, expr, var_dict): + res = var_dict[0] + for c in expr.children: + res += self.rec(c, var_dict) + return res + + def map_pwaff_comparison(self, expr, var_dict): + lhs = self.rec(expr.left, var_dict) + rhs = self.rec(expr.right, var_dict) + op = expr.operator + + print(type(lhs), type(rhs)) + + if isinstance(lhs, int) and isinstance(rhs, int): + exec('valid = (%s %s %s)' % (str(lhs), str(op), str(rhs))) + if valid: # noqa: F821 + return 1 # stands for the whole set + else: + return 0 # stands for the empty set + + if isinstance(lhs, int) and isinstance(rhs, isl.PwAff): + lhs, rhs = rhs, lhs + + if isinstance(lhs, isl.PwAff) and isinstance(rhs, int): + lhs = lhs - rhs + rhs = var_dict[0] + + if isinstance(lhs, isl.PwAff) and isinstance(rhs, isl.PwAff): + if op == '==': + return lhs.eq_set(rhs) + if op == '!=': + return lhs.ne_set(rhs) + if op == '<=': + return lhs.le_set(rhs) + if op == '<': + return lhs.lt_set(rhs) + if op == '>=': + return lhs.ge_set(rhs) + if op == '>': + return lhs.gt_set(rhs) + raise ValueError('unknown operator "%s"' % str(op)) + + else: + raise ValueError('cannot compile operator "%s" (lhs=%s, rhs=%s)' + % (op, str(expr.left), str(expr.right))) + + def map_pwaff_sum(self, expr, var_dict): + """Mapper for sum of pwaffs. + """ + return sum([self.rec(child, var_dict) for child in expr.children], 0) + + def map_pwaff_product(self, expr, var_dict): + """Mapper for product of pwaffs. + """ + res = 1 + for rc in [self.rec(child, var_dict) for child in expr.children]: + res *= rc + return res + + def map_set_union(self, expr, var_dict): + """Mapper for set union. + + Singular sets (empty set and its complement) are represented with ints. + """ + children = [self.rec(child, var_dict) for child in expr.children] + + for c in children: + if isinstance(c, int) and c == UNIVERSE_SET: + return UNIVERSE_SET + + children = [c for c in children if not isinstance(c, int)] + + if len(children) > 0: + res = children[0] + else: + res = UNIVERSE_SET + + for c in children: + assert isinstance(res, isl.Set) and isinstance(c, isl.Set) + res = res | c + + return res + + def map_set_intersection(self, expr, var_dict): + """Mapper for set intersection. + + Singular sets (empty set and its complement) are represented with ints. + """ + children = [self.rec(child, var_dict) for child in expr.children] + + for c in children: + if isinstance(c, int) and c == EMPTY_SET: + return EMPTY_SET + + children = [c for c in children if not isinstance(c, int)] + + if len(children) > 0: + res = children[0] + else: + res = UNIVERSE_SET + + for c in children: + assert isinstance(res, isl.Set) and isinstance(c, isl.Set) + res = res & c + + return res + + def __call__(self, expr): + """Returns a list of basic sets representing the loop domain. + + The variables and parameter are lexicographically ordered. + """ + var = sorted(list(SetVariableCollector()(expr))) + param = sorted(list(SetParameterCollector()(expr))) + var_dict = isl.make_zero_and_vars(var, param) + + domain_pre = super(LoopDomainCompiler, self).__call__(expr, var_dict) + if isinstance(domain_pre, isl.PwAff): + # unconstrained domain + raise ValueError("unconstrained (thus unbounded) loop domain") + + if isinstance(domain_pre, isl.Set): + domain = domain_pre.coalesce().get_basic_sets() + else: + raise RuntimeError() + + return domain + +# }}} End loop domain compiler diff --git a/lappy/core/mapper.py b/lappy/core/mapper.py index e790dfb..fa6f724 100644 --- a/lappy/core/mapper.py +++ b/lappy/core/mapper.py @@ -22,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from pymbolic.mapper import Collector +from pymbolic.mapper.stringifier import StringifyMapper, PREC_SUM class SetVariableCollector(Collector): @@ -30,7 +31,13 @@ class SetVariableCollector(Collector): def map_set_variable(self, expr): return {expr.name, } + map_set_zero = Collector.map_constant map_set_parameter = Collector.map_constant + map_pwaff_comparison = Collector.map_comparison + map_pwaff_sum = Collector.map_sum + map_pwaff_product = Collector.map_sum + map_set_union = Collector.map_sum + map_set_intersection = Collector.map_sum class SetParameterCollector(Collector): @@ -39,4 +46,31 @@ class SetParameterCollector(Collector): def map_set_parameter(self, expr): return {expr.name, } + map_set_zero = Collector.map_constant map_set_variable = Collector.map_constant + map_pwaff_comparison = Collector.map_comparison + map_pwaff_sum = Collector.map_sum + map_pwaff_product = Collector.map_sum + map_set_union = Collector.map_sum + map_set_intersection = Collector.map_sum + + +class SetExpressionStringifyMapper(StringifyMapper): + def map_set_variable(self, expr, *args, **kwargs): + return expr.name + + def map_set_parameter(self, expr, *args, **kwargs): + return expr.name + + def map_set_zero(self, expr, *args, **kwargs): + return expr.name + + def map_set_union(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(" ∪ ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) + + def map_set_intersection(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(" ∩ ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) diff --git a/lappy/core/primitives.py b/lappy/core/primitives.py index a13a275..f11e4c7 100644 --- a/lappy/core/primitives.py +++ b/lappy/core/primitives.py @@ -25,9 +25,11 @@ THE SOFTWARE. from six.moves import intern from pymbolic.primitives import AlgebraicLeaf -from pymbolic.primitives import Variable +from pymbolic.primitives import Expression from pymbolic.primitives import QuotientBase -from pymbolic.primitives import make_sym_vector # noqa: F401 +from pymbolic.primitives import Variable + +from lappy.core.mapper import SetExpressionStringifyMapper # {{{ Array computation primitives @@ -91,11 +93,297 @@ class Reduction(AlgebraicLeaf): # {{{ Integer set computation primitives -class PwAff(Variable): +EMPTY_SET = 0 +UNIVERSE_SET = 1 # complement of the empty set + + +class SetExpression(Expression): + """Expression with replaced stringifier. + + The expression tree has two layers, a lower layer of PwAffs and an upper layer of + Sets. + The leaf nodes are PwAffs, and the internal layer above the comparison nodes are + Sets. All nodes further above are also Sets. + """ + # {{{ disable unused methods for easier debugging + + def __add__(self, *args, **kwargs): + raise NotImplementedError() + + def __radd__(self, *args, **kwargs): + raise NotImplementedError() + + def __sub__(self, *args, **kwargs): + raise NotImplementedError() + + def __rsub__(self, *args, **kwargs): + raise NotImplementedError() + + def __mul__(self, *args, **kwargs): + raise NotImplementedError() + + def __rmul__(self, *args, **kwargs): + raise NotImplementedError() + + def __div__(self, *args, **kwargs): + raise NotImplementedError() + + def __floordiv__(self, *args, **kwargs): + raise NotImplementedError() + + def __rfloordiv__(self, *args, **kwargs): + raise NotImplementedError() + + def __mod__(self, *args, **kwargs): + raise NotImplementedError() + + def __rmod__(self, *args, **kwargs): + raise NotImplementedError() + + def __pow__(self, *args, **kwargs): + raise NotImplementedError() + + def __rpow__(self, *args, **kwargs): + raise NotImplementedError() + + def __lshift__(self, other): + raise NotImplementedError() + + def __rlshift__(self, other): + raise NotImplementedError() + + def __rshift__(self, other): + raise NotImplementedError() + + def __rrshift__(self, other): + raise NotImplementedError() + + def __inv__(self): + raise NotImplementedError() + + def __or__(self, other): + raise NotImplementedError() + + def __ror__(self, other): + raise NotImplementedError() + + def __xor__(self, other): + raise NotImplementedError() + + def __rxor__(self, other): + raise NotImplementedError() + + def __and__(self, other): + raise NotImplementedError() + + def __rand__(self, other): + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def eq(self, *args, **kwargs): + raise NotImplementedError() + + def ne(self, *args, **kwargs): + raise NotImplementedError() + + def le(self, *args, **kwargs): + raise NotImplementedError() + + def lt(self, *args, **kwargs): + raise NotImplementedError() + + def ge(self, *args, **kwargs): + raise NotImplementedError() + + def gt(self, *args, **kwargs): + raise NotImplementedError() + + def not_(self, *args, **kwargs): + raise NotImplementedError() + + # }}} End disabled methods + + def make_stringifier(self, originating_stringifier=None): + return SetExpressionStringifyMapper() + + def or_(self, other): + if isinstance(self, PwAff) or isinstance(other, PwAff): + raise ValueError("cannot compute set union of %s and %s" + % (str(self), str(other))) + if isinstance(other, SetUnion): + return SetUnion((self, ) + other.children) + else: + return SetUnion((self, other)) + + def and_(self, other): + if isinstance(self, PwAff) or isinstance(other, PwAff): + raise ValueError("cannot compute set intersection of %s and %s" + % (str(self), str(other))) + if isinstance(other, SetIntersection): + return SetIntersection((self, ) + other.children) + else: + return SetIntersection((self, other)) + + def _make_pwaff_comparison(self, op, other): + if not isinstance(self, PwAff) and isinstance(other, PwAff): + raise ValueError("cannot do pwaff comparison with %s and %s" + % (str(self), str(other))) + return PwAffComparison(self, op, other) + + +class PwAffComparison(SetExpression): + """PwAffs form sets via the comparison node. + + .. attribute:: left + .. attribute:: operator + + One of ``[">", ">=", "==", "!=", "<", "<="]``. + + .. attribute:: right + """ + + init_arg_names = ("left", "operator", "right") + + def __init__(self, left, operator, right): + self.left = left + self.right = right + if operator not in [">", ">=", "==", "!=", "<", "<="]: + raise RuntimeError("invalid operator") + self.operator = operator + + def __getinitargs__(self): + return self.left, self.operator, self.right + + mapper_method = intern("map_pwaff_comparison") + + +class _MultiChildSetExpression(SetExpression): + init_arg_names = ("children",) + + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children, + + +class SetUnion(_MultiChildSetExpression): + """ + .. attribute:: children + + A :class:`tuple`. + """ + def or_(self, other): + if isinstance(other, PwAff): + raise ValueError() + if isinstance(other, SetUnion): + return SetUnion(self.children + other.children) + else: + return SetUnion(self.children + (other, )) + + mapper_method = intern("map_set_union") + + +class SetIntersection(_MultiChildSetExpression): + """ + .. attribute:: children + + A :class:`tuple`. + """ + def and_(self, other): + if isinstance(other, PwAff): + raise ValueError() + if isinstance(other, SetIntersection): + return SetIntersection(self.children + other.children) + else: + return SetIntersection(self.children + (other, )) + + mapper_method = intern("map_set_intersection") + + +class PwAff(SetExpression): """PwAff: variables used to construct piecewise-affine expressions. + + .. attribute:: name """ + init_arg_names = ("name",) + mapper_method = intern('map_pwaff') + + def __init__(self, name): + assert name + self.name = intern(name) + + def __getinitargs__(self): + return self.name, + + def __lt__(self, other): + if isinstance(other, Variable): + return self.name.__lt__(other.name) + else: + return NotImplemented + + def __setstate__(self, val): + super(Variable, self).__setstate__(val) + + self.name = intern(self.name) + + def le(self, other): + return self._make_pwaff_comparison('<=', other) + + def lt(self, other): + return self._make_pwaff_comparison('<', other) + + def ge(self, other): + return self._make_pwaff_comparison('>=', other) + + def gt(self, other): + return self._make_pwaff_comparison('>', other) + + def eq(self, other): + return self._make_pwaff_comparison('==', other) + + def ne(self, other): + return self._make_pwaff_comparison('!=', other) - mapper_method = 'map_pwaff' + def __add__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum((self,) + other.children) + else: + return PwAffSum((self, other)) + + def __radd__(self, other): + assert isinstance(other, int) + return PwAffSum((other, self)) + + def __mul__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffProduct): + return PwAffProduct((self,) + other.children) + else: + return PwAffProduct((self, other)) + + def __rmul__(self, other): + assert isinstance(other, int) + return PwAffProduct((other, self)) + + +class _MultiChildPwAff(PwAff): + init_arg_names = ("children",) + + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children, class SetVar(PwAff): @@ -111,4 +399,59 @@ class SetParam(PwAff): mapper_method = 'map_set_parameter' + +class SetZero(PwAff): + """The constant 0. + """ + + mapper_method = 'map_set_zero' + + +class PwAffSum(_MultiChildPwAff): + """Sum of PwAffs. + + .. attribute:: children + + A :class:`tuple`. + """ + + def __add__(self, other): + if not isinstance(other, (int, PwAff)): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum(self.children + other.children) + return PwAffSum(self.children + (other,)) + + def __radd__(self, other): + if not isinstance(other, int): + return ValueError() + if isinstance(other, PwAffSum): + return PwAffSum(other.children + self.children) + return PwAffSum((other,) + self.children) + + mapper_method = intern("map_pwaff_sum") + + +class PwAffProduct(_MultiChildPwAff): + """Product of PwAffs. Up to one non-constant child is allowed. + + .. attribute:: children + + A :class:`tuple`. + """ + def __mul__(self, other): + if not isinstance(other, (PwAff, int)): + return ValueError() + if isinstance(other, PwAffProduct): + return PwAffProduct(self.children + other.children) + return PwAffProduct(self.children + (other,)) + + def __rmul__(self, other): + if not isinstance(other, int): + return ValueError() + return PwAffProduct((other,) + self.children) + + mapper_method = intern("map_pwaff_product") + + # }}} End Integer set computation primitives diff --git a/lappy/core/ufuncs.py b/lappy/core/ufuncs.py index 19065d8..8a958f9 100644 --- a/lappy/core/ufuncs.py +++ b/lappy/core/ufuncs.py @@ -121,6 +121,7 @@ class UnaryOperation(UFunc): 'name': name, 'inames': a.inames, 'expr': self.f(a.expr), 'value': None, + 'domain_expr': a.domain_expr, 'arguments': a.arguments.copy(), 'bound_arguments': a.bound_arguments.copy(), 'env': a.env.copy(), @@ -288,6 +289,7 @@ class BinaryOperation(UFunc): 'name': name, 'inames': a.inames, 'expr': self.f(a.expr, b.expr), 'value': None, + 'domain_expr': a.domain_expr, 'arguments': new_arglist, 'bound_arguments': new_bound_arglist, 'intermediaries': new_interm, diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index 0dd3975..2726c0b 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -23,8 +23,10 @@ THE SOFTWARE. """ import pytest +import islpy as isl from lappy.core.primitives import SetVar, SetParam from lappy.core.mapper import SetVariableCollector, SetParameterCollector +from lappy.core.compiler import LoopDomainCompiler @pytest.mark.parametrize('expr, var', [ @@ -45,3 +47,26 @@ def test_variable_collector(expr, var): def test_parameter_collector(expr, var): collector = SetParameterCollector() assert collector(expr) == var + + +@pytest.mark.parametrize('expr,expected_set', [ + ((SetVar('i') + SetVar('j')).ge(10).and_( + SetVar('i').le(SetParam('m'))).and_( + SetVar('j').lt(SetParam('n'))), + ["[m, n] -> { [i, j]: i <= m and j < n and i + j >= 10}"]), + (SetVar('i').lt(SetParam('n')).and_(SetVar('i').ge(0)), + ["[n] -> { [i]: i < n and i >= 0 }"]), + (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), + ["[n] -> { [i]: i < n and i * 12 + 3 >= -2 }"]), + ]) +def test_loop_domain_compiler(expr, expected_set): + + compiler = LoopDomainCompiler() + domain = compiler(expr) + + print(domain) + expected_domain = [ + isl.BasicSet.read_from_str(isl.DEFAULT_CONTEXT, es) + for es in expected_set] + + assert domain == expected_domain diff --git a/test/test_transpose.py b/test/test_transpose.py index 7ea3af4..3ea5516 100644 --- a/test/test_transpose.py +++ b/test/test_transpose.py @@ -31,7 +31,10 @@ from pyopencl.tools import ( # noqa @pytest.mark.parametrize('test_shape,axes', [ + ((24, 13), None), ((2, 3, 4), None), + ((2, 3, 4), (1, 0, 2)), + ((2, 3, 4, 6), (0, 2, 1, 3)), ]) def test_transpose(ctx_factory, test_shape, axes, dtype=np.float64): ndim = len(test_shape) -- GitLab From be5c48e54f7a093eb37df33d2ab6ebdb759c4f15 Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 22:43:47 -0600 Subject: [PATCH 3/6] Add floor_div and mod for pwaffs --- lappy/core/compiler.py | 8 +++++ lappy/core/mapper.py | 4 +++ lappy/core/primitives.py | 52 +++++++++++++++++++++++++++++ test/test_loop_domain_expression.py | 6 ++++ 4 files changed, 70 insertions(+) diff --git a/lappy/core/compiler.py b/lappy/core/compiler.py index c1b2cad..64f3513 100644 --- a/lappy/core/compiler.py +++ b/lappy/core/compiler.py @@ -100,6 +100,14 @@ class LoopDomainCompiler(Mapper): res *= rc return res + def map_pwaff_floor_div(self, expr, var_dict): + return (self.rec(expr.numerator, var_dict) + // self.rec(expr.denominator, var_dict)) + + def map_pwaff_remainder(self, expr, var_dict): + return (self.rec(expr.numerator, var_dict) + % self.rec(expr.denominator, var_dict)) + def map_set_union(self, expr, var_dict): """Mapper for set union. diff --git a/lappy/core/mapper.py b/lappy/core/mapper.py index fa6f724..024d844 100644 --- a/lappy/core/mapper.py +++ b/lappy/core/mapper.py @@ -38,6 +38,8 @@ class SetVariableCollector(Collector): map_pwaff_product = Collector.map_sum map_set_union = Collector.map_sum map_set_intersection = Collector.map_sum + map_pwaff_floor_div = Collector.map_quotient + map_pwaff_remainder = Collector.map_quotient class SetParameterCollector(Collector): @@ -53,6 +55,8 @@ class SetParameterCollector(Collector): map_pwaff_product = Collector.map_sum map_set_union = Collector.map_sum map_set_intersection = Collector.map_sum + map_pwaff_floor_div = Collector.map_quotient + map_pwaff_remainder = Collector.map_quotient class SetExpressionStringifyMapper(StringifyMapper): diff --git a/lappy/core/primitives.py b/lappy/core/primitives.py index f11e4c7..403c205 100644 --- a/lappy/core/primitives.py +++ b/lappy/core/primitives.py @@ -128,6 +128,12 @@ class SetExpression(Expression): def __div__(self, *args, **kwargs): raise NotImplementedError() + def __truediv__(self, *args, **kwargs): + return self.__div__(*args, **kwargs) + + def __rdiv__(self, *args, **kwargs): + raise NotImplementedError() + def __floordiv__(self, *args, **kwargs): raise NotImplementedError() @@ -361,6 +367,22 @@ class PwAff(SetExpression): assert isinstance(other, int) return PwAffSum((other, self)) + def __sub__(self, other): + return self + (-1) * other + + def __rsub__(self, other): + return self + (-1) * other + + def __floordiv__(self, other): + if not isinstance(other, int) and other > 0: + return ValueError() + return PwAffFloorDiv(self, other) + + def __mod__(self, other): + if not isinstance(other, int) and other > 0: + return ValueError() + return PwAffRemainder(self, other) + def __mul__(self, other): if not isinstance(other, (PwAff, int)): return ValueError() @@ -454,4 +476,34 @@ class PwAffProduct(_MultiChildPwAff): mapper_method = intern("map_pwaff_product") +class PwAffQuotientBase(PwAff): + init_arg_names = ("numerator", "denominator",) + + def __init__(self, numerator, denominator=1): + self.numerator = numerator + self.denominator = denominator + + def __getinitargs__(self): + return self.numerator, self.denominator + + mapper_method = None + + +class PwAffFloorDiv(PwAffQuotientBase): + """ + .. attribute:: numerator + .. attribute:: denominator + """ + + mapper_method = intern("map_pwaff_floor_div") + + +class PwAffRemainder(PwAffQuotientBase): + """ + .. attribute:: numerator + .. attribute:: denominator + """ + + mapper_method = intern("map_pwaff_remainder") + # }}} End Integer set computation primitives diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index 2726c0b..81551ff 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -58,6 +58,12 @@ def test_parameter_collector(expr, var): ["[n] -> { [i]: i < n and i >= 0 }"]), (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), ["[n] -> { [i]: i < n and i * 12 + 3 >= -2 }"]), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') - 9).ge(8)), + ["[n] -> { [i]: i < n and i - 9 >= 8 }"]), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') // 9).ge(8)), + ["[n] -> { [i]: i < n and i / 9 >= 8 }"]), ]) def test_loop_domain_compiler(expr, expected_set): -- GitLab From 5713eabe96f144a6f3bc622637af49827827f8da Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 22:59:21 -0600 Subject: [PATCH 4/6] Fix the stringifier for the loop domain expr --- lappy/core/mapper.py | 25 ++++++++++++++++++++++++- test/test_loop_domain_expression.py | 21 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/lappy/core/mapper.py b/lappy/core/mapper.py index 024d844..7cf8082 100644 --- a/lappy/core/mapper.py +++ b/lappy/core/mapper.py @@ -22,7 +22,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from pymbolic.mapper import Collector -from pymbolic.mapper.stringifier import StringifyMapper, PREC_SUM +from pymbolic.mapper.stringifier import ( + StringifyMapper, + PREC_SUM, PREC_COMPARISON, PREC_NONE, PREC_PRODUCT) class SetVariableCollector(Collector): @@ -78,3 +80,24 @@ class SetExpressionStringifyMapper(StringifyMapper): return self.parenthesize_if_needed( self.join_rec(" ∩ ", expr.children, PREC_SUM, *args, **kwargs), enclosing_prec, PREC_SUM) + + def map_pwaff_comparison(self, expr, enclosing_prec, *args, **kwargs): + return self.format("{%s %s %s}", + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), + expr.operator, + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)) + + def map_pwaff_floor_div(self, expr, enclosing_prec, *args, **kwargs): + return self.format("%s // %s", + self.rec(expr.numerator, PREC_NONE, *args, **kwargs), + self.rec(expr.denominator, PREC_NONE, *args, **kwargs)) + + def map_pwaff_sum(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), + enclosing_prec, PREC_SUM) + + def map_pwaff_product(self, expr, enclosing_prec, *args, **kwargs): + return self.parenthesize_if_needed( + self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, PREC_PRODUCT) diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index 81551ff..e81dcd5 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -76,3 +76,24 @@ def test_loop_domain_compiler(expr, expected_set): for es in expected_set] assert domain == expected_domain + + +@pytest.mark.parametrize('expr,expected_str', [ + ((SetVar('i') + SetVar('j')).ge(10).and_( + SetVar('i').le(SetParam('m'))).and_( + SetVar('j').lt(SetParam('n'))), + "{i + j >= 10} ∩ {i <= m} ∩ {j < n}"), + (SetVar('i').lt(SetParam('n')).and_(SetVar('i').ge(0)), + "{i < n} ∩ {i >= 0}"), + (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), + "{i < n} ∩ {i*12 + 3 >= -2}"), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') - 9).ge(8)), + "{i < n} ∩ {i + -9 >= 8}"), + (SetVar('i').lt(SetParam('n')).and_( + (SetVar('i') // 9).ge(8)), + "{i < n} ∩ {i // 9 >= 8}"), + ]) +def test_loop_domain_expr_stringifier(expr, expected_str): + stringified = str(expr) + assert stringified == expected_str -- GitLab From a9d66cf0f01ab4245b758fe71c2efa5c235d8d36 Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 23:06:30 -0600 Subject: [PATCH 5/6] Py2 unicode strings --- lappy/core/mapper.py | 5 +++-- test/test_loop_domain_expression.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/lappy/core/mapper.py b/lappy/core/mapper.py index 7cf8082..56d4065 100644 --- a/lappy/core/mapper.py +++ b/lappy/core/mapper.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import division, absolute_import, print_function copyright__ = "Copyright (C) 2017 Sophia Lin and Andreas Kloeckner" @@ -73,12 +74,12 @@ class SetExpressionStringifyMapper(StringifyMapper): def map_set_union(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" ∪ ", expr.children, PREC_SUM, *args, **kwargs), + self.join_rec(u" ∪ ", expr.children, PREC_SUM, *args, **kwargs), enclosing_prec, PREC_SUM) def map_set_intersection(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" ∩ ", expr.children, PREC_SUM, *args, **kwargs), + self.join_rec(u" ∩ ", expr.children, PREC_SUM, *args, **kwargs), enclosing_prec, PREC_SUM) def map_pwaff_comparison(self, expr, enclosing_prec, *args, **kwargs): diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index e81dcd5..9201244 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import division, absolute_import, print_function __copyright__ = "Copyright (C) 2019 Xiaoyu Wei" @@ -82,17 +83,17 @@ def test_loop_domain_compiler(expr, expected_set): ((SetVar('i') + SetVar('j')).ge(10).and_( SetVar('i').le(SetParam('m'))).and_( SetVar('j').lt(SetParam('n'))), - "{i + j >= 10} ∩ {i <= m} ∩ {j < n}"), + u"{i + j >= 10} ∩ {i <= m} ∩ {j < n}"), (SetVar('i').lt(SetParam('n')).and_(SetVar('i').ge(0)), - "{i < n} ∩ {i >= 0}"), + u"{i < n} ∩ {i >= 0}"), (SetVar('i').lt(SetParam('n')).and_((SetVar('i') * 12 + 3).ge(-2)), - "{i < n} ∩ {i*12 + 3 >= -2}"), + u"{i < n} ∩ {i*12 + 3 >= -2}"), (SetVar('i').lt(SetParam('n')).and_( (SetVar('i') - 9).ge(8)), - "{i < n} ∩ {i + -9 >= 8}"), + u"{i < n} ∩ {i + -9 >= 8}"), (SetVar('i').lt(SetParam('n')).and_( (SetVar('i') // 9).ge(8)), - "{i < n} ∩ {i // 9 >= 8}"), + u"{i < n} ∩ {i // 9 >= 8}"), ]) def test_loop_domain_expr_stringifier(expr, expected_str): stringified = str(expr) -- GitLab From 9823d9c995ceca7b18ce60c344f4c34f4ad1498a Mon Sep 17 00:00:00 2001 From: xywei Date: Fri, 6 Dec 2019 23:19:20 -0600 Subject: [PATCH 6/6] Do not convert to str for Py2 test --- test/test_loop_domain_expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_loop_domain_expression.py b/test/test_loop_domain_expression.py index 9201244..2618b34 100644 --- a/test/test_loop_domain_expression.py +++ b/test/test_loop_domain_expression.py @@ -96,5 +96,5 @@ def test_loop_domain_compiler(expr, expected_set): u"{i < n} ∩ {i // 9 >= 8}"), ]) def test_loop_domain_expr_stringifier(expr, expected_str): - stringified = str(expr) + stringified = expr.make_stringifier()(expr) assert stringified == expected_str -- GitLab