load("eigen");
load("itensor");
load("diag");

assert(condition):=if not condition then error("Assertion violated") else true$

norm_2_squared(v):=v.v;

/* A constant matrix of size n x m */
constmatrix(n,m,c):=genmatrix(lambda ([i,j], c),n,m);

vstack:append;

hstack(a,b):=transpose(append(transpose(a),transpose(b)));

blockmat(a11,a12,a21,a22):=vstack(hstack(a11,a12),hstack(a21,a22))$

crossfunc(f):=makelist(
  sum(sum(
      levi_civita([i,j,k])*f(j,k),
   j,1,3),k,1,3),i,1,3)$

crossprod(a,b):=crossfunc(lambda([j,k], a[j]*b[k]));

subrange(v, start, stop):=makelist(v[i,1],i,start,stop);

extract_diagonal(A):=makelist(A[i][i],i,1,length(A));
make_diagonal_matrix(v):=genmatrix(
  lambda([i,j], if i=j then v[i] else 0),
  length(v),length(v));

/* ------------------------------------------------------------------------- */
/* Simplification for expressions stemming from hyperbolic systems */
/* ------------------------------------------------------------------------- */

hypsimp(x):=ratsimp(ratsubst(1,n.n,x))$

fullhypsimp(x):=hypsimp(
  ratsubst(
    last(n)^2,
    1-sum(n[i]^2,i,1,length(n)-1),
    x)
  );

/* ------------------------------------------------------------------------- */
/* diagonalize a given hyperbolic operator A */
/* ------------------------------------------------------------------------- */

hypdiagonalize(A):=block([evA, V, invV,D],
  evA:hypsimp(apply(append, eigenvectors(A)[2])),
  V:transpose(apply(matrix, evA)),
  invV:hypsimp(invert(V)),
  assert(hypsimp(V.invV)=ident(length(A))),
  D:hypsimp(invV.A.V),
  [V, D, invV]);

/* ------------------------------------------------------------------------- */
/* compute upwind flux for a given operator with eigenvalues evs, sorted
 * in ascending order.
 * Sign assumptions for all variables occurring in evs must be in place.
 */
/* ------------------------------------------------------------------------- */
hyp_upwind_flux(evs, D):=block([evvars, Dp, Dm, n, midstates, states, unknowns],
  evvars:listofvars(evs),

  add_evvars_suffix(suffix, x):=subst(makelist(v=concat(''v, suffix), v, evvars), x),

  evsm:add_evvars_suffix(m, evs),
  evsp:add_evvars_suffix(p, evs),

  Dm:add_evvars_suffix(m, D),
  Dp:add_evvars_suffix(p, D),

  midstates:makelist(makelist(concat(s,state,i), i, 1, length(D)),
      state, 1, length(evs)-1),

  states:append(
    [makelist(concat(sm, i), i, 1, length(D))],
    midstates,
    [makelist(concat(sp,i), i, 1, length(D))]),

  unknowns:apply(append, midstates),

  result:if member(0, evs) then
    block([biasedD, veceqns, eqns, soln],
      biasedD:makelist(
        if evs[i] = 0 then [Dp,Dm]
        else if evs[i] > 0 then [Dp,Dp]
        else [Dm,Dm],
        i, 1, length(evs)),

      veceqns:apply(append, makelist(
        -(if evs[i] > 0 then evsp[i] else evsm[i]) *(states[i+1]-states[i])
        +(biasedD[i][1].states[i+1]-biasedD[i][2].states[i]),
        i,1,length(evs))),

      eqns:makelist(veceqns[i,1], i, 1, length(veceqns)),

      soln:solve(eqns, unknowns),
      assert(length(soln)=1),

      for i: 1 thru length(evs) do
        if evs[i] = 0 then return(Dp.subst(soln[1], midstates[i]))
    )
  else
    block([straddle_idx, Dstates, veceqns, eqns, soln],
      straddle_idx:for i: 1 thru length(evs)-1 do
        if (evs[i] < 0) and (evs[i+1] > 0) then return(i),

      flux:makelist(concat(flux,i),i,1,length(D)),

      unknowns:append(unknowns, flux),

      Dstates:append(
        [Dm.first(states)],
        makelist(
          if i = straddle_idx then flux
          else if evs[i] > 0 then Dp.midstates[i]
          else Dm.midstates[i],
          i, 1, length(midstates)),
        [Dp.last(states)]),

      veceqns:apply(append, makelist(
        -(if evs[i] > 0 then evsp[i] else evsm[i]) *(states[i+1]-states[i])
        +(Dstates[i+1]-Dstates[i]),
        i,1,length(evs))),

      eqns:makelist(veceqns[i,1], i, 1, length(veceqns)),

      print(covect(eqns)),
      soln:solve(eqns, unknowns),
      assert(length(soln)=1),

      subst(soln[1], flux)
    ),
  subst(
    append(
      makelist(concat(sm, i)=sm[i,1], i, 1, length(D)),
      makelist(concat(sp, i)=sp[i,1], i, 1, length(D))
      ),
    result)
  );