diff --git a/pytato/array.py b/pytato/array.py index 87f4e3d203c8939c521aff07275267517371ade7..42070f5098e21facd7ddf7d550771cb6f84f4ee1 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -118,7 +118,6 @@ Node constructors such as :class:`Placeholder.__init__` and .. autofunction:: make_placeholder .. autofunction:: make_size_param .. autofunction:: make_data_wrapper -.. autofunction:: make_slice """ # }}} @@ -471,10 +470,10 @@ class Array: .. attribute:: ndim """ - mapper_method: ClassVar[str] + _mapper_method: ClassVar[str] # A tuple of field names. Fields must be equality comparable and # hashable. Dicts of hashable keys and values are also permitted. - fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") + _fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") def __init__(self, tags: Optional[TagsType] = None): if tags is None: @@ -529,6 +528,7 @@ class Array: elif elem is None: raise NotImplementedError("newaxis is unsupported") else: + assert isinstance(elem, (int, slice)) slice_spec_expanded.append(elem) slice_spec_expanded.extend( @@ -559,7 +559,7 @@ class Array: else: raise ValueError("unknown index along dimension") - slice_ = make_slice(self, starts, sizes) + slice_ = _make_slice(self, starts, sizes) if len(kept_dims) != self.ndim: # Return an IndexLambda that elides the indexed-into dimensions @@ -573,6 +573,7 @@ class Array: if indices: expr = expr[tuple(indices)] + # FIXME: Flatten into a single IndexLambda return IndexLambda(self.namespace, expr, shape=tuple(shape), @@ -609,7 +610,7 @@ class Array: @memoize_method def __hash__(self) -> int: attrs = [] - for field in self.fields: + for field in self._fields: attr = getattr(self, field) if isinstance(attr, dict): attr = frozenset(attr.items()) @@ -624,7 +625,7 @@ class Array: and self.namespace is other.namespace and all( getattr(self, field) == getattr(other, field) - for field in self.fields)) + for field in self._fields)) def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -856,8 +857,8 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): .. automethod:: is_reference """ - fields = Array.fields + ("expr", "bindings") - mapper_method = "map_index_lambda" + _fields = Array._fields + ("expr", "bindings") + _mapper_method = "map_index_lambda" def __init__(self, namespace: Namespace, @@ -937,9 +938,9 @@ class MatrixProduct(Array): .. [pep465] https://www.python.org/dev/peps/pep-0465/ """ - fields = Array.fields + ("x1", "x2") + _fields = Array._fields + ("x1", "x2") - mapper_method = "map_matrix_product" + _mapper_method = "map_matrix_product" def __init__(self, x1: Array, @@ -992,8 +993,8 @@ class Stack(Array): """ - fields = Array.fields + ("arrays", "axis") - mapper_method = "map_stack" + _fields = Array._fields + ("arrays", "axis") + _mapper_method = "map_stack" def __init__(self, arrays: Tuple[Array, ...], @@ -1033,7 +1034,7 @@ class IndexRemappingBase(Array): The input :class:`~pytato.Array` """ - fields = Array.fields + ("array",) + _fields = Array._fields + ("array",) def __init__(self, array: Array, @@ -1065,8 +1066,8 @@ class Roll(IndexRemappingBase): Shift axis. """ - fields = IndexRemappingBase.fields + ("shift", "axis") - mapper_method = "map_roll" + _fields = IndexRemappingBase._fields + ("shift", "axis") + _mapper_method = "map_roll" def __init__(self, array: Array, @@ -1093,8 +1094,8 @@ class AxisPermutation(IndexRemappingBase): A permutation of the input axes. """ - fields = IndexRemappingBase.fields + ("axes",) - mapper_method = "map_axis_permutation" + _fields = IndexRemappingBase._fields + ("axes",) + _mapper_method = "map_axis_permutation" def __init__(self, array: Array, @@ -1132,8 +1133,8 @@ class Slice(IndexRemappingBase): .. attribute:: begin .. attribute:: size """ - fields = IndexRemappingBase.fields + ("begin", "size") - mapper_method = "map_slice" + _fields = IndexRemappingBase._fields + ("begin", "size") + _mapper_method = "map_slice" def __init__(self, array: Array, @@ -1172,8 +1173,8 @@ class InputArgumentBase(Array): """ # The name uniquely identifies this object in the namespace. Therefore, - # subclasses don't have to update *fields*. - fields = ("name",) + # subclasses don't have to update *_fields*. + _fields = ("name",) def __init__(self, namespace: Namespace, @@ -1237,7 +1238,7 @@ class DataWrapper(InputArgumentBase): this array may not be updated in-place. """ - mapper_method = "map_data_wrapper" + _mapper_method = "map_data_wrapper" def __init__(self, namespace: Namespace, @@ -1268,7 +1269,7 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): user during evaluation. """ - mapper_method = "map_placeholder" + _mapper_method = "map_placeholder" def __init__(self, namespace: Namespace, @@ -1292,7 +1293,7 @@ class SizeParam(InputArgumentBase): expressions for array sizes. """ - mapper_method = "map_size_param" + _mapper_method = "map_size_param" @property def shape(self) -> ShapeType: @@ -1422,7 +1423,7 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: return Stack(tuple(arrays), axis) -def make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Array: +def _make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Array: """Extract a constant-sized slice from an array with constant offsets. :param array: input array @@ -1462,6 +1463,7 @@ def make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Array if sval < 0 or not (0 <= bval + sval <= ubound): raise ValueError("index '%d' of 'size' out of bounds" % i) + # FIXME: Generate IndexLambda when possible return Slice(array, tuple(begin), tuple(size)) diff --git a/pytato/codegen.py b/pytato/codegen.py index 5a62bb278dd2bee49640990d7f6630ef065372e0..ad4fbc545e3f04324fe94727a013a43d00360cd7 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -432,7 +432,7 @@ class CodeGenMapper(Mapper): shape = [] for component in expr.shape: shape.append(self.exprgen_mapper(component, shape_context)) - # Not supported yet. + # Data-dependent shape: Not supported yet. assert not shape_context.depends_on assert not shape_context.reduction_bounds @@ -652,7 +652,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult, loopy_expr_context = LoopyExpressionContext(state, num_indices=0) loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) - # Rename reductions to names suitable as inames. + # Rename reduction variables to names suitable as inames. loopy_expr = rename_reductions( loopy_expr, loopy_expr_context, lambda old_name: state.var_name_gen(f"{name}{old_name}")) @@ -712,8 +712,8 @@ def rename_reductions( loopy_expr: ScalarExpression, loopy_expr_context: LoopyExpressionContext, var_name_gen: Callable[[str], str]) -> ScalarExpression: - """Rename the reductions in *loopy_expr* and *loopy_expr_context* using the - callable *var_name_gen.* + """Rename the reduction variables in *loopy_expr* and *loopy_expr_context* + using the callable *var_name_gen.* """ new_reduction_inames = tuple( var_name_gen(old_iname) diff --git a/pytato/transform.py b/pytato/transform.py index 812135879e29c82aa04d854e1675641992fe816b..19739745550bd29cac72580af84773d07f610e3f 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -67,7 +67,7 @@ class Mapper: method: Callable[..., Array] try: - method = getattr(self, expr.mapper_method) + method = getattr(self, expr._mapper_method) except AttributeError: if isinstance(expr, Array): return self.handle_unsupported_array(expr, *args, **kwargs)