Skip to content
Snippets Groups Projects
Commit 93b73161 authored by Isuru Fernando's avatar Isuru Fernando
Browse files

add fast spmv transpose as well

parent e937212a
No related branches found
No related tags found
1 merge request!68Specify the PDE symbolically
......@@ -266,19 +266,32 @@ def _spmv(spmat, x, sparse_vectors):
# }}}
def _fast_spmv(reconstruct_matrix, stored, sac):
res = [0] * len(reconstruct_matrix)
stored_idx = 0
for row, deps in reconstruct_matrix:
if len(deps) == 0:
res[row] = stored[stored_idx]
stored_idx += 1
else:
for k, v in deps.items():
res[row] += res[k] * v
new_sym = sym.Symbol(sac.assign_unique("expr", res[row]))
res[row] = new_sym
return res
def _fast_spmv(reconstruct_matrix, vec, sac, transpose=False):
if not transpose:
res = [0] * len(reconstruct_matrix)
stored_idx = 0
for row, deps in enumerate(reconstruct_matrix):
if len(deps) == 0:
res[row] = vec[stored_idx]
stored_idx += 1
else:
for k, v in deps:
res[row] += res[k] * v
new_sym = sym.Symbol(sac.assign_unique("expr", res[row]))
res[row] = new_sym
return res
else:
res = []
expr_all = vec.copy()
for row, deps in reversed(enumerate(reconstruct_matrix)):
if len(deps) == 0:
res.append(expr_all[row])
continue
new_sym = sym.Symbol(sac.assign_unique("expr", expr_all[row]))
for k, v in deps:
expr_all[k] += new_sym * v
res.reverse()
return res
class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
......@@ -311,14 +324,18 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
return _fast_spmv(reconstruct_matrix, stored_kernel_derivatives, sac)
def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients,
rscale):
rscale, sac=None):
# = M^T x, where M = coeff matrix
coeff_matrix, _, _ = self.get_coefficient_matrix(rscale)
result = [0] * len(self.stored_identifiers)
for row, coeff in enumerate(full_mpole_coefficients):
for col, val in coeff_matrix[row]:
result[col] += coeff * val
return result
coeff_matrix, reconstruct_matrix, use_reconstruct = \
self.get_coefficient_matrix(rscale)
if not use_reconstruct or sac is None:
result = [0] * len(self.stored_identifiers)
for row, coeff in enumerate(full_mpole_coefficients):
for col, val in coeff_matrix[row]:
result[col] += coeff * val
return result
return _fast_spmv(reconstruct_matrix, full_mpole_coefficients, sac,
tranpose=True)
@property
def stored_identifiers(self):
......@@ -370,20 +387,20 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
reconstruct_matrix_with_rscale = []
count_nonzero_reconstruct = 0
for row, deps in six.iteritems(reconstruct_matrix):
for row, deps in enumerate(reconstruct_matrix):
# For eg: (u_xxx / rscale**3) = (u_yy / rscale**2) * coeff1 +
# (u_xx / rscale**2) * coeff2
# is converted to u_xxx = u_yy * (rscale * coeff1) +
# u_xx * (rscale * coeff2)
row_rscale = sum(full_coeffs[row])
matrix_row = []
deps_with_rscale = {}
deps_with_rscale = []
for k, coeff in deps:
diff = row_rscale - sum(full_coeffs[k])
mult = rscale**diff
deps_with_rscale[k] = coeff * mult
deps_with_rscale.append((k, coeff * mult))
count_nonzero_reconstruct += len(deps)
reconstruct_matrix_with_rscale.append((row, deps_with_rscale))
reconstruct_matrix_with_rscale.append(deps_with_rscale)
use_reconstruct = count_nonzero_reconstruct < count_nonzero_coeff
......@@ -422,8 +439,9 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
(i, mi) in enumerate(stored_identifiers))
coeff_matrix_dict = defaultdict(lambda: defaultdict(lambda: 0))
reconstruct_matrix = defaultdict(list)
reconstruct_matrix = []
for i, mi in enumerate(mis):
reconstruct_matrix.append([])
if is_stored(mi):
coeff_matrix_dict[i][stored_ident_enumerate_dict[mi]] = 1
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment