diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index f680347c3e7656f90c5325ab4a00ceea4977f798..68a99197c991c5630bca8dbf6867e801f1531fc3 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -356,7 +356,7 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): @memoize_method def _get_stored_ids_and_coeff_mat(self): from six import iteritems - from sumpy.tools import nullspace + from sumpy.tools import nullspace, solve_symbolic tol = 1e-13 stored_identifiers = [] @@ -382,6 +382,7 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): eq[coeff_ident_enumerate_dict[c]] = coeff else: pde_mat.append(eq) + if len(pde_mat) > 0: r""" Find a matrix :math:`s` such that :math:`K = S^T K_{[r]}` @@ -399,7 +400,7 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): n = nullspace(pde_mat) idx = self.get_reduced_coeffs() assert len(idx) >= n.shape[1] - s = n.T[:, idx].solve(n.T) + s = solve_symbolic(n.T[:,idx], n.T) stored_identifiers = [mis[i] for i in idx] else: s = np.eye(len(mis)) diff --git a/sumpy/tools.py b/sumpy/tools.py index c5438320517642aad4ef7acb590511e42290d3f1..996f75f08b7812cea060fe018c8fc731e8489449 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -657,45 +657,56 @@ def my_syntactic_subs(expr, subst_dict): return expr -def rref(mat): - rows = len(mat) - cols = len(mat[0]) - col = 0 - pivot_cols = [] - for row in range(rows): - if col >= cols: +def rref(m): + l = np.array(m, dtype=object) + index = 0 + nrows = l.shape[0] + ncols = l.shape[1] + pivot_cols = [] + for i in range(ncols): + if index == nrows: break - i = row - while mat[i][col] == 0: - i += 1 - if i == rows: - i = row - col += 1 - if col == cols: - return mat, pivot_cols - - pivot_cols.append(col) - mat[i], mat[row] = mat[row], mat[i] - - piv = mat[row][col] - for c in range(col, cols): - mat[row][c] /= piv - - for r in range(rows): - if r == row: + pivot = nrows + for k in range(index, nrows): + if l[k, i] != 0 and pivot == nrows: + pivot = k + if abs(l[k, i]) == 1: + pivot = k + break + if pivot == nrows: + continue + if pivot != index: + l[pivot,:], l[index,:] = l[index,:].copy(), l[pivot,:].copy() + + pivot_cols.append(i) + scale = l[index, i] + t = l[index,:]//scale + not_exact = (t * scale != l[index, :]) + if (np.any(not_exact)): + for j in range(ncols): + if not_exact[j]: + t[j] = sym.sympify(l[index, j])/scale + + l[index,:] = t + + for j in range(nrows): + if (j == index): continue - piv = mat[r][col] - for c in range(col, cols): - mat[r][c] -= piv * mat[row][c] - col += 1 - return mat, pivot_cols + + scale = l[j, i] + if scale != 0: + l[j,:] = l[j,:] - l[index,:]*scale + + index = index + 1 + + return l, pivot_cols def nullspace(m): - m2 = [[sym.sympify(col) for col in row] for row in m] - mat, pivot_cols = rref(m2) - cols = len(mat[0]) + mat, pivot_cols = rref(m) + pivot_cols = list(pivot_cols) + cols = mat.shape[1] free_vars = [i for i in range(cols) if i not in pivot_cols] @@ -705,8 +716,17 @@ def nullspace(m): vec[free_var] = 1 for piv_row, piv_col in enumerate(pivot_cols): for pos in pivot_cols[piv_row+1:] + [free_var]: - vec[piv_col] -= mat[piv_row][pos] + vec[piv_col] -= mat[piv_row,pos] n.append(vec) - return sym.Matrix(n).T + return np.array(n, dtype=object).T + + +def solve_symbolic(A, b): + if isinstance(A, sym.Matrix): + big = A.row_join(b) + else: + big = np.hstack((A, b)) + red = rref(big)[0] + return red[:,big.shape[0]:] # vim: fdm=marker