From 23a454442a1ffcdf7dca26c3afd4cb92c46b549d Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Fri, 23 Aug 2019 23:04:41 -0500
Subject: [PATCH] introduce new test to check whether bug is in WENO.F90 or
 Mathematica script

---
 test.py      | 43 +++++++++++++++++++++++++++++++++++++++++++
 utilities.py |  3 ++-
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/test.py b/test.py
index 53ab820..d8ae1bf 100644
--- a/test.py
+++ b/test.py
@@ -24,6 +24,49 @@ from data_for_test import (  # noqa: F401
 def test_weno_weight_computation(ctx_factory, flux_test_data_fixture):
     data = flux_test_data_fixture
 
+    def weno_weights(oscillation):
+        linear = np.array([0.1, 0.6, 0.3])
+        eps = 1e-6
+
+        raw_weights = np.empty((5,3))
+        for i in range(5):
+            for j in range(3):
+                raw_weights[i,j] = linear[j]/(oscillation[i,j] + eps)**2
+
+        weight_sum = raw_weights.sum(axis=1)
+        weights = np.empty((5,3))
+        for i in range(5):
+            for j in range(3):
+                weights[i,j] = raw_weights[i,j]/weight_sum[i]
+
+        return weights
+
+    prg = u.get_weno_program_with_root_kernel("weno_weights_pos")
+    queue = u.get_queue(ctx_factory)
+
+    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)
+
+    prg(queue, nvars=data.nvars,
+            characteristic_fluxes=data.char_fluxes_pos,
+            combined_frozen_metrics=1.0,
+            w=weights_dev)
+
+    w = weno_weights(data.oscillation_pos)
+    u.compare_arrays(weights_dev.get(), w)
+
+    prg = u.get_weno_program_with_root_kernel("weno_weights_neg")
+    queue = u.get_queue(ctx_factory)
+
+    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)
+
+    prg(queue, nvars=data.nvars,
+            characteristic_fluxes=data.char_fluxes_neg,
+            combined_frozen_metrics=1.0,
+            w=weights_dev)
+
+    w = weno_weights(data.oscillation_neg)
+    u.compare_arrays(weights_dev.get(), w)
+
 
 def test_weno_flux_uniform_grid(ctx_factory, flux_test_data_fixture):
     data = flux_test_data_fixture
diff --git a/utilities.py b/utilities.py
index c8bde8a..927763f 100644
--- a/utilities.py
+++ b/utilities.py
@@ -11,7 +11,8 @@ from pytest import approx
 # {{{ arrays
 
 def compare_arrays(a, b):
-    assert a == approx(b, rel=1e-5, abs=2e-5)
+    #assert a == approx(b, rel=1e-5, abs=2e-5)
+    assert a == approx(b, rel=1e-12, abs=1e-14)
 
 
 def random_array_on_device(queue, *shape):
-- 
GitLab