diff --git a/viennacl/forwards.h b/viennacl/forwards.h index 4e12db143d593da3ea08fb34847690f33d7b62ce..4dad770af6bee83107011b44ad053d24d3e8067a 100644 --- a/viennacl/forwards.h +++ b/viennacl/forwards.h @@ -72,6 +72,8 @@ namespace viennacl struct op_mult {}; /** @brief A tag class representing matrix-vector products */ struct op_prod {}; + /** @brief A tag class representing matrix-matrix products */ + struct op_mat_mat_prod {}; /** @brief A tag class representing division */ struct op_div {}; diff --git a/viennacl/linalg/prod.hpp b/viennacl/linalg/prod.hpp index c04a02e61fd33de404e070e12ac85831d5c5f6f2..091e8cc716e829acdcd18e6c0ef7bba9d6c3f0fb 100644 --- a/viennacl/linalg/prod.hpp +++ b/viennacl/linalg/prod.hpp @@ -138,14 +138,14 @@ namespace viennacl template< typename NumericT, typename F1, typename F2> viennacl::matrix_expression< const viennacl::matrix_base, const viennacl::matrix_base, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_base const & A, viennacl::matrix_base const & B) { // std::cout << "viennacl .. " << std::endl; return viennacl::matrix_expression< const viennacl::matrix_base, const viennacl::matrix_base, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } // right factor is transposed: @@ -154,7 +154,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_base const & A, viennacl::matrix_expression, const viennacl::matrix_base, @@ -165,7 +165,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } // left factor transposed: @@ -174,7 +174,7 @@ namespace viennacl const viennacl::matrix_base, op_trans>, const viennacl::matrix_base, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_expression, const viennacl::matrix_base, op_trans> const & A, @@ -185,7 +185,7 @@ namespace viennacl const viennacl::matrix_base, op_trans>, const viennacl::matrix_base, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } @@ -197,7 +197,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod > + viennacl::op_mat_mat_prod > prod(viennacl::matrix_expression, const viennacl::matrix_base, op_trans> const & A, @@ -212,7 +212,7 @@ namespace viennacl const viennacl::matrix_expression, const viennacl::matrix_base, op_trans>, - viennacl::op_prod >(A, B); + viennacl::op_mat_mat_prod >(A, B); } diff --git a/viennacl/matrix.hpp b/viennacl/matrix.hpp index 6beefb1269b56b8c807dd60dd767a68f0bdc95e1..b4c4a683554df052682bd35f93ab11821bfead32 100644 --- a/viennacl/matrix.hpp +++ b/viennacl/matrix.hpp @@ -2503,9 +2503,9 @@ namespace viennacl // C = A * B template - struct op_executor, op_assign, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_assign, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2515,11 +2515,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2529,11 +2529,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2543,11 +2543,11 @@ namespace viennacl template struct op_executor, op_assign, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(0)); } @@ -2556,9 +2556,9 @@ namespace viennacl // C += A * B template - struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2568,11 +2568,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2582,11 +2582,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2596,11 +2596,11 @@ namespace viennacl template struct op_executor, op_inplace_add, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(1.0), T(1.0)); } @@ -2609,9 +2609,9 @@ namespace viennacl // C -= A * B template - struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_prod> > + struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_mat_mat_prod> > { - static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_prod> const & rhs) + static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2621,11 +2621,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2635,11 +2635,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_base, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } @@ -2649,11 +2649,11 @@ namespace viennacl template struct op_executor, op_inplace_sub, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> > + op_mat_mat_prod> > { static void apply(matrix_base & lhs, matrix_expression, const matrix_base, op_trans>, const matrix_expression, const matrix_base, op_trans>, - op_prod> const & rhs) + op_mat_mat_prod> const & rhs) { viennacl::linalg::prod_impl(rhs.lhs(), rhs.rhs(), lhs, T(-1.0), T(1.0)); } diff --git a/viennacl/scheduler/forwards.h b/viennacl/scheduler/forwards.h index af6ba9cc0d92f80f323aceb966bfba3e63b5b194..9295cbd731404973ce978a69b87f8ebdee8e618e 100644 --- a/viennacl/scheduler/forwards.h +++ b/viennacl/scheduler/forwards.h @@ -69,7 +69,8 @@ namespace viennacl OPERATION_BINARY_INPLACE_SUB_TYPE, OPERATION_BINARY_ADD_TYPE, OPERATION_BINARY_SUB_TYPE, - OPERATION_BINARY_PROD_TYPE, + OPERATION_BINARY_MAT_VEC_PROD_TYPE, + OPERATION_BINARY_MAT_MAT_PROD_TYPE, OPERATION_BINARY_MULT_TYPE, // scalar times vector/matrix OPERATION_BINARY_ELEMENT_MULT_TYPE, OPERATION_BINARY_ELEMENT_DIV_TYPE, @@ -114,7 +115,8 @@ namespace viennacl template <> struct op_type_info { enum { id = OPERATION_BINARY_INPLACE_SUB_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_ADD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_SUB_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; - template <> struct op_type_info { enum { id = OPERATION_BINARY_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; + template <> struct op_type_info { enum { id = OPERATION_BINARY_MAT_VEC_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; + template <> struct op_type_info { enum { id = OPERATION_BINARY_MAT_MAT_PROD_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info { enum { id = OPERATION_BINARY_MULT_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info > { enum { id = OPERATION_BINARY_ELEMENT_MULT_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; template <> struct op_type_info > { enum { id = OPERATION_BINARY_ELEMENT_DIV_TYPE, family = OPERATION_BINARY_TYPE_FAMILY }; }; diff --git a/viennacl/tools/matrix_size_deducer.hpp b/viennacl/tools/matrix_size_deducer.hpp index 551c431a6f661732925da589805447ce4e99fb03..e9d35e5e56249a7a1156d88b9577f754f522831a 100644 --- a/viennacl/tools/matrix_size_deducer.hpp +++ b/viennacl/tools/matrix_size_deducer.hpp @@ -141,7 +141,7 @@ namespace viennacl struct MATRIX_SIZE_DEDUCER, const viennacl::matrix_base, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { static std::size_t size1(viennacl::matrix_expression, const viennacl::matrix_expression, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { static std::size_t size1(viennacl::matrix_base const & lhs, viennacl::matrix_expression, const viennacl::matrix_expression, - viennacl::op_prod> + viennacl::op_mat_mat_prod> { typedef viennacl::matrix_expression LHSType; typedef viennacl::matrix_expression RHSType;