diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_matmul_to_fc.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_matmul_to_fc.cpp index 30cdb9063f9..b16e13eb064 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_matmul_to_fc.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_matmul_to_fc.cpp @@ -4,7 +4,6 @@ #include "convert_matmul_to_fc.hpp" #include "op/fully_connected.hpp" -#include #include #include #include @@ -13,25 +12,39 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertMatMulToFC, "ConvertMatMulToFC", 0); MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() { - auto matmul = ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), - ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}, - ngraph::pattern::has_static_shape()); + auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank()); + auto weights_m = ngraph::pattern::wrap_type(); + auto matmul_m = ngraph::pattern::wrap_type({ activations_m, weights_m }, ngraph::pattern::has_static_rank()); - ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { - auto matmul = std::dynamic_pointer_cast(m.get_match_root()); - if (!matmul) { + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto matmul = std::dynamic_pointer_cast(pattern_map.at(matmul_m).get_node_shared_ptr()); + if (!matmul || transformation_callback(matmul)) { return false; } - auto input_a = matmul->input(0).get_source_output(); - auto input_b = matmul->input(1).get_source_output(); + // fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations. + // So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b. + auto fc_input_a = pattern_map.at(activations_m); + auto fc_input_b = pattern_map.at(weights_m); - auto shape_a = input_a.get_shape(); - auto shape_b = input_b.get_shape(); - auto output_shape = matmul->get_shape(); + auto shape_a = fc_input_a.get_partial_shape(); + auto shape_b = fc_input_b.get_partial_shape(); + NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher + + auto rank_a = shape_a.rank().get_length(); + auto rank_b = shape_b.rank().get_length(); // Transformation to FC is not supported for 1D second input - if (shape_b.size() == 1) { + if (rank_b == 1) { + return false; + } + + // Check that if second inputs is Constant path and it's shape without ones dimensions has length <= 2 + // we replace MatMul with FullyConnected operation. + if (!std::dynamic_pointer_cast(fc_input_b.get_node_shared_ptr()) || + std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) { return false; } @@ -42,15 +55,17 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() { * for example: [2, 32, 64] [3, 64, 64] it will raise an exception. */ - auto get_aligned_shapes = [shape_a, shape_b, &matmul]() -> std::pair { - ngraph::Shape shape_a_aligned(shape_a), shape_b_aligned(shape_b); - size_t max_size = std::max(shape_a_aligned.size(), shape_b_aligned.size()); - for (size_t i = 0, cnt = max_size - shape_a_aligned.size(); i < cnt; ++i) + auto get_aligned_shapes = [shape_a, shape_b, rank_a, rank_b, &matmul]() -> std::tuple { + ngraph::PartialShape shape_a_aligned(shape_a), shape_b_aligned(shape_b); + size_t max_size = std::max(rank_a, rank_b); + for (size_t i = 0, cnt = max_size - rank_a; i < cnt; ++i) { shape_a_aligned.insert(shape_a_aligned.begin(), 1); - for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i) + } + for (size_t i = 0, cnt = max_size - rank_b; i < cnt; ++i) { shape_b_aligned.insert(shape_b_aligned.begin(), 1); + } - if (matmul->get_transpose_a() && shape_a.size() != 1) { + if (matmul->get_transpose_a() && rank_a != 1) { std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2)); } if (matmul->get_transpose_b()) { @@ -58,16 +73,25 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() { } for (size_t i = 0; i < max_size - 2; ++i) { - if (shape_a_aligned[i] != shape_b_aligned[i] && shape_a_aligned[i] > 1 && shape_b_aligned[i] > 1) { + auto a_dim = shape_a_aligned[i], b_dim = shape_b_aligned[i]; + if (a_dim.is_dynamic()) { + if (b_dim == 1) { + shape_a_aligned[i] = shape_b_aligned[i] = a_dim; + } else { + return std::make_tuple(false, ngraph::PartialShape{shape_a_aligned}, ngraph::PartialShape{shape_b_aligned}); + } + continue; + } + // both dimensions are static + if (a_dim != b_dim && a_dim.get_length() > 1 && b_dim.get_length() > 1) { std::ostringstream stream; stream << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned; throw ngraph::ngraph_error(stream.str()); } - size_t max_value = std::max(shape_a_aligned[i], shape_b_aligned[i]); + size_t max_value = std::max(a_dim.get_length(), b_dim.get_length()); shape_a_aligned[i] = shape_b_aligned[i] = max_value; } - - return {shape_a_aligned, shape_b_aligned}; + return std::make_tuple(true, shape_a_aligned, shape_b_aligned); }; /* @@ -78,76 +102,68 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() { * order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute. */ - auto create_transpose = [this](ngraph::Output node, const std::string& transpose_name) -> std::shared_ptr { - ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape(); - - std::vector transpose_order(output_shape.size()); + auto create_transpose = [this](const ngraph::Output& node, const std::string& transpose_name) { + auto rank = node.get_partial_shape().rank(); + std::vector transpose_order(rank.get_length()); std::iota(transpose_order.begin(), transpose_order.end(), 0); std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2)); - auto transpose = ngraph::pass::MatcherPass::register_new_node( - node, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{transpose_order.size()}, transpose_order)); + auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order); + auto transpose = ngraph::op::util::make_try_fold(node, transpose_const); + if (!ngraph::is_type(transpose)) { + MatcherPass::register_new_node(transpose); + } transpose->set_friendly_name(transpose_name); return transpose; }; - // fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations. - // So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and - // fc_input_b updated. - auto fc_input_a = input_a, fc_input_b = input_b; - - // vector of new nGraph operations ngraph::NodeVector new_ops; - - // Check that if second inputs is Constant operation and it's shape without ones dimensions has length <= 2 - // we replace MatMul with FullyConnected operation. - // Otherwise we replace MatMul with Gemm. - if ((std::dynamic_pointer_cast(fc_input_b.get_node_shared_ptr()) || - std::dynamic_pointer_cast(fc_input_b.get_node_shared_ptr())) && - std::count_if(shape_b.begin(), shape_b.end(), [](size_t x) { return x != 1; }) <= 2) { - ngraph::Shape shape_a_aligned, shape_b_aligned; - std::tie(shape_a_aligned, shape_b_aligned) = get_aligned_shapes(); - - if (shape_a_aligned.size() < 2 || shape_b_aligned.size() < 2) { - throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent."); - } - - // Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] - // to FullyConnected representation: [I, K] * [K, O] = [I, O] - size_t K = *(shape_a_aligned.end() - 1); - ngraph::Shape B(shape_a_aligned.begin(), shape_a_aligned.end() - 2); - - // Weights normalization - if (!matmul->get_transpose_b()) { - fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b"); - new_ops.push_back(fc_input_b.get_node_shared_ptr()); - } - - if (shape_b.size() != 2) { - auto reshape_shape = - ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {-1ll, static_cast(K)}); - fc_input_b = std::make_shared(fc_input_b, reshape_shape, true); - new_ops.push_back(fc_input_b.get_node_shared_ptr()); - } - - // Input normalization - if (matmul->get_transpose_a() && shape_a.size() != 1) { - fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a"); - new_ops.push_back(fc_input_a.get_node_shared_ptr()); - } - - // Create FullyConnected - auto fc = std::make_shared(fc_input_a, fc_input_b, output_shape, matmul->output(0).get_element_type()); - fc->set_friendly_name(matmul->get_friendly_name()); - new_ops.push_back(fc); - - ngraph::copy_runtime_info(matmul, new_ops); - ngraph::replace_node(matmul, fc); - return true; + bool success = true; + ngraph::PartialShape shape_a_aligned, shape_b_aligned; + std::tie(success, shape_a_aligned, shape_b_aligned) = get_aligned_shapes(); + if (!success) { + return false; } - return false; + + auto aligned_a_rank = shape_a_aligned.rank(), aligned_b_rank = shape_b_aligned.rank(); + if (aligned_a_rank.is_dynamic() || aligned_b_rank.is_dynamic() || aligned_a_rank.get_length() < 2 || aligned_b_rank.get_length() < 2) { + throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent."); + } + + // Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] + // to FullyConnected representation: [I, K] * [K, O] = [I, O] + + // Weights normalization + if (!matmul->get_transpose_b()) { + fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b"); + new_ops.push_back(fc_input_b.get_node_shared_ptr()); + } + + if (rank_b != 2) { + ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1); + NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher + std::vector reshape_shape_values = { -1ll, static_cast(K.get_length()) }; + auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values); + fc_input_b = ngraph::op::util::make_try_fold(fc_input_b, reshape_shape, false); + new_ops.push_back(fc_input_b.get_node_shared_ptr()); + } + + // Input normalization + if (matmul->get_transpose_a() && rank_a != 1) { + fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a"); + new_ops.push_back(fc_input_a.get_node_shared_ptr()); + } + + auto output_rank = matmul->get_output_partial_shape(0).rank(); + // Create FullyConnected + auto fc = std::make_shared(fc_input_a, fc_input_b, output_rank, matmul->get_output_element_type(0)); + fc->set_friendly_name(matmul->get_friendly_name()); + new_ops.push_back(fc); + ngraph::copy_runtime_info(matmul, new_ops); + ngraph::replace_node(matmul, fc); + return true; }; - auto m = std::make_shared(matmul, "ConvertMatMulToFC"); + auto m = std::make_shared(matmul_m, "ConvertMatMulToFC"); this->register_matcher(m, callback); } diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/fc_bias_fusion.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/fc_bias_fusion.cpp index 5ad5b180b52..fa9bba7dd6f 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/fc_bias_fusion.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/fc_bias_fusion.cpp @@ -9,53 +9,58 @@ #include #include +#include "transformations/utils/utils.hpp" + NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0); MKLDNNPlugin::FullyConnectedBiasFusion::FullyConnectedBiasFusion() { - auto m_fc = ngraph::pattern::wrap_type([](ngraph::Output output) { - return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_shape()(output); + auto input = ngraph::pattern::any_input(); + auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto m_fc = ngraph::pattern::wrap_type({ input, weights }, [](ngraph::Output output) { + return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output); }); - auto m_bias = ngraph::pattern::any_input(); + auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); auto m_add = ngraph::pattern::wrap_type({m_fc, m_bias}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { - auto & pattern_to_output = m.get_pattern_value_map(); + auto& pattern_to_output = m.get_pattern_value_map(); auto add = pattern_to_output[m_add].get_node_shared_ptr(); auto bias = pattern_to_output[m_bias].get_node_shared_ptr(); auto fc = std::dynamic_pointer_cast(pattern_to_output[m_fc].get_node_shared_ptr()); - if (!fc) { + if (!fc || transformation_callback(fc)) { return false; } - if (auto bcast = std::dynamic_pointer_cast(bias)) { - bias = bcast->input_value(0).get_node_shared_ptr(); - } - if (!std::dynamic_pointer_cast(bias)) { return false; } ngraph::Shape bias_shape(bias->get_shape()); - ngraph::Shape output_shape(fc->get_shape()); - size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), size_t{1}, std::multiplies()); - if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) { + ngraph::PartialShape output_shape(fc->get_output_partial_shape(0)); + size_t bias_size = ngraph::shape_size(bias_shape); + auto rank = output_shape.rank().get_length(); + if (rank == 0 || output_shape[rank - 1].is_dynamic()) { + return false; + } + + if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) { return false; } ngraph::NodeVector new_ops; std::shared_ptr final_bias = bias; - if (bias->get_shape().size() >= 2) { - final_bias = std::make_shared(final_bias, ngraph::opset1::Constant::create(ngraph::element::i64, - ngraph::Shape{1}, {-1}), true); + if (bias_shape.size() >= 2) { + auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 }); + final_bias = ngraph::op::util::make_try_fold(final_bias, reshape_const, true); new_ops.push_back(final_bias); } - auto new_fc = std::make_shared(fc->input(0).get_source_output(), - fc->input(1).get_source_output(), + auto new_fc = std::make_shared(fc->input_value(0), + fc->input_value(1), final_bias, - fc->get_shape(), + fc->get_output_rank(), fc->get_output_type()); new_ops.push_back(new_fc); diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp index f36091330f9..2d5ef0d39b6 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp @@ -8,40 +8,99 @@ constexpr ngraph::NodeTypeInfo MKLDNNPlugin::FullyConnectedNode::type_info; MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output& A, const ngraph::Output& B, - const ngraph::Shape& output_shape, + const ngraph::Rank& output_rank, const ngraph::element::Type output_type) - : Op({A, B}), m_output_shape(output_shape), m_output_type(output_type) { - validate_and_infer_types(); + : Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) { + constructor_validate_and_infer_types(); } MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output& A, const ngraph::Output& B, const ngraph::Output& C, - const ngraph::Shape& output_shape, + const ngraph::Rank& output_rank, const ngraph::element::Type output_type) - : Op({A, B, C}), m_output_shape(output_shape), m_output_type(output_type) { - validate_and_infer_types(); + : Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) { + constructor_validate_and_infer_types(); } std::shared_ptr MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const { check_new_args_count(this, new_args); if (new_args.size() == 2) { - return std::make_shared(new_args.at(0), new_args.at(1), m_output_shape, m_output_type); + return std::make_shared(new_args.at(0), new_args.at(1), m_output_rank, m_output_type); } else if (new_args.size() == 3) { - return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_output_shape, m_output_type); + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_output_rank, m_output_type); } throw ngraph::ngraph_error("Unsupported number of arguments for FullyConnected operation"); } void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() { - m_output_size = m_output_shape.back(); - set_output_type(0, m_output_type == ngraph::element::undefined ? input_value(0).get_element_type() : m_output_type, m_output_shape); + const auto input_size = get_input_size(); + NODE_VALIDATION_CHECK(this, + input_size == 2 || input_size == 3, + "Number of inputs is incorrect. Current value is: ", + input_size, + ", expected: 2 or 3."); + + const auto output_size = get_output_size(); + NODE_VALIDATION_CHECK(this, + output_size == 1, + "Number of outputs is incorrect. Current value is: ", + output_size, + ", expected: 1."); + + // Weights shape: [O, I1, ..., Im]; + // O - output channels dimensions, Ik - input channels dimensions + const auto weights_pshape = get_input_partial_shape(1); + NODE_VALIDATION_CHECK(this, + weights_pshape.is_static(), + "Weights pshape must be static"); + const auto weights_shape = weights_pshape.to_shape(); + + const auto o_channels = weights_pshape[0]; + if (input_size == 3) { + const auto bias_shape = get_input_partial_shape(2); + const auto expected_bias_shape = ngraph::PartialShape{ o_channels }; + NODE_VALIDATION_CHECK(this, + bias_shape == expected_bias_shape, + "Bias shape is incorrect. Current value is: ", + bias_shape, + ", expected: ", + expected_bias_shape, + "."); + } + + // Activations shape: [B1, ..., Bn, I1, ..., Im]; + // Bi - batch dimensions, Ik - input channels dimensions + const auto activations_pshape = get_input_partial_shape(0); + + // Result shape: [B1, ..., Bn, O] + ngraph::PartialShape output_pshape; + if (activations_pshape.rank().is_static()) { + size_t output_channels_dimensions_count = weights_shape.size() - 1; + for (size_t i = 0; i < activations_pshape.size() - output_channels_dimensions_count; ++i) { + output_pshape.push_back(activations_pshape[i]); + } + output_pshape.push_back(o_channels); + + if (m_output_rank.is_static()) { + while (output_pshape.rank().get_length() < m_output_rank.get_length()) { + output_pshape.insert(output_pshape.begin(), 1); + } + } + } else { + output_pshape = ngraph::PartialShape::dynamic(); + } + + auto output_type = m_output_type == ngraph::element::undefined ? get_input_element_type(0) : m_output_type; + set_output_type(0, output_type, output_pshape); } bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) { - visitor.on_attribute("out-size", m_output_size); - visitor.on_attribute("out-shape", m_output_shape); + if (m_output_rank.is_static()) { + std::int64_t value = m_output_rank.get_length(); + visitor.on_attribute("out-rank", value); + } visitor.on_attribute("out-type", m_output_type); return true; } diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.hpp index 47c6509db58..ab24b302176 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.hpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.hpp @@ -19,13 +19,13 @@ public: FullyConnectedNode(const ngraph::Output &A, const ngraph::Output &B, - const ngraph::Shape &output_shape, + const ngraph::Rank& output_rank, const ngraph::element::Type output_type = ngraph::element::undefined); FullyConnectedNode(const ngraph::Output &A, const ngraph::Output &B, const ngraph::Output &C, - const ngraph::Shape &output_shape, + const ngraph::Rank& output_rank, const ngraph::element::Type output_type = ngraph::element::undefined); bool visit_attributes(ngraph::AttributeVisitor &visitor) override; @@ -34,13 +34,11 @@ public: std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector& new_args) const override; - size_t get_out_size() const { return m_output_size; } - + ngraph::Rank get_output_rank() const { return m_output_rank; } ngraph::element::Type get_output_type() const { return m_output_type; } private: - size_t m_output_size = 0; - ngraph::Shape m_output_shape = {}; + ngraph::Rank m_output_rank; ngraph::element::Type m_output_type; }; diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_1d_ops.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_1d_ops.cpp index db7d3ca6971..ebdf5cde844 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_1d_ops.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_1d_ops.cpp @@ -14,6 +14,7 @@ #include "transformations/utils/utils.hpp" +namespace Reshape1DOps { template std::shared_ptr convert(const ngraph::Output & data, std::shared_ptr node, ngraph::NodeVector &new_ops) { auto new_strides = node->get_strides(); @@ -171,13 +172,14 @@ ngraph::matcher_pass_callback get_callback() { return true; }; } +} // namespace Reshape1DOps NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DConvolution, "Reshape1DConvolution", 0); MKLDNNPlugin::Reshape1DConvolution::Reshape1DConvolution() { auto conv = ngraph::pattern::wrap_type(ngraph::pattern::has_static_shape()); auto m = std::make_shared(conv, "Reshape1DConvolution"); - this->register_matcher(m, get_callback()); + this->register_matcher(m, Reshape1DOps::get_callback()); } NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupConvolution", 0); @@ -185,7 +187,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupC MKLDNNPlugin::Reshape1DGroupConvolution::Reshape1DGroupConvolution() { auto group_conv = ngraph::pattern::wrap_type(ngraph::pattern::has_static_shape()); auto m = std::make_shared(group_conv, "Reshape1DGroupConvolution"); - this->register_matcher(m, get_callback()); + this->register_matcher(m, Reshape1DOps::get_callback()); } NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0); @@ -193,7 +195,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0); MKLDNNPlugin::Reshape1DAvgPool::Reshape1DAvgPool() { auto pool = ngraph::pattern::wrap_type(ngraph::pattern::has_static_shape()); auto m = std::make_shared(pool, "Reshape1DAvgPool"); - this->register_matcher(m, get_callback()); + this->register_matcher(m, Reshape1DOps::get_callback()); } NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0); @@ -201,5 +203,5 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0); MKLDNNPlugin::Reshape1DMaxPool::Reshape1DMaxPool() { auto pool = ngraph::pattern::wrap_type(ngraph::pattern::has_static_shape()); auto m = std::make_shared(pool, "Reshape1DMaxPool"); - this->register_matcher(m, get_callback()); + this->register_matcher(m, Reshape1DOps::get_callback()); } diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fc_fusion.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fc_fusion.cpp index b850bd98ae2..75614004147 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fc_fusion.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fc_fusion.cpp @@ -63,13 +63,13 @@ MKLDNNPlugin::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() { if (fc->get_input_size() == 2) { new_fc = std::make_shared(reshape->input_value(0), weightInput, - outShape, + ngraph::Rank(outShape.size()), fc->output(0).get_element_type()); } else if (fc->get_input_size() == 3) { new_fc = std::make_shared(reshape->input_value(0), weightInput, fc->input_value(2), - outShape, + ngraph::Rank(outShape.size()), fc->output(0).get_element_type()); } else { return false; diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fully_connected.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fully_connected.cpp index f140f44e74e..2446e7694a8 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fully_connected.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/reshape_fully_connected.cpp @@ -5,60 +5,66 @@ #include "reshape_fully_connected.hpp" #include "op/fully_connected.hpp" #include +#include #include #include -#include #include +#include +#include NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnected, "ReshapeFullyConnected", 0); MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() { - ngraph::OutputVector twoInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input()}; - ngraph::OutputVector threeInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(), + ngraph::OutputVector twoInputs = { + ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}; + ngraph::OutputVector threeInputs = { + ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input()}; - auto fcTwoInputs = ngraph::pattern::wrap_type(twoInputs, ngraph::pattern::has_static_shape()); - auto fcThreeInputs = ngraph::pattern::wrap_type(threeInputs, ngraph::pattern::has_static_shape()); + auto fcTwoInputs = ngraph::pattern::wrap_type(twoInputs, ngraph::pattern::has_static_rank()); + auto fcThreeInputs = ngraph::pattern::wrap_type(threeInputs, ngraph::pattern::has_static_rank()); const auto fcTwoOrThreeInputs = std::make_shared(ngraph::OutputVector{fcTwoInputs, fcThreeInputs}); ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { - auto fc = std::dynamic_pointer_cast (m.get_match_root()); + auto fc = std::dynamic_pointer_cast(m.get_match_root()); if (!fc || transformation_callback(fc)) { return false; } - auto input_shape = fc->input_value(0).get_shape(); - auto output_shape = fc->get_shape(); + auto fc_input_shape = fc->get_input_partial_shape(0); + auto input_rank = fc_input_shape.rank().get_length(); + auto output_shape = fc->get_output_partial_shape(0); - if (input_shape.size() == 2) { + if (input_rank == 2 || input_rank == 0) { return false; } ngraph::NodeVector new_ops; - - std::vector reshape_shape{-1, static_cast(input_shape.back())}; - auto reshape = std::make_shared(fc->input_value(0), - ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, reshape_shape), true); + int64_t K = *(fc->get_input_shape(1).rbegin()); // requested 2nd input with static shape in the matcher + auto reshape = std::make_shared( + fc->input_value(0), ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector{-1, K}), false); + if (reshape->get_output_partial_shape(0).rank().is_dynamic()) + return false; new_ops.push_back(reshape); reshape->set_friendly_name(fc->get_friendly_name() + "/Reshape"); // Calculate output shape for new FullyConnected layer // [I, K] * [O, K] = [I, O] - auto I = reshape->get_shape()[0]; - auto O = fc->input_value(1).get_shape()[0]; - ngraph::Shape output_shape_new{I, O}; + auto I = reshape->get_output_partial_shape(0)[0]; + auto O = fc->get_input_partial_shape(1)[0]; + ngraph::PartialShape output_shape_new{I, O}; std::shared_ptr fc_new; if (fc->get_input_size() == 2) { fc_new = std::make_shared(reshape, fc->input_value(1), - output_shape_new, + output_shape_new.rank(), fc->get_output_type()); } else if (fc->get_input_size() == 3) { fc_new = std::make_shared(reshape, fc->input_value(1), fc->input_value(2), - output_shape_new, + output_shape_new.rank(), fc->get_output_type()); } else { return false; @@ -66,7 +72,29 @@ MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() { new_ops.push_back(fc_new); if (output_shape != output_shape_new) { - auto reshape_output = ngraph::op::util::reshapeTo(fc_new, output_shape); + auto I_idxs = std::vector(input_rank - 1); + std::iota(I_idxs.begin(), I_idxs.end(), 0); + auto A_input_shape = ngraph::op::util::make_try_fold(fc->input_value(0)); + auto B_input_shape = ngraph::op::util::make_try_fold(fc->input_value(1)); + auto I_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(A_input_shape, {I_idxs}); + auto O_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(B_input_shape, {0}); + ngraph::OutputVector output_shape_dims{I_node, O_node}; + + const auto original_rank = fc->get_output_rank(); + NGRAPH_CHECK(original_rank.is_static()); + if (input_rank < original_rank.get_length()) { + const size_t const_shape_value = original_rank.get_length() - input_rank; + output_shape_dims.insert( + output_shape_dims.begin(), ngraph::opset1::Constant::create(I_node->get_element_type(), { const_shape_value }, { 1 })); + } + + auto reshape_output_shape = ngraph::op::util::make_try_fold(output_shape_dims, 0); + auto reshape_output = std::make_shared(fc_new, reshape_output_shape, false); + new_ops.push_back(A_input_shape); + new_ops.push_back(B_input_shape); + new_ops.push_back(I_node); + new_ops.push_back(O_node); + new_ops.push_back(reshape_output_shape); new_ops.push_back(reshape_output); reshape_output->set_friendly_name(fc->get_friendly_name()); fc_new->set_friendly_name(fc->get_friendly_name() + "/FC"); diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp deleted file mode 100644 index c04d2e8364b..00000000000 --- a/inference-engine/src/transformations/include/transformations/smart_reshape/utils.hpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2018-2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include - -namespace ngraph { -namespace op { -namespace util { - -std::shared_ptr node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr& shape_node, - const std::vector& indices) { - return std::make_shared(shape_node, - ngraph::opset4::Constant::create(ngraph::element::i64, {indices.size()}, indices), - ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0})); -} - -std::shared_ptr node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output& shape_source, - const std::vector& indices) { - const auto & shape_node = std::make_shared(shape_source); - return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices); -} - -} // namespace util -} // namespace op -} // namespace ngraph \ No newline at end of file diff --git a/inference-engine/src/transformations/include/transformations/utils/utils.hpp b/inference-engine/src/transformations/include/transformations/utils/utils.hpp index 0ac77626253..c3672fb9961 100644 --- a/inference-engine/src/transformations/include/transformations/utils/utils.hpp +++ b/inference-engine/src/transformations/include/transformations/utils/utils.hpp @@ -142,6 +142,12 @@ Output eltwise_fold(const Output & input0, const Output & inpu } TRANSFORMATIONS_API std::vector> get_node_target_inputs(const std::shared_ptr& node); + +TRANSFORMATIONS_API std::shared_ptr node_to_get_shape_value_of_indices_from_shape_node( + const std::shared_ptr& shape_node, const std::vector& indices); + +TRANSFORMATIONS_API std::shared_ptr node_to_get_shape_value_of_indices_from_shape_source( + const ngraph::Output& shape_source, const std::vector& indices); } // namespace util } // namespace op } // namespace ngraph diff --git a/inference-engine/src/transformations/src/transformations/utils/utils.cpp b/inference-engine/src/transformations/src/transformations/utils/utils.cpp index 31d30031b56..984e34dd47f 100644 --- a/inference-engine/src/transformations/src/transformations/utils/utils.cpp +++ b/inference-engine/src/transformations/src/transformations/utils/utils.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace ngraph { @@ -152,6 +153,20 @@ std::vector> get_node_target_inputs(const std::shared_ptr& nod return result; } +std::shared_ptr node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr& shape_node, + const std::vector& indices) { + return make_try_fold( + shape_node, + v0::Constant::create(ngraph::element::i64, {indices.size()}, indices), + v0::Constant::create(ngraph::element::i64, {}, {0})); +} + +std::shared_ptr node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output& shape_source, + const std::vector& indices) { + const auto & shape_node = make_try_fold(shape_source); + return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices); +} + } // namespace util } // namespace op } // namespace ngraph diff --git a/inference-engine/tests/unit/cpu/CMakeLists.txt b/inference-engine/tests/unit/cpu/CMakeLists.txt index f90dd34bdd7..a66c27aa5c2 100644 --- a/inference-engine/tests/unit/cpu/CMakeLists.txt +++ b/inference-engine/tests/unit/cpu/CMakeLists.txt @@ -22,6 +22,7 @@ addIeTargetTest( inference_engine_lp_transformations ov_shape_inference inference_engine_s + unitTestUtils ADD_CPPLINT LABELS CPU diff --git a/inference-engine/tests/unit/cpu/convert_matmul_test.cpp b/inference-engine/tests/unit/cpu/convert_matmul_test.cpp new file mode 100644 index 00000000000..79fb1d8a387 --- /dev/null +++ b/inference-engine/tests/unit/cpu/convert_matmul_test.cpp @@ -0,0 +1,567 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; +using namespace MKLDNNPlugin; + +TEST(TransformationTests, ConvertMatMulToFCTest1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 2 }, { 1 }); + auto matmul = std::make_shared(input1, input2, true, false); + + f = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }); + auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 }); + auto transpose = std::make_shared(input1, transpose_constant); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 }); + auto matmul = std::make_shared(transpose, input2, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest2) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto input2 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 1}); + auto matmul = std::make_shared(input1, input2, false, false); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto input2 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 1}); + auto matmul = std::make_shared(input1, input2, false, false); + + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest3) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest4) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest5) { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 }); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }, { 1 }); + auto matmul = std::make_shared(input1, input2, false, true); + + auto f = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + ASSERT_NO_THROW(m.run_passes(f)); +} + +TEST(TransformationTests, ConvertMatMulToFCTest6) { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 }); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 2 }, { 1 }); + auto matmul = std::make_shared(input1, input2, false, true); + + auto f = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + ASSERT_NO_THROW(m.run_passes(f)); + ASSERT_NO_THROW(check_rt_info(f)); +} + +TEST(TransformationTests, ConvertMatMulToFCTest7) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1}); + auto reshape_begin = std::make_shared( + input1, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector{-1, 2}), false); + auto fc = std::make_shared(reshape_begin, input2, ngraph::Rank(2)); + auto reshape_end = ngraph::op::util::reshapeTo(fc, ngraph::Shape{3, 2, 3}); + + f_ref = std::make_shared(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest8) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1}); + + auto reshape_begin = std::make_shared( + input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 2}), false); + + auto fc = std::make_shared(reshape_begin, input2, ngraph::Rank(2)); + auto a_shape = std::make_shared(input1); + + auto I = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(a_shape, {0, 1}); + auto O = ngraph::opset1::Constant::create(ngraph::element::i64, { 1 }, { 3 }); + auto output_shape = std::make_shared(ngraph::OutputVector{I, O}, 0); + auto reshape_end = std::make_shared(fc, output_shape, false); + + f_ref = std::make_shared(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest9) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + + ngraph::pass::Manager m; + auto pass_config = m.get_pass_config(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 2, 2}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1}); + auto matmul = std::make_shared(input1, input2, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest10) { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 }); + auto matmul = std::make_shared(input1, input2, false, true); + + auto f = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + ASSERT_NO_THROW(m.run_passes(f)); +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 3072}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1}); + auto fc = std::make_shared(input1, weights, ngraph::Rank(3)); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto add = std::make_shared(fc, const_bias); + + f = std::make_shared(ngraph::NodeVector{add}, ngraph::ParameterVector{input1}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass([](std::shared_ptr f) { + check_rt_info(f); + }); + ASSERT_NO_THROW(manager.run_passes(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 3072}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto fc = std::make_shared(input1, weights, bias, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest2) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1}); + auto fc = std::make_shared(input1, weights, ngraph::Rank(3)); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto add = std::make_shared(fc, const_bias); + + f = std::make_shared(ngraph::NodeVector{add}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass([](std::shared_ptr f) { + check_rt_info(f); + }); + ASSERT_NO_THROW(manager.run_passes(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto fc = std::make_shared(input1, weights, bias, ngraph::Rank(3)); + + f_ref = std::make_shared(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest3) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1}); + auto fc = std::make_shared(input1, weights, ngraph::Rank(2)); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1}); + auto add = std::make_shared(fc, const_bias); + + f = std::make_shared(ngraph::NodeVector{add}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass([](std::shared_ptr f) { + check_rt_info(f); + }); + ASSERT_NO_THROW(manager.run_passes(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto fc = std::make_shared(input1, weights, bias, ngraph::Rank(2)); + + f_ref = std::make_shared(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest4) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, 128}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1}); + auto fc = std::make_shared(input1, weights, ngraph::Rank(2)); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1}); + auto add = std::make_shared(fc, const_bias); + + f = std::make_shared(ngraph::NodeVector{add}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass([](std::shared_ptr f) { + check_rt_info(f); + }); + ASSERT_NO_THROW(manager.run_passes(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, 128}); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1}); + auto fc = std::make_shared(input1, weights, bias, ngraph::Rank(2)); + + f_ref = std::make_shared(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest5) { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 786, 128 }, { 1 }); + auto fc = std::make_shared(input1, weights, ngraph::Rank(2)); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 }); + auto add = std::make_shared(fc, const_bias); + + auto f = std::make_shared(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 }); + + ngraph::pass::Manager manager; + manager.register_pass(); + ASSERT_NO_THROW(manager.run_passes(f)); +} + +TEST(TransformationTests, FullyConnectedBiasFusionTest6) { + auto input1 = std::make_shared(ngraph::element::u8, ngraph::PartialShape{ -1, -1 }); + auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 }); + auto fc = std::make_shared(input1, weights, ngraph::Rank(2), ngraph::element::f32); + + auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 }); + auto add = std::make_shared(fc, const_bias); + + auto f = std::make_shared(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 }); + + ngraph::pass::Manager manager; + manager.register_pass(); + ASSERT_NO_THROW(manager.run_passes(f)); +} + +TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{5, 2, 3}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{5, 2, 3}); + auto reshape_1 = std::make_shared(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1}); + auto matmul = std::make_shared(reshape_1, input2, ngraph::Rank(2)); + auto reshape_out = std::make_shared(matmul, ngraph::opset1::Constant::create(ngraph::element::i64, {4}, {1, 5, 2, 2}), false); + f_ref = std::make_shared(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_2) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 2, 3 }); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 }); + auto matmul = std::make_shared(input1, weights, false, true); + + f = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 2, 3 }); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 }); + auto matmul = std::make_shared(input1, weights, ngraph::Rank(2)); + f_ref = std::make_shared(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 }); + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2, 3 }, { 1 }); + auto matmul = std::make_shared(input1, weights, false, true); + auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 1, 2 }, { 1 }); + auto add = std::make_shared(matmul, biases); + + f = std::make_shared(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 }); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 }); + auto reshape_before_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 2 }, { -1, 3 }); + auto reshape_1 = std::make_shared(input1, reshape_before_const, false); + + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 }); + auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2 }, { 1 }); + auto matmul = std::make_shared(reshape_1, weights, biases, ngraph::Rank(2)); + + auto reshape_after_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, 5, 2, 2 }); + auto reshape_out = std::make_shared(matmul, reshape_after_const, false); + f_ref = std::make_shared(ngraph::NodeVector{ reshape_out }, ngraph::ParameterVector{ input1 }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_dynamic) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3}); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1}); + auto matmul = std::make_shared(input1, input2, false, true); + + f = std::make_shared(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input1 = std::make_shared(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3}); + auto reshape_1 = std::make_shared(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false); + auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1}); + auto matmul = std::make_shared(reshape_1, input2, ngraph::Rank(2)); + + auto shape_of = std::make_shared(input1); + auto gather = std::make_shared( + shape_of, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {0, 1}), ngraph::opset1::Constant::create(ngraph::element::i64, {}, {0})); + auto concat = std::make_shared(ngraph::OutputVector{ + ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {1}), + gather, + ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {2}), + }, 0); + auto reshape_out = std::make_shared(matmul, concat, false); + f_ref = std::make_shared(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/ngraph/core/include/openvino/core/partial_shape.hpp b/ngraph/core/include/openvino/core/partial_shape.hpp index a2ab80ff616..0464b891a09 100644 --- a/ngraph/core/include/openvino/core/partial_shape.hpp +++ b/ngraph/core/include/openvino/core/partial_shape.hpp @@ -304,6 +304,35 @@ public: OPENVINO_ASSERT(rank().is_static()); return m_dimensions.size(); } + /// \brief Returns a read/write iterator that points to the inserted element in the shape. + iterator insert(iterator position, const Dimension& val) { + m_rank_is_static = true; + m_shape_type = ShapeType::SHAPE_IS_UPDATED; + return m_dimensions.insert(position, val); + } + /// \brief Inserts count copies of the value before position + void insert(iterator position, size_t n, const Dimension& val) { + m_dimensions.insert(position, n, val); + m_rank_is_static = true; + m_shape_type = ShapeType::SHAPE_IS_UPDATED; + } + /// \brief Inserts elements from range [first, last) before position + template + void insert(iterator position, InputIterator first, InputIterator last) { + m_dimensions.insert(position, first, last); + m_rank_is_static = true; + m_shape_type = ShapeType::SHAPE_IS_UPDATED; + } + /// \brief Requests that the dimensions vector capacity be enough to contain n elements + void reserve(size_t n) { + m_dimensions.reserve(n); + } + /// \brief push element to the end of partial shape + void push_back(const Dimension& val) { + m_dimensions.push_back(val); + m_rank_is_static = true; + m_shape_type = ShapeType::SHAPE_IS_UPDATED; + } private: // Private constructor for PartialShape::dynamic(). diff --git a/ngraph/core/src/pass/smart_reshape/matmul_sr.cpp b/ngraph/core/src/pass/smart_reshape/matmul_sr.cpp index 3532694536b..a4241bd3fba 100644 --- a/ngraph/core/src/pass/smart_reshape/matmul_sr.cpp +++ b/ngraph/core/src/pass/smart_reshape/matmul_sr.cpp @@ -13,7 +13,6 @@ #include #include "itt.hpp" -#include "transformations/smart_reshape/utils.hpp" bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& pattern_to_output, const std::shared_ptr& matmul_label, @@ -35,7 +34,10 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& const auto& raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1); const auto& idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank); - const auto& C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx); + const auto& C = std::make_shared( + std::make_shared(shape_source), + ngraph::opset4::Constant::create(ngraph::element::i64, {idx.size()}, idx), + ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0})); const auto& N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1}); const auto& pattern_vector = reshape_is_A_input ? (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C}))