
subroutine compute_flux_derivatives(states, fluxes, metrics, metric_jacobians, flux_derivatives)
  real, intent(in) :: states(nvars, nx, ny, nz)
  real, intent(in) :: fluxes(nvars, ndim, nx, ny, nz)
  real, intent(in) :: metrics(ndim, ndim, nx, ny, nz)
  real, intent(in) :: metric_jacobians(nx, ny, nz)
  real, intent(out) :: flux_derivatives(nvars, ndim, nx, ny, nz)

  real :: flux_derivatives_generalized(nvars, ndim, nx, ny, nz)
  real :: generalized_fluxes(nvars, ndim, nx, ny, nz)
  real :: generalized_states_frozen(nvars, -2:3)
  real :: generalized_fluxes_frozen(nvars, -2:3)

  ! grid spacing in generalized coordinates
  delta_xi = 1.0
  delta_eta = 1.0
  delta_zeta = 1.0

  flux_derivatives_generalized = 0.0

  call convert_to_generalized(fluxes, metrics, metric_jacobians, generalized_fluxes)

  do k=1,nz
    do j=1,ny
      do i=0,nx
        call pointwise_eigenvalues(states(:, i-2:i+3, j, k), lambda_pointwise)
        call roe_eigensystem(states(:, i:i+1, j, k), R, R_inv, lambda_roe)

        call lax_wavespeeds(lambda_pointwise, lambda_roe, lambda)

        call convert_to_generalized_frozen(states(:, i-2:i+3, j, k),
                                           fluxes(:, :, i-2:i+3, j, k),
                                           metrics(:, :, i:i+1, j, k),
                                           metric_jacobians(i:i+1, j, k),
                                           generalized_states_frozen,
                                           generalized_fluxes_frozen)
        call split_characteristic_fluxes(generalized_states_frozen,
                                         generalized_fluxes_frozen(:, 1, :),
                                         R_inv,
                                         lambda,
                                         characteristic_fluxes_pos,
                                         characteristic_fluxes_neg)

        call weno_flux(generalized_fluxes(:, 1, i-2:i+3, j, k),
                        characteristic_fluxes_pos, characteristic_fluxes_neg,
                        combined_frozen_metrics, R, flux)

        flux_derivatives_generalized(:, 1, i, j, k) += flux / delta_xi
        flux_derivatives_generalized(:, 1, i+1, j, k) -= flux / delta_xi
      end do
    end do
  end do

  ! two more loops:
  !   j is inner index, filling flux_derivatives_generalized(:,2,:,:,:), using delta_eta
  !   k is inner index, filling flux_derivatives_generalized(:,3,:,:,:), using delta_zeta

  call convert_from_generalized(flux_derivatives_generalized,
                                metrics,
                                metric_jacobians,
                                flux_derivatives)
end subroutine

subroutine pointwise_eigenvalues(states, lambda_pointwise)
  real, intent(in) :: states(nvars, -2:3)
  real, intent(out) :: lambda_pointwise(nvars, -2:3)

  ! Needs implementation
end subroutine

subroutine roe_eigensystem(states(:, i:i+1, j, k), R, R_inv, lambda_roe)
  real, intent(in) :: states(nvars, 2)
  real, intent(out) :: R(nvars, nvars)
  real, intent(out) :: R_inv(nvars, nvars)
  real, intent(out) :: lambda_roe(nvars)

  ! Needs implementation
end subroutine

subroutine lax_wavespeeds(lambda_pointwise, lambda_roe, lambda)
  real, intent(in) :: lambda_pointwise(nvars, -2:3)
  real, intent(in) :: lambda_roe(nvars)
  real, intent(out) :: lambda(nvars)

  real :: kappa

  kappa = 1.1

  do v=1,nvars
    lambda(v) = kappa * max(maxval(abs(lambda_pointwise(v, :))), abs(lambda_roe(v)))
  end do
end subroutine

