diff --git a/pytato/stringifier.py b/pytato/stringifier.py index eca3a0d54109f373a75842a26ca712719440a745..e3702792299fe862a6e66c2f0540d223c07caaf4 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -26,7 +26,7 @@ THE SOFTWARE. import numpy as np -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, cast from pytato.transform import Mapper from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis, IndexLambda, ReductionDescriptor) @@ -80,9 +80,10 @@ class Reprifier(Mapper): elif isinstance(expr, (dict, Map)): return ("{" + ", ".join(f"{key!r}: {self.rec(val, depth)}" - for key, val in expr.items()) + for key, val + in sorted(expr.items(), + key=lambda k_x_v: cast(str, k_x_v[0]))) + "}") - return "(" + ", ".join(self.rec(el, depth) for el in expr) + ")" elif isinstance(expr, (frozenset, set)): return "{" + ", ".join(self.rec(el, depth) for el in expr) + "}" elif isinstance(expr, np.dtype): diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index eaf645a1c7e916f78818a4ab4c7fdad03d0802d6..d630b3763ffa0aed06c1e5820cf2e4922408790b 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -402,8 +402,8 @@ class CodeGenMapper(Mapper): prstnt_ctx = PersistentExpressionContext(state) local_ctx = LocalExpressionContext( local_namespace={ - name: self.rec(subexpr, state) - for name, subexpr in expr.bindings.items()}, + name: self.rec(expr.bindings[name], state) + for name in sorted(expr.bindings)}, num_indices=expr.ndim, reduction_bounds={}, var_to_reduction_descr=expr.var_to_reduction_descr) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 593eee0627439e0e29dbb9b8faa1736cc7ce58cd..ce033cde59caff12bcf5d8291f404f35edcc1899 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -261,9 +261,9 @@ class CopyMapper(CachedMapper[ArrayOrNames]): for s in situp) def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Dict[str, Array] = { + bindings: Mapping[str, Array] = Map({ name: self.rec(subexpr) - for name, subexpr in sorted(expr.bindings.items())} + for name, subexpr in sorted(expr.bindings.items())}) return IndexLambda(expr=expr.expr, shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, @@ -762,7 +762,8 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_call(self, expr: Call) -> CombineT: return self.combine(self.map_function_definition(expr.function), - *[self.rec(bnd) for bnd in expr.bindings.values()]) + *[self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())]) def map_named_call_result(self, expr: NamedCallResult) -> CombineT: return self.rec(expr._container) @@ -1311,8 +1312,8 @@ class MPMSMaterializer(Mapper): new_expr = IndexLambda(expr.expr, expr.shape, expr.dtype, - {bnd_name: bnd.expr - for bnd_name, bnd in children_rec.items()}, + bindings={bnd_name: bnd.expr + for bnd_name, bnd in sorted(children_rec.items())}, axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index b70721ac4e25dd3e00f3e9785ec400180b6eeefe..5ef5b504973c32f4b5ca37d4f8d1f0aff65ded44 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -220,12 +220,12 @@ class EinsumDistributiveLawMapper(Mapper): raise NotImplementedError(hlo) else: rec_expr = IndexLambda( - expr.expr, - expr.shape, - expr.dtype, - Map({name: self.rec(bnd, None) - for name, bnd in expr.bindings.items()}), - expr.var_to_reduction_descr, + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=Map({name: self.rec(bnd, None) + for name, bnd in sorted(expr.bindings.items())}), + var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags, axes=expr.axes, )