{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "\n",
    "import pyopencl as cl\n",
    "from pytools import add_tuples\n",
    "\n",
    "import sumpy.toys as t\n",
    "from sumpy.expansion.local import VolumeTaylorLocalExpansion\n",
    "from sumpy.expansion.multipole import VolumeTaylorMultipoleExpansion\n",
    "from sumpy.kernel import HelmholtzKernel, LaplaceKernel, YukawaKernel  # noqa: F401\n",
    "\n",
    "\n",
    "rng = np.random.default_rng(seed=42)\n",
    "order = 4\n",
    "\n",
    "if 0:\n",
    "    knl = LaplaceKernel(2)\n",
    "    pde = [(1, (2, 0)), (1, (0, 2))]\n",
    "    extra_kernel_kwargs = {}\n",
    "\n",
    "else:\n",
    "    helm_k = 1.2\n",
    "    knl = HelmholtzKernel(2)\n",
    "    extra_kernel_kwargs = {\"k\": helm_k}\n",
    "\n",
    "    pde = [(1, (2, 0)), (1, (0, 2)), (helm_k**2, (0, 0))]\n",
    "\n",
    "mpole_expn = VolumeTaylorMultipoleExpansion(knl, order)\n",
    "local_expn = VolumeTaylorLocalExpansion(knl, order)\n",
    "\n",
    "cl_ctx = cl.create_some_context(answers=[\"port\"])\n",
    "\n",
    "tctx = t.ToyContext(\n",
    "    cl_ctx,\n",
    "    knl,\n",
    "    mpole_expn_class=type(mpole_expn),\n",
    "    local_expn_class=type(local_expn),\n",
    "    extra_kernel_kwargs=extra_kernel_kwargs,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pt_src = t.PointSources(tctx, rng.uniform(-0.5, 0.5, size=(2, 50)), np.ones(50))\n",
    "\n",
    "mexp = t.multipole_expand(pt_src, [0, 0], order)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mexp.coeffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_pde_mat(expn, pde):\n",
    "    coeff_ids = expn.get_coefficient_identifiers()\n",
    "    id_to_index = expn._storage_loc_dict\n",
    "\n",
    "    # FIXME: specific to scalar PDEs\n",
    "    pde_mat = np.zeros((len(coeff_ids), len(coeff_ids)))\n",
    "\n",
    "    row = 0\n",
    "    for base_coeff_id in coeff_ids:\n",
    "        valid = True\n",
    "\n",
    "        for pde_coeff, coeff_id_offset in pde:\n",
    "            other_coeff = add_tuples(base_coeff_id, coeff_id_offset)\n",
    "            if other_coeff not in id_to_index:\n",
    "                valid = False\n",
    "                break\n",
    "\n",
    "            pde_mat[row, id_to_index[other_coeff]] = pde_coeff\n",
    "\n",
    "        if valid:\n",
    "            row += 1\n",
    "        else:\n",
    "            pde_mat[row] = 0\n",
    "\n",
    "    return pde_mat[:row]\n",
    "\n",
    "\n",
    "pde_mat = build_pde_mat(mpole_expn, pde)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_nullspace(mat, tol=1e-10):\n",
    "    _u, sig, vt = la.svd(pde_mat, full_matrices=True)\n",
    "    zerosig = np.where(np.abs(sig) < tol)[0]\n",
    "    if zerosig.size:\n",
    "        nullsp_start = zerosig[0]\n",
    "        assert np.array_equal(zerosig, np.arange(nullsp_start, pde_mat.shape[1]))\n",
    "    else:\n",
    "        nullsp_start = pde_mat.shape[0]\n",
    "\n",
    "    return vt[nullsp_start:].T\n",
    "\n",
    "\n",
    "nullsp = find_nullspace(pde_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "la.norm(pde_mat @ nullsp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_translation_mat(mexp, to_center):\n",
    "    n = len(mexp.coeffs)\n",
    "    result = np.zeros((n, n))\n",
    "\n",
    "    for j in range(n):\n",
    "        unit_coeffs = np.zeros(n)\n",
    "        unit_coeffs[j] = 1\n",
    "        unit_mexp = mexp.with_coeffs(unit_coeffs)\n",
    "\n",
    "        result[:, j] = t.multipole_expand(unit_mexp, to_center).coeffs\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "new_center = np.array([0, 0.5])\n",
    "tmat = build_translation_mat(mexp, new_center)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(tmat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nullsp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 0:\n",
    "    reduction_mat = nullsp.T\n",
    "    expansion_mat = nullsp\n",
    "elif 1:\n",
    "    chosen_indices_and_coeff_ids = [\n",
    "        (i, cid)\n",
    "        for i, cid in enumerate(mpole_expn.get_coefficient_identifiers())\n",
    "        if cid[0] < 2\n",
    "    ]\n",
    "    chosen_indices = [idx for idx, _ in chosen_indices_and_coeff_ids]\n",
    "\n",
    "    expansion_mat = np.zeros((\n",
    "        len(mpole_expn.get_coefficient_identifiers()),\n",
    "        len(chosen_indices_and_coeff_ids),\n",
    "    ))\n",
    "    for i, (idx, _) in enumerate(chosen_indices_and_coeff_ids):\n",
    "        expansion_mat[idx, i] = 1\n",
    "\n",
    "    reduction_mat = (nullsp @ la.inv(nullsp[chosen_indices])).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_coeffs(expn, coeffs, **kwargs):\n",
    "    x = [cid[0] for cid in expn.get_coefficient_identifiers()]\n",
    "    y = [cid[1] for cid in expn.get_coefficient_identifiers()]\n",
    "    plt.scatter(x, y, c=coeffs, **kwargs)\n",
    "    plt.colorbar()\n",
    "\n",
    "    for cid, coeff in zip(expn.get_coefficient_identifiers(), coeffs, strict=True):\n",
    "        plt.text(cid[0], cid[1] + 0.2, f\"{coeff:.1f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj_mexp = mexp.with_coeffs(expansion_mat @ reduction_mat @ mexp.coeffs)\n",
    "\n",
    "proj_resid = proj_mexp.coeffs - mexp.coeffs\n",
    "\n",
    "plot_coeffs(mpole_expn, np.log10(1e-15 + np.abs(proj_resid)), vmin=-15, vmax=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(t.l_inf(proj_mexp - mexp, 1.2, center=[3, 0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_unproj = t.multipole_expand(mexp, new_center)\n",
    "trans_proj = t.multipole_expand(proj_mexp, new_center)\n",
    "\n",
    "print(t.l_inf(trans_unproj - trans_proj, 1.2, center=[3, 0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trans_proj.coeffs - trans_unproj.coeffs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "la.norm(reduction_mat @ (trans_proj.coeffs - trans_unproj.coeffs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t.l_inf(trans_unproj - pt_src, 1.2, center=[3, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t.l_inf(mexp - pt_src, 1.2, center=[3, 0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}