From 031615c977a614c863fb6eb279693236bad45856 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 20 Aug 2020 17:05:32 -0500 Subject: [PATCH 1/5] Privatize _make_slice --- pytato/array.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 87f4e3d..f6f6471 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -118,7 +118,6 @@ Node constructors such as :class:`Placeholder.__init__` and .. autofunction:: make_placeholder .. autofunction:: make_size_param .. autofunction:: make_data_wrapper -.. autofunction:: make_slice """ # }}} @@ -559,7 +558,7 @@ class Array: else: raise ValueError("unknown index along dimension") - slice_ = make_slice(self, starts, sizes) + slice_ = _make_slice(self, starts, sizes) if len(kept_dims) != self.ndim: # Return an IndexLambda that elides the indexed-into dimensions @@ -1422,7 +1421,7 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: return Stack(tuple(arrays), axis) -def make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Array: +def _make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Array: """Extract a constant-sized slice from an array with constant offsets. :param array: input array -- GitLab From a8be07145a8ad20398056a2b21b8ee2d0591ec0b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 20 Aug 2020 17:07:03 -0500 Subject: [PATCH 2/5] Add underscore to Array._mapper_method --- pytato/array.py | 20 ++++++++++---------- pytato/transform.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index f6f6471..02e48e3 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -470,7 +470,7 @@ class Array: .. attribute:: ndim """ - mapper_method: ClassVar[str] + _mapper_method: ClassVar[str] # A tuple of field names. Fields must be equality comparable and # hashable. Dicts of hashable keys and values are also permitted. fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") @@ -856,7 +856,7 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): """ fields = Array.fields + ("expr", "bindings") - mapper_method = "map_index_lambda" + _mapper_method = "map_index_lambda" def __init__(self, namespace: Namespace, @@ -938,7 +938,7 @@ class MatrixProduct(Array): """ fields = Array.fields + ("x1", "x2") - mapper_method = "map_matrix_product" + _mapper_method = "map_matrix_product" def __init__(self, x1: Array, @@ -992,7 +992,7 @@ class Stack(Array): """ fields = Array.fields + ("arrays", "axis") - mapper_method = "map_stack" + _mapper_method = "map_stack" def __init__(self, arrays: Tuple[Array, ...], @@ -1065,7 +1065,7 @@ class Roll(IndexRemappingBase): Shift axis. """ fields = IndexRemappingBase.fields + ("shift", "axis") - mapper_method = "map_roll" + _mapper_method = "map_roll" def __init__(self, array: Array, @@ -1093,7 +1093,7 @@ class AxisPermutation(IndexRemappingBase): A permutation of the input axes. """ fields = IndexRemappingBase.fields + ("axes",) - mapper_method = "map_axis_permutation" + _mapper_method = "map_axis_permutation" def __init__(self, array: Array, @@ -1132,7 +1132,7 @@ class Slice(IndexRemappingBase): .. attribute:: size """ fields = IndexRemappingBase.fields + ("begin", "size") - mapper_method = "map_slice" + _mapper_method = "map_slice" def __init__(self, array: Array, @@ -1236,7 +1236,7 @@ class DataWrapper(InputArgumentBase): this array may not be updated in-place. """ - mapper_method = "map_data_wrapper" + _mapper_method = "map_data_wrapper" def __init__(self, namespace: Namespace, @@ -1267,7 +1267,7 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): user during evaluation. """ - mapper_method = "map_placeholder" + _mapper_method = "map_placeholder" def __init__(self, namespace: Namespace, @@ -1291,7 +1291,7 @@ class SizeParam(InputArgumentBase): expressions for array sizes. """ - mapper_method = "map_size_param" + _mapper_method = "map_size_param" @property def shape(self) -> ShapeType: diff --git a/pytato/transform.py b/pytato/transform.py index 8121358..1973974 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -67,7 +67,7 @@ class Mapper: method: Callable[..., Array] try: - method = getattr(self, expr.mapper_method) + method = getattr(self, expr._mapper_method) except AttributeError: if isinstance(expr, Array): return self.handle_unsupported_array(expr, *args, **kwargs) -- GitLab From e3a024c2c62ec8347e76067a1686e8b6333fff03 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 20 Aug 2020 17:08:00 -0500 Subject: [PATCH 3/5] Add underscore to Array._fields --- pytato/array.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 02e48e3..17355f2 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -473,7 +473,7 @@ class Array: _mapper_method: ClassVar[str] # A tuple of field names. Fields must be equality comparable and # hashable. Dicts of hashable keys and values are also permitted. - fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") + _fields: ClassVar[Tuple[str, ...]] = ("shape", "dtype", "tags") def __init__(self, tags: Optional[TagsType] = None): if tags is None: @@ -608,7 +608,7 @@ class Array: @memoize_method def __hash__(self) -> int: attrs = [] - for field in self.fields: + for field in self._fields: attr = getattr(self, field) if isinstance(attr, dict): attr = frozenset(attr.items()) @@ -623,7 +623,7 @@ class Array: and self.namespace is other.namespace and all( getattr(self, field) == getattr(other, field) - for field in self.fields)) + for field in self._fields)) def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -855,7 +855,7 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): .. automethod:: is_reference """ - fields = Array.fields + ("expr", "bindings") + _fields = Array._fields + ("expr", "bindings") _mapper_method = "map_index_lambda" def __init__(self, @@ -936,7 +936,7 @@ class MatrixProduct(Array): .. [pep465] https://www.python.org/dev/peps/pep-0465/ """ - fields = Array.fields + ("x1", "x2") + _fields = Array._fields + ("x1", "x2") _mapper_method = "map_matrix_product" @@ -991,7 +991,7 @@ class Stack(Array): """ - fields = Array.fields + ("arrays", "axis") + _fields = Array._fields + ("arrays", "axis") _mapper_method = "map_stack" def __init__(self, @@ -1032,7 +1032,7 @@ class IndexRemappingBase(Array): The input :class:`~pytato.Array` """ - fields = Array.fields + ("array",) + _fields = Array._fields + ("array",) def __init__(self, array: Array, @@ -1064,7 +1064,7 @@ class Roll(IndexRemappingBase): Shift axis. """ - fields = IndexRemappingBase.fields + ("shift", "axis") + _fields = IndexRemappingBase._fields + ("shift", "axis") _mapper_method = "map_roll" def __init__(self, @@ -1092,7 +1092,7 @@ class AxisPermutation(IndexRemappingBase): A permutation of the input axes. """ - fields = IndexRemappingBase.fields + ("axes",) + _fields = IndexRemappingBase._fields + ("axes",) _mapper_method = "map_axis_permutation" def __init__(self, @@ -1131,7 +1131,7 @@ class Slice(IndexRemappingBase): .. attribute:: begin .. attribute:: size """ - fields = IndexRemappingBase.fields + ("begin", "size") + _fields = IndexRemappingBase._fields + ("begin", "size") _mapper_method = "map_slice" def __init__(self, @@ -1171,8 +1171,8 @@ class InputArgumentBase(Array): """ # The name uniquely identifies this object in the namespace. Therefore, - # subclasses don't have to update *fields*. - fields = ("name",) + # subclasses don't have to update *_fields*. + _fields = ("name",) def __init__(self, namespace: Namespace, -- GitLab From 400fe784b5ef333ccc0c8bf6e93e63c6085e08e1 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 20 Aug 2020 17:08:18 -0500 Subject: [PATCH 4/5] Add an assert and some FIXMEs --- pytato/array.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index 17355f2..42070f5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -528,6 +528,7 @@ class Array: elif elem is None: raise NotImplementedError("newaxis is unsupported") else: + assert isinstance(elem, (int, slice)) slice_spec_expanded.append(elem) slice_spec_expanded.extend( @@ -572,6 +573,7 @@ class Array: if indices: expr = expr[tuple(indices)] + # FIXME: Flatten into a single IndexLambda return IndexLambda(self.namespace, expr, shape=tuple(shape), @@ -1461,6 +1463,7 @@ def _make_slice(array: Array, begin: Sequence[int], size: Sequence[int]) -> Arra if sval < 0 or not (0 <= bval + sval <= ubound): raise ValueError("index '%d' of 'size' out of bounds" % i) + # FIXME: Generate IndexLambda when possible return Slice(array, tuple(begin), tuple(size)) -- GitLab From 88ccedc9b1398a8c39f5736c8df9b85ea4ad6e61 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 20 Aug 2020 17:12:09 -0500 Subject: [PATCH 5/5] Codegen comments: reductions -> reduction variables --- pytato/codegen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index 5a62bb2..ad4fbc5 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -432,7 +432,7 @@ class CodeGenMapper(Mapper): shape = [] for component in expr.shape: shape.append(self.exprgen_mapper(component, shape_context)) - # Not supported yet. + # Data-dependent shape: Not supported yet. assert not shape_context.depends_on assert not shape_context.reduction_bounds @@ -652,7 +652,7 @@ def add_store(name: str, expr: Array, result: ImplementedResult, loopy_expr_context = LoopyExpressionContext(state, num_indices=0) loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) - # Rename reductions to names suitable as inames. + # Rename reduction variables to names suitable as inames. loopy_expr = rename_reductions( loopy_expr, loopy_expr_context, lambda old_name: state.var_name_gen(f"{name}{old_name}")) @@ -712,8 +712,8 @@ def rename_reductions( loopy_expr: ScalarExpression, loopy_expr_context: LoopyExpressionContext, var_name_gen: Callable[[str], str]) -> ScalarExpression: - """Rename the reductions in *loopy_expr* and *loopy_expr_context* using the - callable *var_name_gen.* + """Rename the reduction variables in *loopy_expr* and *loopy_expr_context* + using the callable *var_name_gen.* """ new_reduction_inames = tuple( var_name_gen(old_iname) -- GitLab