subroutine convert_to_generalized_frozen(
            states,
            fluxes,
            metrics,
            metric_jacobians,
            generalized_states_frozen,
            generalized_fluxes_frozen)
  real, intent(in) :: states(nvars, -2:3)
  real, intent(in) :: fluxes(nvars, ndim, -2:3)
  real, intent(in) :: metrics(ndim, ndim, 2)
  real, intent(in) :: metric_jacobians(2)
  real, intent(out) :: generalized_states_frozen(nvars, -2:3)
  real, intent(out) :: generalized_fluxes_frozen(nvars, ndim, -2:3)

  real :: metrics_frozen(ndim, ndim)
  real :: jacobian_frozen

  metrics_frozen = 0.0
  do i=1,2
    metrics_frozen += 0.5*metrics(:,:,i)/metric_jacobians(i)
  end do
  jacobian_frozen = 0.5*sum(metric_jacobians)

  generalized_states_frozen = states/jacobian_frozen

  do k=-2,3
    do v=1,nvars
      generalized_fluxes_frozen = matmul(metrics_frozen, fluxes(v, :, k))
    end do
  end do
end subroutine

subroutine convert_to_generalized(fluxes,
                                  metrics,
                                  metric_jacobians,
                                  generalized_fluxes)
  real, intent(in) :: fluxes(nvars, ndim, nx, ny, nz)
  real, intent(in) :: metrics(ndim, ndim, nx, ny, nz)
  real, intent(in) :: metric_jacobians(nx, ny, nz)
  real, intent(out) :: generalized_fluxes(nvars, ndim, nx, ny, nz)

  do k=1,nz
    do j=1,ny
      do i=1,nx
        do v=1,nvars
          generalized_fluxes(v,:,i,j,k) =
            matmul(metrics(:,:,i,j,k), fluxes(v,:,i,j,k))/metric_jacobians(i,j,k)
        end do
      end do
    end do
  end do
end subroutine

subroutine convert_from_generalized(flux_derivatives_generalized,
                                    metrics,
                                    metric_jacobians,
                                    flux_derivatives)
  real, intent(in) :: flux_derivatives_generalized(nvars, ndim, nx, ny, nz)
  real, intent(in) :: metrics(ndim, ndim, nx, ny, nz)
  real, intent(in) :: metric_jacobians(nx, ny, nz)
  real, intent(out) :: flux_derivatives(nvars, ndim, nx, ny, nz)

  do k=1,nz
    do j=1,ny
      do i=1,nx
        do v=1,nvars
          flux_derivatives(v,:,i,j,k) =
            matmul(inverse(metrics(:,:,i,j,k)), flux_derivatives_generalized(v,:,i,j,k))
              *metric_jacobians(i,j,k)
        end do
      end do
    end do
  end do
end subroutine

subroutine split_characteristic_fluxes(generalized_states_frozen,
                                       generalized_fluxes_frozen,
                                       R_inv,
                                       lambda,
                                       characteristic_fluxes_pos,
                                       characteristic_fluxes_neg)
  real, intent(in) :: generalized_states_frozen(nvars, -2:3)
  real, intent(in) :: generalized_fluxes_frozen(nvars, -2:3)
  real, intent(in) :: R_inv(nvars, nvars)
  real, intent(in) :: lambda(nvars)
  real, intent(out) :: characteristic_fluxes_pos(nvars, -2:3)
  real, intent(out) :: characteristic_fluxes_neg(nvars, -2:3)

  do k=-2,3
    do m=1,nvars
      characteristic_fluxes_pos(m,k) = 0.5*sum(R_inv(m,:)
        *(generalized_fluxes_frozen(:,k) + lambda(m)*generalized_states_frozen(:,k)))
      characteristic_fluxes_neg(m,k) = 0.5*sum(R_inv(m,:)
        *(generalized_fluxes_frozen(:,k) - lambda(m)*generalized_states_frozen(:,k)))
    end do
  end do
end subroutine

