From 363ece00689185f7a86019a021cfbc407e8e5cce Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 May 2015 21:47:48 -0500
Subject: [PATCH] Add DiscretizationConnection.full_resample_matrix

---
 meshmode/discretization/connection.py | 47 +++++++++++++++++++++++++++
 test/test_meshmode.py                 |  5 +++
 2 files changed, 52 insertions(+)

diff --git a/meshmode/discretization/connection.py b/meshmode/discretization/connection.py
index 9bbd1013..251ef6c9 100644
--- a/meshmode/discretization/connection.py
+++ b/meshmode/discretization/connection.py
@@ -137,6 +137,53 @@ class DiscretizationConnection(object):
                 mp.simplex_onb(self.from_discr.dim, from_grp.order),
                 ibatch.result_unit_nodes, from_grp.unit_nodes)
 
+    def full_resample_matrix(self, queue):
+        @memoize_method_nested
+        def knl():
+            import loopy as lp
+            knl = lp.make_kernel(
+                """{[k,i,j]:
+                    0<=k<nelements and
+                    0<=i<n_to_nodes and
+                    0<=j<n_from_nodes}""",
+                "result[itgt_base + target_element_indices[k]*n_to_nodes + i, \
+                        isrc_base + source_element_indices[k]*n_from_nodes + j] \
+                    = resample_mat[i, j]",
+                [
+                    lp.GlobalArg("result", None,
+                        shape="nnodes_tgt, nnodes_src",
+                        offset=lp.auto),
+                    lp.ValueArg("itgt_base,isrc_base", np.int32),
+                    lp.ValueArg("nnodes_tgt,nnodes_src", np.int32),
+                    "...",
+                    ],
+                name="oversample_mat")
+
+            knl = lp.split_iname(knl, "i", 16, inner_tag="l.0")
+            return lp.tag_inames(knl, dict(k="g.0"))
+
+        result = cl.array.zeros(
+                queue,
+                (self.to_discr.nnodes, self.from_discr.nnodes),
+                dtype=self.to_discr.real_dtype)
+
+        for i_grp, (tgrp, sgrp, cgrp) in enumerate(
+                zip(self.to_discr.groups, self.from_discr.groups, self.groups)):
+            for i_batch, batch in enumerate(cgrp.batches):
+                if len(batch.source_element_indices):
+                    if not len(batch.source_element_indices):
+                        continue
+
+                    knl()(queue,
+                            resample_mat=self._resample_matrix(i_grp, i_batch),
+                            result=result,
+                            itgt_base=tgrp.node_nr_base,
+                            isrc_base=sgrp.node_nr_base,
+                            source_element_indices=batch.source_element_indices,
+                            target_element_indices=batch.target_element_indices)
+
+        return result
+
     def __call__(self, queue, vec):
         @memoize_method_nested
         def knl():
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 7a635362..086653ef 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -102,6 +102,11 @@ def test_boundary_interpolation(ctx_getter):
         bdry_f = 0.1*cl.clmath.sin(30*bdry_x)
         bdry_f_2 = bdry_connection(queue, f)
 
+        mat = bdry_connection.full_resample_matrix(queue).get(queue)
+        bdry_f_2_by_mat = mat.dot(f.get())
+
+        assert(la.norm(bdry_f_2.get(queue=queue) - bdry_f_2_by_mat)) < 1e-14
+
         err = la.norm((bdry_f-bdry_f_2).get(), np.inf)
         eoc_rec.add_data_point(h, err)
 
-- 
GitLab