Skip to content
Snippets Groups Projects
Commit 214e865b authored by Matt Smith's avatar Matt Smith
Browse files

attempt to fix force_device_scalars handling

parent dbcec037
No related branches found
No related tags found
No related merge requests found
......@@ -165,18 +165,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
actx = vec.array_context
if hasattr(actx, "_force_device_scalars"):
force_device_scalars = actx._force_device_scalars
else:
force_device_scalars = True
def device_sum(x):
result = actx.np.sum(x)
if not force_device_scalars:
result = actx.from_numpy(np.array(result))
return result
return sum([device_sum(grp_ary) for grp_ary in vec])
return sum([actx.np.sum(grp_ary) for grp_ary in vec])
def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any:
......@@ -228,9 +217,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
result = as_device_scalar(result)
return result
return reduce(
result = reduce(
lambda acc, grp_ary: actx.np.minimum(acc, device_min(grp_ary)),
vec, as_device_scalar(np.inf))
if not force_device_scalars:
result = actx.to_numpy(result)[()]
return result
def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any:
......@@ -282,9 +275,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
result = as_device_scalar(result)
return result
return reduce(
result = reduce(
lambda acc, grp_ary: actx.np.maximum(acc, device_max(grp_ary)),
vec, as_device_scalar(-np.inf))
if not force_device_scalars:
result = actx.to_numpy(result)[()]
return result
def integral(dcoll: DiscretizationCollection, dd, vec):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment