From aaee8804eaf4ee9a2da43708c16e6a1ba1db355f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Jan 2023 19:59:06 -0600 Subject: [PATCH] Base expressions on dataclasses --- doc/index.rst | 14 +- pymbolic/parser.py | 4 +- pymbolic/primitives.py | 707 +++++++++++++++++++---------------------- setup.py | 65 ++-- 4 files changed, 368 insertions(+), 422 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 6dce626..047a424 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -69,13 +69,13 @@ You can also easily define your own objects to use inside an expression: .. doctest:: - >>> from pymbolic.primitives import Expression - >>> class FancyOperator(Expression): - ... def __init__(self, operand): - ... self.operand = operand - ... - ... def __getinitargs__(self): - ... return (self.operand,) + >>> from pymbolic.primitives import Expression, augment_expression_dataclass + >>> from dataclasses import dataclass + >>> + >>> @augment_expression_dataclass + ... @dataclass(frozen=True) + ... class FancyOperator(Expression): + ... operand: Expression ... ... mapper_method = "map_fancy_operator" ... diff --git a/pymbolic/parser.py b/pymbolic/parser.py index e885c7d..1e6d527 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import immutables import pytools.lex from pytools import memoize_method from sys import intern @@ -330,7 +331,8 @@ class Parser: args, kwargs = self.parse_arglist(pstate) if kwargs: - left_exp = primitives.CallWithKwargs(left_exp, args, kwargs) + left_exp = primitives.CallWithKwargs( + left_exp, args, immutables.Map(kwargs)) else: left_exp = primitives.Call(left_exp, args) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 329dd88..c98fe49 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" __license__ = """ @@ -20,9 +22,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Union, Mapping, Type, Optional, Callable, Any, Tuple from sys import intern -from abc import ABC, abstractmethod import pymbolic.traits as traits +from warnings import warn + +from dataclasses import dataclass, fields + + +# FIXME: This is a lie. Many more constant types (e.g. numpy and such) +# are in practical use and completely fine. We cannot really add in numpy +# as a special case (because pymbolic doesn't have a hard numpy dependency), +# and there isn't a usable numerical tower that we could rely on. As such, +# code abusing what constants are allowable will have to type-ignore those +# statements. Better ideas would be most welcome. +# +# References: +# https://github.com/python/mypy/issues/3186 +# https://discuss.python.org/t/numeric-generics-where-do-we-go-from-pep-3141-and-present-day-mypy/17155/14 +_ConstantT = Union[int, float, complex] +ExpressionT = Union[_ConstantT, "Expression", Tuple["ExpressionT", ...]] __doc__ = """ @@ -31,6 +50,14 @@ Expression base class .. autoclass:: Expression +.. class:: ExpressionT + + A type that can be used in type annotations whenever an expression + is desired. A :class:`typing.Union` of :class:`Expression` and + built-in scalar types. + +.. autofunction:: augment_expression_dataclass + Sums, products and such ----------------------- @@ -186,7 +213,7 @@ def disable_subscript_by_getitem(): pass -class Expression(ABC): +class Expression: """Superclass for parts of a mathematical expression. Overrides operators to implicitly construct :class:`Sum`, :class:`Product` and other expressions. @@ -213,9 +240,7 @@ class Expression(ABC): .. automethod:: make_stringifier .. automethod:: __eq__ - .. automethod:: is_equal .. automethod:: __hash__ - .. automethod:: get_hash .. automethod:: __str__ .. automethod:: __repr__ @@ -237,9 +262,8 @@ class Expression(ABC): # {{{ init arg names (override by subclass) - @abstractmethod def __getinitargs__(self): - pass + raise NotImplementedError @classmethod @property @@ -438,7 +462,8 @@ class Expression(ABC): def __call__(self, *args, **kwargs): if kwargs: - return CallWithKwargs(self, args, kwargs) + from immutables import Map + return CallWithKwargs(self, args, Map(kwargs)) else: return Call(self, args) @@ -536,10 +561,13 @@ class Expression(ABC): def __eq__(self, other): """Provides equality testing with quick positive and negative paths based on :func:`id` and :meth:`__hash__`. - - Subclasses should generally not override this method, but instead - provide an implementation of :meth:`is_equal`. """ + from warnings import warn + warn(f"Expression.__eq__ is used by {self.__class__}. This is deprecated. " + "Use equality comparison supplied by augment_expression_dataclass " + "instead. " + "This will stop working in 2024.", + DeprecationWarning, stacklevel=2) if self is other: return True elif hash(self) != hash(other): @@ -552,10 +580,13 @@ class Expression(ABC): def __hash__(self): """Provides caching for hash values. - - Subclasses should generally not override this method, but instead - provide an implementation of :meth:`get_hash`. """ + from warnings import warn + warn(f"Expression.__hash__ is used by {self.__class__}. This is deprecated. " + "Use hash functions supplied by augment_expression_dataclass instead. " + "This will stop working in 2024.", + DeprecationWarning, stacklevel=2) + try: return self._hash_value except AttributeError: @@ -569,7 +600,7 @@ class Expression(ABC): # Can't use trivial pickling: _hash_value cache must stay unset assert len(self.init_arg_names) == len(state), type(self) for name, value in zip(self.init_arg_names, state): - setattr(self, name, value) + object.__setattr__(self, name, value) # }}} @@ -681,6 +712,116 @@ class Expression(ABC): raise TypeError("expression types are not iterable") +# {{{ dataclasses support + +def augment_expression_dataclass(cls: Type[Expression]) -> Type[Expression]: + """A class decorator for :func:`dataclasses.dataclass`-derived + :class:`Expression` nodes. It adds cached hashing, equality comparisons + with ``self is other`` shortcuts as well as some methods/attributes + for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``) + """ + attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) + if attr_tuple: + attr_tuple = f"({attr_tuple},)" + else: + attr_tuple = "()" + + fld_name_tuple = ", ".join(f"'{fld.name}'" for fld in fields(cls)) + if fld_name_tuple: + fld_name_tuple = f"({fld_name_tuple},)" + else: + fld_name_tuple = "()" + + comparison = " and ".join( + f"self.{fld.name} == other.{fld.name}" + for fld in fields(cls)) + + if not comparison: + comparison = "True" + + from pytools.codegen import remove_common_indentation + augment_code = remove_common_indentation( + f""" + from warnings import warn + + + def {cls.__name__}_eq(self, other): + if self is other: + return True + if self.__class__ is not other.__class__: + return False + if hash(self) != hash(other): + return False + if self.__class__ is not cls and self.init_arg_names != {fld_name_tuple}: + warn(f"{{self.__class__}} is derived from {cls}, which is now " + f"a dataclass. {{self.__class__}} should be converted to being " + "a dataclass as well. Non-dataclass subclasses " + "will stop working in 2024.", + DeprecationWarning) + + return self.is_equal(other) + + return self.__class__ == other.__class__ and {comparison} + + cls.__eq__ = {cls.__name__}_eq + + + def {cls.__name__}_hash(self): + try: + return self._hash_value + except AttributeError: + pass + + if self.__class__ is not cls and self.init_arg_names != {fld_name_tuple}: + warn(f"{{self.__class__}} is derived from {cls}, which is now " + f"a dataclass. {{self.__class__}} should be converted to being " + "a dataclass as well. Non-dataclass subclasses " + "will stop working in 2024.", + DeprecationWarning) + + hash_val = self.get_hash() + else: + hash_val = hash({attr_tuple}) + + object.__setattr__(self, "_hash_value", hash_val) + return hash_val + + cls.__hash__ = {cls.__name__}_hash + + + def {cls.__name__}_init_arg_names(self): + warn("__getinitargs__ is deprecated and will be removed in 2024. " + "Use dataclasses.fields instead.", + DeprecationWarning, stacklevel=2) + + return {fld_name_tuple} + + cls.init_arg_names = property({cls.__name__}_init_arg_names) + + + def {cls.__name__}_getinitargs(self): + warn("__getinitargs__ is deprecated and will be removed in 2024. " + "Use dataclasses.fields instead.", + DeprecationWarning, stacklevel=2) + + return {attr_tuple} + + cls.__getinitargs__ = {cls.__name__}_getinitargs + + + # FIXME Also implement pickling, with fallback + """) + + exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code} + exec(compile(augment_code, + f"", "exec"), + exec_dict) + + return cls + +# }}} + + class AlgebraicLeaf(Expression): """An expression that serves as a leaf for arithmetic evaluation. This may end up having child nodes still, but they're not reached by @@ -694,73 +835,50 @@ class Leaf(AlgebraicLeaf): pass +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Variable(Leaf): """ .. attribute:: name """ - init_arg_names = ("name",) + name: str - def __init__(self, name): - assert name - self.name = intern(name) - - def __getinitargs__(self): - return self.name, - - def __lt__(self, other): - if isinstance(other, Variable): - return self.name.__lt__(other.name) - else: - return NotImplemented - - def __setstate__(self, val): - super().__setstate__(val) - - self.name = intern(self.name) + # FIXME: Missing intern(): does it matter? mapper_method = intern("map_variable") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Wildcard(Leaf): - def __getinitargs__(self): - return () - mapper_method = intern("map_wildcard") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class DotWildcard(Leaf): """ A wildcard that can be substituted for a single expression. """ - init_arg_names = ("name",) - - def __init__(self, name): - assert isinstance(name, str) - self.name = name - - def __getinitargs__(self): - return self.name, + name: str mapper_method = intern("map_dot_wildcard") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class StarWildcard(Leaf): """ A wildcard that can be substituted by a sequence of expressions of non-negative length. """ - init_arg_names = ("name",) - - def __init__(self, name): - assert isinstance(name, str) - self.name = name - - def __getinitargs__(self): - return self.name, + name: str mapper_method = intern("map_star_wildcard") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class FunctionSymbol(AlgebraicLeaf): """Represents the name of a function. @@ -768,14 +886,13 @@ class FunctionSymbol(AlgebraicLeaf): allow `Call` to check the number of arguments. """ - def __getinitargs__(self): - return () - mapper_method = intern("map_function_symbol") # {{{ structural primitives +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Call(AlgebraicLeaf): """A function invocation. @@ -789,29 +906,14 @@ class Call(AlgebraicLeaf): of which is a :class:`Expression` or a constant. """ - - init_arg_names = ("function", "parameters",) - - def __init__(self, function, parameters): - self.function = function - self.parameters = parameters - - try: - arg_count = self.function.arg_count - except AttributeError: - pass - else: - if len(self.parameters) != arg_count: - raise TypeError( - f"{self.function} called with wrong number of arguments " - f"(need {arg_count}, got {len(parameters)})") - - def __getinitargs__(self): - return self.function, self.parameters + function: ExpressionT + parameters: Tuple[ExpressionT, ...] mapper_method = intern("map_call") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class CallWithKwargs(AlgebraicLeaf): """A function invocation with keyword arguments. @@ -832,48 +934,15 @@ class CallWithKwargs(AlgebraicLeaf): constructor. """ - init_arg_names = ("function", "parameters", "kw_parameters") - - def __init__(self, function, parameters, kw_parameters): - self.function = function - self.parameters = parameters - - if isinstance(kw_parameters, dict): - self.kw_parameters = kw_parameters - else: - self.kw_parameters = dict(kw_parameters) - - try: - arg_count = self.function.arg_count - except AttributeError: - pass - else: - if len(self.parameters) != arg_count: - raise TypeError( - f"{self.function} called with wrong number of arguments " - f"(need {arg_count}, got {len(parameters)})") - - def __getinitargs__(self): - return (self.function, - self.parameters, - tuple(sorted(self.kw_parameters.items(), key=lambda item: item[0]))) - - def __setstate__(self, state): - # CallWithKwargs must override __setstate__ because during pickling the - # kw_parameters are converted to tuple, which needs to be converted - # back to dict. - assert len(self.init_arg_names) == len(state) - function, parameters, kw_parameters = state - - self.function = function - self.parameters = parameters - if not isinstance(kw_parameters, dict): - kw_parameters = dict(kw_parameters) - self.kw_parameters = kw_parameters + function: ExpressionT + parameters: Tuple[ExpressionT, ...] + kw_parameters: Mapping[str, ExpressionT] mapper_method = intern("map_call_with_kwargs") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Subscript(AlgebraicLeaf): """An array subscript. @@ -884,15 +953,8 @@ class Subscript(AlgebraicLeaf): Return :attr:`index` wrapped in a single-element tuple, if it is not already a tuple. """ - - init_arg_names = ("aggregate", "index",) - - def __init__(self, aggregate, index): - self.aggregate = aggregate - self.index = index - - def __getinitargs__(self): - return self.aggregate, self.index + aggregate: ExpressionT + index: ExpressionT @property def index_tuple(self): @@ -904,19 +966,15 @@ class Subscript(AlgebraicLeaf): mapper_method = intern("map_subscript") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Lookup(AlgebraicLeaf): """Access to an attribute of an *aggregate*, such as an attribute of a class. """ - init_arg_names = ("aggregate", "name",) - - def __init__(self, aggregate, name): - self.aggregate = aggregate - self.name = name - - def __getinitargs__(self): - return self.aggregate, self.name + aggregate: ExpressionT + name: str mapper_method = intern("map_lookup") @@ -925,25 +983,17 @@ class Lookup(AlgebraicLeaf): # {{{ arithmetic primitives -class _MultiChildExpression(Expression): - init_arg_names = ("children",) - - def __init__(self, children): - assert isinstance(children, tuple) - - self.children = children - - def __getinitargs__(self): - return self.children, - - -class Sum(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class Sum(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] + def __add__(self, other): if not is_valid_operand(other): return NotImplemented @@ -986,13 +1036,17 @@ class Sum(_MultiChildExpression): mapper_method = intern("map_sum") -class Product(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class Product(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] + def __mul__(self, other): if not is_valid_operand(other): return NotImplemented @@ -1026,15 +1080,11 @@ class Product(_MultiChildExpression): mapper_method = intern("map_product") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class QuotientBase(Expression): - init_arg_names = ("numerator", "denominator",) - - def __init__(self, numerator, denominator=1): - self.numerator = numerator - self.denominator = denominator - - def __getinitargs__(self): - return self.numerator, self.denominator + numerator: ExpressionT + denominator: ExpressionT @property def num(self): @@ -1050,21 +1100,19 @@ class QuotientBase(Expression): __nonzero__ = __bool__ +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Quotient(QuotientBase): """ .. attribute:: numerator .. attribute:: denominator """ - def is_equal(self, other): - from pymbolic.rational import Rational - return isinstance(other, (Rational, Quotient)) \ - and (self.numerator == other.numerator) \ - and (self.denominator == other.denominator) - mapper_method = intern("map_quotient") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class FloorDiv(QuotientBase): """ .. attribute:: numerator @@ -1074,6 +1122,8 @@ class FloorDiv(QuotientBase): mapper_method = intern("map_floor_div") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Remainder(QuotientBase): """ .. attribute:: numerator @@ -1083,20 +1133,16 @@ class Remainder(QuotientBase): mapper_method = intern("map_remainder") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Power(Expression): """ .. attribute:: base .. attribute:: exponent """ - init_arg_names = ("base", "exponent",) - - def __init__(self, base, exponent): - self.base = base - self.exponent = exponent - - def __getinitargs__(self): - return self.base, self.exponent + base: ExpressionT + exponent: ExpressionT mapper_method = intern("map_power") @@ -1105,17 +1151,15 @@ class Power(Expression): # {{{ shift operators +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class _ShiftOperator(Expression): - init_arg_names = ("shiftee", "shift",) - - def __init__(self, shiftee, shift): - self.shiftee = shiftee - self.shift = shift - - def __getinitargs__(self): - return self.shiftee, self.shift + shiftee: ExpressionT + shift: ExpressionT +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class LeftShift(_ShiftOperator): """ .. attribute:: shiftee @@ -1125,6 +1169,8 @@ class LeftShift(_ShiftOperator): mapper_method = intern("map_left_shift") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class RightShift(_ShiftOperator): """ .. attribute:: shiftee @@ -1138,48 +1184,54 @@ class RightShift(_ShiftOperator): # {{{ bitwise operators +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class BitwiseNot(Expression): """ .. attribute:: child """ - init_arg_names = ("child",) - - def __init__(self, child): - self.child = child - - def __getinitargs__(self): - return (self.child,) + child: ExpressionT mapper_method = intern("map_bitwise_not") -class BitwiseOr(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class BitwiseOr(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] + mapper_method = intern("map_bitwise_or") -class BitwiseXor(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class BitwiseXor(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] mapper_method = intern("map_bitwise_xor") -class BitwiseAnd(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class BitwiseAnd(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] mapper_method = intern("map_bitwise_and") @@ -1188,6 +1240,8 @@ class BitwiseAnd(_MultiChildExpression): # {{{ comparisons, logic, conditionals +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Comparison(Expression): """ .. attribute:: left @@ -1201,9 +1255,14 @@ class Comparison(Expression): Unlike other expressions, comparisons are not implicitly constructed by comparing :class:`Expression` objects. See :meth:`Expression.eq`. + + .. attribute:: operator_to_name + .. attribute:: name_to_operator """ - init_arg_names = ("left", "operator", "right") + left: ExpressionT + operator: str + right: ExpressionT operator_to_name = { "==": "eq", @@ -1215,66 +1274,65 @@ class Comparison(Expression): } name_to_operator = {name: op for op, name in operator_to_name.items()} - def __init__(self, left, operator, right): - """ - :arg operator: accepts the same values as :attr:`operator`, or the - standard Python comparison operator names - - .. versionchanged:: 2020.2 - - Now also accepts Python operator names. - """ - self.left = left - self.right = right - - operator = self.name_to_operator.get(operator, operator) - - if operator not in self.operator_to_name: - raise RuntimeError(f"invalid operator: '{operator}'") - self.operator = operator - - def __getinitargs__(self): - return self.left, self.operator, self.right + def __post_init__(self): + # FIXME Yuck, gross + if self.operator not in self.operator_to_name: + if self.operator in self.name_to_operator: + warn("Passing operators by name is deprecated and will stop working " + "in 2024. " + "Use the name_to_operator class attribute to translate in " + "calling code instead.", + DeprecationWarning, stacklevel=3) + + object.__setattr__( + self, "operator", self.name_to_operator[self.operator]) + else: + raise RuntimeError(f"invalid operator: '{self.operator}'") mapper_method = intern("map_comparison") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class LogicalNot(Expression): """ .. attribute:: child """ - init_arg_names = ("child",) - - def __init__(self, child): - self.child = child - - def __getinitargs__(self): - return (self.child,) + child: ExpressionT mapper_method = intern("map_logical_not") -class LogicalOr(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class LogicalOr(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] + mapper_method = intern("map_logical_or") -class LogicalAnd(_MultiChildExpression): +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class LogicalAnd(Expression): """ .. attribute:: children A :class:`tuple`. """ + children: Tuple[ExpressionT, ...] mapper_method = intern("map_logical_and") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class If(Expression): """ .. attribute:: condition @@ -1282,52 +1340,36 @@ class If(Expression): .. attribute:: else_ """ - init_arg_names = ("condition", "then", "else_") - - def __init__(self, condition, then, else_): - self.condition = condition - self.then = then - self.else_ = else_ - - def __getinitargs__(self): - return self.condition, self.then, self.else_ + condition: ExpressionT + then: ExpressionT + else_: ExpressionT mapper_method = intern("map_if") -class IfPositive(Expression): - init_arg_names = ("criterion", "then", "else_") - - def __init__(self, criterion, then, else_): - from warnings import warn - warn("IfPositive is deprecated, use If( ... >0)", DeprecationWarning, - stacklevel=2) - - self.criterion = criterion - self.then = then - self.else_ = else_ - - def __getinitargs__(self): - return self.criterion, self.then, self.else_ - - mapper_method = intern("map_if_positive") - - -class _MinMaxBase(Expression): - init_arg_names = ("children",) +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class Min(Expression): + """ + .. attribute:: children - def __init__(self, children): - self.children = children + A :class:`tuple`. + """ + children: Tuple[ExpressionT, ...] - def __getinitargs__(self): - return (self.children,) + mapper_method = intern("map_min") -class Min(_MinMaxBase): - mapper_method = intern("map_min") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) +class Max(Expression): + """ + .. attribute:: children + A :class:`tuple`. + """ + children: Tuple[ExpressionT, ...] -class Max(_MinMaxBase): mapper_method = intern("map_max") # }}} @@ -1335,88 +1377,6 @@ class Max(_MinMaxBase): # {{{ misc stuff -class Vector(Expression): - """An immutable sequence that you can compute with.""" - - init_arg_names = ("children",) - - def __init__(self, children): - assert isinstance(children, tuple) - self.children = children - - from warnings import warn - warn("pymbolic vectors are deprecated in favor of either " - "(a) numpy object arrays and " - "(b) pymbolic.geometric_algebra.MultiVector " - "(depending on the required semantics)", - DeprecationWarning) - - def __bool__(self): - for i in self.children: - if is_nonzero(i): - return False - return True - - __nonzero__ = __bool__ - - def __len__(self): - return len(self.children) - - def __getitem__(self, index): - if is_constant(index): - return self.children[index] - else: - return Expression.__getitem__(self, index) - - def __neg__(self): - return Vector(tuple([-x for x in self])) - - def __add__(self, other): - if len(other) != len(self): - raise ValueError("can't add values of differing lengths") - return Vector(tuple([x+y for x, y in zip(self, other)])) - - def __radd__(self, other): - if len(other) != len(self): - raise ValueError("can't add values of differing lengths") - return Vector(tuple([y+x for x, y in zip(self, other)])) - - def __sub__(self, other): - if len(other) != len(self): - raise ValueError("can't subtract values of differing lengths") - return Vector(tuple([x-y for x, y in zip(self, other)])) - - def __rsub__(self, other): - if len(other) != len(self): - raise ValueError("can't subtract values of differing lengths") - return Vector(tuple([y-x for x, y in zip(self, other)])) - - def __mul__(self, other): - return Vector(tuple([x*other for x in self])) - - def __rmul__(self, other): - return Vector(tuple([other*x for x in self])) - - def __div__(self, other): - # Py2 only - import operator - return Vector(tuple([ - operator.div(x, other) for x in self # pylint: disable=no-member - ])) - - def __truediv__(self, other): - import operator - return Vector(tuple([operator.truediv(x, other) for x in self])) - - def __floordiv__(self, other): - return Vector(tuple([x//other for x in self])) - - def __getinitargs__(self): - return self.children - - mapper_method = intern("map_vector") - - class cse_scope: # noqa """Determines the lifetime for the saved value of a :class:`CommonSubexpression`. @@ -1441,6 +1401,8 @@ class cse_scope: # noqa GLOBAL = "pymbolic_global" +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class CommonSubexpression(Expression): """A helper for code generation and caching. Denotes a subexpression that should only be evaluated once. If, in code generation, it is assigned to @@ -1455,21 +1417,17 @@ class CommonSubexpression(Expression): See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example. """ - init_arg_names = ("child", "prefix", "scope") - - def __init__(self, child, prefix=None, scope=None): - """ - :arg scope: Defaults to :attr:`cse_scope.EVALUATION` if given as *None*. - """ - if scope is None: - scope = cse_scope.EVALUATION - - self.child = child - self.prefix = prefix - self.scope = scope + child: ExpressionT + prefix: Optional[str] = None + scope: str = cse_scope.EVALUATION - def __getinitargs__(self): - return (self.child, self.prefix, self.scope) + def __post_init__(self): + if self.scope is None: + warn("CommonSubexpression.scope set to None. " + "This is deprecated and will stop working in 2024. " + "Use cse_scope.EVALUATION explicitly instead.", + DeprecationWarning, stacklevel=3) + object.__setattr__(self, "scope", cse_scope.EVALUATION) def get_extra_properties(self): """Return a dictionary of extra kwargs to be passed to the @@ -1484,51 +1442,40 @@ class CommonSubexpression(Expression): mapper_method = intern("map_common_subexpression") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Substitution(Expression): """Work-alike of sympy's Subs.""" - init_arg_names = ("child", "variables", "values") - - def __init__(self, child, variables, values): - self.child = child - self.variables = variables - self.values = values - - def __getinitargs__(self): - return (self.child, self.variables, self.values) + child: ExpressionT + variables: Tuple[str, ...] + values: Tuple[ExpressionT, ...] mapper_method = intern("map_substitution") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Derivative(Expression): """Work-alike of sympy's Derivative.""" - init_arg_names = ("child", "variables") - - def __init__(self, child, variables): - self.child = child - self.variables = variables - - def __getinitargs__(self): - return (self.child, self.variables) + child: ExpressionT + variables: Tuple[str, ...] mapper_method = intern("map_derivative") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class Slice(Expression): """A slice expression as in a[1:7].""" - init_arg_names = ("children",) - - def __init__(self, children): - assert isinstance(children, tuple) - self.children = children - - if len(children) > 3: - raise ValueError("slice with more than three arguments") - - def __getinitargs__(self): - return (self.children,) + children: Union[ + Tuple[()], + Tuple[ExpressionT], + Tuple[ExpressionT, ExpressionT], + Tuple[ExpressionT, ExpressionT, ExpressionT], + ] def __bool__(self): return True @@ -1561,6 +1508,8 @@ class Slice(Expression): mapper_method = intern("map_slice") +@augment_expression_dataclass +@dataclass(frozen=True, repr=False) class NaN(Expression): """ An expression node representing not-a-number as a floating point number. @@ -1582,13 +1531,7 @@ class NaN(Expression): type. It must also be suitable for use as the second argument of :func:`isinstance`. """ - init_arg_names = ("data_type", ) - - def __init__(self, data_type=None): - self.data_type = data_type - - def __getinitargs__(self): - return (self.data_type, ) + data_type: Optional[Callable[[float], Any]] = None mapper_method = intern("map_nan") @@ -1597,9 +1540,9 @@ class NaN(Expression): # {{{ intelligent factory functions -def make_variable(var_or_string): - if not isinstance(var_or_string, Expression): - return Variable(var_or_string) +def make_variable(var_or_string: Union[Expression, str]) -> Variable: + if isinstance(var_or_string, str): + return Variable(intern(var_or_string)) else: return var_or_string diff --git a/setup.py b/setup.py index e0d5f35..d906169 100644 --- a/setup.py +++ b/setup.py @@ -11,35 +11,36 @@ finally: exec(compile(version_file_contents, "pymbolic/version.py", "exec"), ver_dic) -setup(name="pymbolic", - version=ver_dic["VERSION_TEXT"], - description="A package for symbolic computation", - long_description=open("README.rst").read(), - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Other Audience", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Natural Language :: English", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Software Development :: Libraries", - "Topic :: Utilities", - ], - author="Andreas Kloeckner", - author_email="inform@tiker.net", - license="MIT", - url="http://mathema.tician.de/software/pymbolic", - - packages=find_packages(), - python_requires="~=3.8", - install_requires=[ - "pytools>=2022.1.14", - ], - extras_require={ - "test": ["pytest>=2.3"], - }, - ) +setup( + name="pymbolic", + version=ver_dic["VERSION_TEXT"], + description="A package for symbolic computation", + long_description=open("README.rst").read(), + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Other Audience", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development :: Libraries", + "Topic :: Utilities", + ], + author="Andreas Kloeckner", + author_email="inform@tiker.net", + license="MIT", + url="http://mathema.tician.de/software/pymbolic", + packages=find_packages(), + python_requires="~=3.8", + install_requires=[ + "pytools>=2022.1.14", + "immutables", + ], + extras_require={ + "test": ["pytest>=2.3"], + }, +) -- GitLab