From 35242d15c40cd0fb8bd90ed736d466677e02dbc7 Mon Sep 17 00:00:00 2001 From: Karl Rupp Date: Fri, 19 Jul 2013 09:23:45 -0500 Subject: [PATCH] Now using op_mat_mat_prod instead of op_prod for matrix-matrix products. The reason for splitting the common op_prod was the limitations encountered within the scheduler. With a common op_prod one first needs to deep-inspect the leaves in order to find out whether we are dealing with a matrix-vector and matrix-matrix product. By splitting op_prod, this deep inspection is no longer necessary. (Similar splittings for separating the matrix-vector product from outer vector products are likely to be applied later) --- viennacl/forwards.h | 2 ++ viennacl/linalg/prod.hpp | 16 ++++----- viennacl/matrix.hpp | 48 +++++++++++++------------- viennacl/scheduler/forwards.h | 6 ++-- viennacl/tools/matrix_size_deducer.hpp | 6 ++-- 5 files changed, 41 insertions(+), 37 deletions(-) diff --git a/viennacl/forwards.h b/viennacl/forwards.h index 4e12db14..4dad770a 100644 --- a/viennacl/forwards.h +++ b/viennacl/forwards.h @@ -72,6 +72,8 @@ namespace viennacl struct op_mult {}; /** @brief A tag class representing matrix-vector products */ struct op_prod {}; + /** @brief A tag class representing matrix-matrix products */ + struct op_mat_mat_prod {}; /** @brief A tag class representing division */ struct op_div {}; diff --git a/viennacl/linalg/prod.hpp b/viennacl/linalg/prod.hpp index c04a02e6..091e8cc7 100644 --- a/viennacl/linalg/prod.hpp +++ b/viennacl/linalg/prod.hpp @@ -138,14 +138,14 @@ namespace viennacl template< typename NumericT, typename F1, typename F2> viennacl::matrix_expression< const viennacl::matrix_base, const viennacl::matrix_base, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_base const & A, viennacl::matrix_base const & B) { // std::cout << "viennacl .. " << std::endl; return viennacl::matrix_expression< const viennacl::matrix_base, const viennacl::matrix_base, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } // right factor is transposed: @@ -154,7 +154,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_base const & A, viennacl::matrix_expression, const viennacl::matrix_base, @@ -165,7 +165,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } // left factor transposed: @@ -174,7 +174,7 @@ namespace viennacl const viennacl::matrix_base, op_trans>, const viennacl::matrix_base, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_expression, const viennacl::matrix_base, op_trans> const & A, @@ -185,7 +185,7 @@ namespace viennacl const viennacl::matrix_base, op_trans>, const viennacl::matrix_base, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } @@ -197,7 +197,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_expression, const viennacl::matrix_base, op_trans> const & A, @@ -212,7 +212,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } diff --git a/viennacl/matrix.hpp b/viennacl/matrix.hpp index 6beefb12..b4c4a683 100644 --- a/viennacl/matrix.hpp +++ b/viennacl/matrix.hpp @@ -2503,9 +2503,9 @@ namespace viennacl // C = A * B template - struct op_executor, op_assign, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_assign, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2515,11 +2515,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2529,11 +2529,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2543,11 +2543,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2556,9 +2556,9 @@ namespace viennacl // C += A * B template - struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2568,11 +2568,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2582,11 +2582,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2596,11 +2596,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2609,9 +2609,9 @@ namespace viennacl // C -= A * B template - struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2621,11 +2621,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2635,11 +2635,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2649,11 +2649,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } diff --git a/viennacl/scheduler/forwards.h b/viennacl/scheduler/forwards.h index af6ba9cc..9295cbd7 100644 --- a/viennacl/scheduler/forwards.h +++ b/viennacl/scheduler/forwards.h @@ -69,7 +69,8 @@ namespace viennacl OPERATION_BINARY_INPLACE_SUB_TYPE, OPERATION_BINARY_ADD_TYPE, OPERATION_BINARY_SUB_TYPE, - OPERATION_BINARY_PROD_TYPE, + OPERATION_BINARY_MAT_VEC_PROD_TYPE, + OPERATION_BINARY_MAT_MAT_PROD_TYPE, OPERATION_BINARY_MULT_TYPE, // scalar times vector/matrix OPERATION_BINARY_ELEMENT_MULT_TYPE, OPERATION_BINARY_ELEMENT_DIV_TYPE, @@ -114,7 +115,8 @@ namespace viennacl template <> struct op_type_info { enum { id = OPERATION_BINARY_INPLACE_SUB_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_ADD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_SUB_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; - template <> struct op_type_info { enum { id = OPERATION_BINARY_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; + template <> struct op_type_info { enum { id = OPERATION_BINARY_MAT_VEC_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; + template <> struct op_type_info { enum { id = OPERATION_BINARY_MAT_MAT_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_MULT_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info > { enum { id = OPERATION_BINARY_ELEMENT_MULT_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info > { enum { id = OPERATION_BINARY_ELEMENT_DIV_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; diff --git a/viennacl/tools/matrix_size_deducer.hpp b/viennacl/tools/matrix_size_deducer.hpp index 551c431a..e9d35e5e 100644 --- a/viennacl/tools/matrix_size_deducer.hpp +++ b/viennacl/tools/matrix_size_deducer.hpp @@ -141,7 +141,7 @@ namespace viennacl struct MATRIX_SIZE_DEDUCER, const viennacl::matrix_base, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { static std::size_t size1(viennacl::matrix_expression, const viennacl::matrix_expression, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { static std::size_t size1(viennacl::matrix_base const & lhs, viennacl::matrix_expression, const viennacl::matrix_expression, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { typedef viennacl::matrix_expression LHSType; typedef viennacl::matrix_expression RHSType; -- GitLab