subroutine weno_flux(generalized_fluxes, characteristic_fluxes_pos, characteristic_fluxes_neg,
                      combined_frozen_metrics, R, flux)
  real, intent(in) :: generalized_fluxes(nvars, -2:3)
  real, intent(in) :: characteristic_fluxes_pos(nvars, -2:3)
  real, intent(in) :: characteristic_fluxes_neg(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(in) :: R(nvars, nvars)
  real, intent(out) :: flux(nvars)

  call consistent_part(generalized_fluxes, consistent)
  call dissipation_part_pos(characteristic_fluxes_pos, combined_frozen_metrics, R, dissipation_pos)
  call dissipation_part_neg(characteristic_fluxes_neg, combined_frozen_metrics, R, dissipation_neg)

  flux = consistent + dissipation_pos + dissipation_neg
end subroutine

subroutine consistent_part(generalized_fluxes, consistent)
  consistent = sum([1, -8, 37, 37, -8, 1]*generalized_fluxes)/60
end subroutine

subroutine dissipation_part_pos(states, fluxes, metrics, metric_jacobians, dissipation_pos)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(in) :: R(nvars, nvars)

  call weno_combination_pos(characteristic_fluxes, combined_frozen_metrics, combined_fluxes)

  dissipation_pos = -matmul(R, combined_fluxes)/60
end subroutine

subroutine weno_combination_pos(characteristic_fluxes, combined_frozen_metrics, combined_fluxes)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(out) :: combined_fluxes(nvars)

  real :: w(nvars, 3)
  real :: flux_differences(nvars, 3)

  call weno_weights_pos(characteristic_fluxes, combined_frozen_metrics, w)
  call flux_differences_pos(characteristic_fluxes, flux_differences)

  combined_fluxes = (20*w(:,1) - 1)*flux_differences(:,1)
                      - (10*(w(:,1) + w(:,2)) - 5)*flux_differences(:,2)
                      + flux_differences(:,3)
end subroutine

subroutine weno_weights_pos(characteristic_fluxes, combined_frozen_metrics, w)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(out) :: w(nvars, 3)

  C = [0.1, 0.6, 0.3]
  eps = 1e-6*combined_frozen_metrics
  p = 2

  do i=1,nvars
    IS(1) = (1/4)*(sum([1, -4, 3]*characteristic_fluxes(i,-2:0)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i,-2:0)))**2
    IS(2) = (1/4)*(sum([-1, 0, 1]*characteristic_fluxes(i,-1:1)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i,-1:1)))**2
    IS(3) = (1/4)*(sum([-3, 4, -1]*characteristic_fluxes(i,0:2)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i,0:2)))**2

    alpha = C/(IS + eps)**p
    w(i,:) = alpha/sum(alpha)
  end do
end subroutine

subroutine flux_differences_pos(characteristic_fluxes, flux_differences)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(out) :: flux_differences(nvars, 3)

  do i=1,3
    flux_differences(:,i) = sum([-1, 3, -3, 1]*characteristic_fluxes(:,-3+i:i))
  end do
end subroutine

subroutine dissipation_part_neg(states, fluxes, metrics, metric_jacobians, dissipation_neg)
  real, intent(in) :: characteristic_fluxes(-2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(in) :: R(nvars, nvars)

  call weno_combination_neg(characteristic_fluxes, combined_frozen_metrics, combined_fluxes)

  dissipation_neg = matmul(R, combined_fluxes)/60
end subroutine

subroutine weno_combination_neg(characteristic_fluxes, combined_frozen_metrics, combined_fluxes)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(out) :: combined_fluxes(nvars)

  real :: w(nvars, 3)
  real :: flux_differences(nvars, 3)

  call weno_weights_neg(characteristic_fluxes, combined_frozen_metrics, w)
  call flux_differences_neg(characteristic_fluxes, flux_differences)

  combined_fluxes = (20*w(:,1) - 1)*flux_differences(:,1)
                      - (10*(w(:,1) + w(:,2)) - 5)*flux_differences(:,2)
                      + flux_differences(:,3)
end subroutine

subroutine weno_weights_neg(characteristic_fluxes, combined_frozen_metrics, w)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(in) :: combined_frozen_metrics
  real, intent(out) :: w(nvars, 3)

  C = [0.1, 0.6, 0.3]
  eps = 1e-6*combined_frozen_metrics
  p = 2

  do i=1,nvars
    IS(1) = (1/4)*(sum([1, -4, 3]*characteristic_fluxes(i, 3:1)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i, 3:1)))**2
    IS(2) = (1/4)*(sum([-1, 0, 1]*characteristic_fluxes(i, 2:0)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i, 2:0)))**2
    IS(3) = (1/4)*(sum([-3, 4, -1]*characteristic_fluxes(i, 1:-1)))**2
            + (13/12)*(sum([1, -2, 1]*characteristic_fluxes(i, 1:-1)))**2

    alpha = C/(IS + eps)**p
    w(i,:) = alpha/sum(alpha)
  end do
end subroutine

subroutine flux_differences_neg(characteristic_fluxes, flux_differences)
  real, intent(in) :: characteristic_fluxes(nvars, -2:3)
  real, intent(out) :: flux_differences(nvars, 3)

  do i=1,3
    flux_differences(:,i) = sum([-1, 3, -3, 1]*characteristic_fluxes(:,1-i:4-i))
  end do
end subroutine
