From 57acaed66bab67410e20470802783cb9c100cad8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 15 Aug 2017 11:59:07 -0500 Subject: [PATCH] Variable evaluation: Cast scalars up to vector if needed --- grudge/execution.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/grudge/execution.py b/grudge/execution.py index 69e4b7a4..ace2dc8b 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -71,6 +71,9 @@ class ExecutionMapper(mappers.Evaluator, # {{{ expression mappings ------------------------------------------------- def map_ones(self, expr): + if expr.dd.is_scalar(): + return 1 + discr = self.get_discr(expr.dd) result = discr.empty(self.queue, allocator=self.bound_op.allocator) @@ -82,7 +85,30 @@ class ExecutionMapper(mappers.Evaluator, return discr.nodes()[expr.axis].with_queue(self.queue) def map_grudge_variable(self, expr): - return self.context[expr.name] + from numbers import Number + + value = self.context[expr.name] + if not expr.dd.is_scalar() and isinstance(value, Number): + discr = self.get_discr(expr.dd) + ary = discr.empty(self.queue) + ary.fill(value) + value = ary + + return value + + def map_subscript(self, expr): + value = super(ExecutionMapper, self).map_subscript(expr) + + if isinstance(expr.aggregate, sym.Variable): + dd = expr.aggregate.dd + + from numbers import Number + if not dd.is_scalar() and isinstance(value, Number): + discr = self.get_discr(dd) + ary = discr.empty(self.queue) + ary.fill(value) + value = ary + return value def map_call(self, expr): from pymbolic.primitives import Variable -- GitLab