From 7898c60d28abe5000da37551210a4b773230db43 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 20 Jul 2021 13:01:54 -0500 Subject: [PATCH] traverse the DAG in a deterministic manner --- pytato/codegen.py | 2 +- pytato/transform.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index 997439f..45ea962 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -147,7 +147,7 @@ class CodeGenPreprocessor(CopyMapper): bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) - for name, subexpr in expr.bindings.items()} + for name, subexpr in sorted(expr.bindings.items())} return LoopyCall(translation_unit=translation_unit, bindings=bindings, diff --git a/pytato/transform.py b/pytato/transform.py index 05cd3eb..4a838b5 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -113,7 +113,7 @@ class CopyMapper(Mapper): def map_index_lambda(self, expr: IndexLambda) -> Array: bindings: Dict[str, Array] = { name: self.rec(subexpr) - for name, subexpr in expr.bindings.items()} + for name, subexpr in sorted(expr.bindings.items())} return IndexLambda(expr=expr.expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), @@ -185,7 +185,7 @@ class CopyMapper(Mapper): def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) - for name, subexpr in expr.bindings.items()} + for name, subexpr in sorted(expr.bindings.items())} return LoopyCall(translation_unit=expr.translation_unit, bindings=bindings, @@ -228,7 +228,7 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_index_lambda(self, expr: IndexLambda) -> CombineT: return self.combine(*(self.rec(bnd) - for bnd in expr.bindings.values()), + for _, bnd in sorted(expr.bindings.items())), *(self.rec(s) for s in expr.shape if isinstance(s, Array))) @@ -276,7 +276,7 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_loopy_call(self, expr: LoopyCall) -> CombineT: return self.combine(*(self.rec(ary) - for ary in expr.bindings.values() + for _, ary in sorted(expr.bindings.items()) if isinstance(ary, Array))) # }}} @@ -433,7 +433,7 @@ class WalkMapper(Mapper): if not self.visit(expr): return - for child in expr.bindings.values(): + for _, child in sorted(expr.bindings.items()): self.rec(child) for dim in expr.shape: @@ -502,7 +502,7 @@ class WalkMapper(Mapper): if not self.visit(expr): return - for child in expr.bindings.values(): + for _, child in sorted(expr.bindings.items()): if isinstance(child, Array): self.rec(child) -- GitLab