[dynamism][CPU] MatMul transformations: dynamic shapes support (#7582)

* [TESTS] CPUUnitTests: unitTestUtils added to CMakeLists

* Reshape1D: functions moved to namespace

* [nGraph] PartialShape extending

* CPU specific transsformations for FullyConnected enabled to work with dynamic shapes

* [Transformations] FC transformations: dynamic shapes support

* [Transformations] FCBiasFusion: removed legacy check on bcast on weights

* [TESTS] FC transformations: tests

* [Transformations] SmartReshape: avoid tranformation utils methods

* codestyle fix

* codestyle fix

* [TESTS] MatMulTransformation tests: compilation error fix

* [CPU] FullyConnected: shape inference

* postreview fixes

* [CPU] FullyConnected: shape inference fix

* [nGraph] PShape dimensions insertion fixed

Co-authored-by: Stepyreva, Evgenya <evgenya.stepyreva@intel.com>
This commit is contained in:
Vladislav Golubev 2021-10-19 10:59:29 +03:00 committed by GitHub
parent 18aaaa79a0
commit c84db94697
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 874 additions and 176 deletions

View File

@ -4,7 +4,6 @@
#include "convert_matmul_to_fc.hpp"
#include "op/fully_connected.hpp"
#include <numeric>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@ -13,25 +12,39 @@
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertMatMulToFC, "ConvertMatMulToFC", 0);
MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape());
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
auto matmul_m = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(pattern_map.at(matmul_m).get_node_shared_ptr());
if (!matmul || transformation_callback(matmul)) {
return false;
}
auto input_a = matmul->input(0).get_source_output();
auto input_b = matmul->input(1).get_source_output();
// fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations.
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b.
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);
auto shape_a = input_a.get_shape();
auto shape_b = input_b.get_shape();
auto output_shape = matmul->get_shape();
auto shape_a = fc_input_a.get_partial_shape();
auto shape_b = fc_input_b.get_partial_shape();
NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher
auto rank_a = shape_a.rank().get_length();
auto rank_b = shape_b.rank().get_length();
// Transformation to FC is not supported for 1D second input
if (shape_b.size() == 1) {
if (rank_b == 1) {
return false;
}
// Check that if second inputs is Constant path and it's shape without ones dimensions has length <= 2
// we replace MatMul with FullyConnected operation.
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
return false;
}
@ -42,15 +55,17 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
* for example: [2, 32, 64] [3, 64, 64] it will raise an exception.
*/
auto get_aligned_shapes = [shape_a, shape_b, &matmul]() -> std::pair<ngraph::Shape, ngraph::Shape> {
ngraph::Shape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
size_t max_size = std::max(shape_a_aligned.size(), shape_b_aligned.size());
for (size_t i = 0, cnt = max_size - shape_a_aligned.size(); i < cnt; ++i)
auto get_aligned_shapes = [shape_a, shape_b, rank_a, rank_b, &matmul]() -> std::tuple<bool, ngraph::PartialShape, ngraph::PartialShape> {
ngraph::PartialShape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
size_t max_size = std::max(rank_a, rank_b);
for (size_t i = 0, cnt = max_size - rank_a; i < cnt; ++i) {
shape_a_aligned.insert(shape_a_aligned.begin(), 1);
for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i)
}
for (size_t i = 0, cnt = max_size - rank_b; i < cnt; ++i) {
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
}
if (matmul->get_transpose_a() && shape_a.size() != 1) {
if (matmul->get_transpose_a() && rank_a != 1) {
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
}
if (matmul->get_transpose_b()) {
@ -58,16 +73,25 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
}
for (size_t i = 0; i < max_size - 2; ++i) {
if (shape_a_aligned[i] != shape_b_aligned[i] && shape_a_aligned[i] > 1 && shape_b_aligned[i] > 1) {
auto a_dim = shape_a_aligned[i], b_dim = shape_b_aligned[i];
if (a_dim.is_dynamic()) {
if (b_dim == 1) {
shape_a_aligned[i] = shape_b_aligned[i] = a_dim;
} else {
return std::make_tuple(false, ngraph::PartialShape{shape_a_aligned}, ngraph::PartialShape{shape_b_aligned});
}
continue;
}
// both dimensions are static
if (a_dim != b_dim && a_dim.get_length() > 1 && b_dim.get_length() > 1) {
std::ostringstream stream;
stream << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
throw ngraph::ngraph_error(stream.str());
}
size_t max_value = std::max(shape_a_aligned[i], shape_b_aligned[i]);
size_t max_value = std::max(a_dim.get_length(), b_dim.get_length());
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
}
return {shape_a_aligned, shape_b_aligned};
return std::make_tuple(true, shape_a_aligned, shape_b_aligned);
};
/*
@ -78,76 +102,68 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
* order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
*/
auto create_transpose = [this](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
auto create_transpose = [this](const ngraph::Output<ngraph::Node>& node, const std::string& transpose_name) {
auto rank = node.get_partial_shape().rank();
std::vector<size_t> transpose_order(rank.get_length());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose = ngraph::pass::MatcherPass::register_new_node<ngraph::opset1::Transpose>(
node, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{transpose_order.size()}, transpose_order));
auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order);
auto transpose = ngraph::op::util::make_try_fold<ngraph::opset1::Transpose>(node, transpose_const);
if (!ngraph::is_type<ngraph::opset1::Constant>(transpose)) {
MatcherPass::register_new_node(transpose);
}
transpose->set_friendly_name(transpose_name);
return transpose;
};
// fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations.
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and
// fc_input_b updated.
auto fc_input_a = input_a, fc_input_b = input_b;
// vector of new nGraph operations
ngraph::NodeVector new_ops;
// Check that if second inputs is Constant operation and it's shape without ones dimensions has length <= 2
// we replace MatMul with FullyConnected operation.
// Otherwise we replace MatMul with Gemm.
if ((std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(fc_input_b.get_node_shared_ptr())) &&
std::count_if(shape_b.begin(), shape_b.end(), [](size_t x) { return x != 1; }) <= 2) {
ngraph::Shape shape_a_aligned, shape_b_aligned;
std::tie(shape_a_aligned, shape_b_aligned) = get_aligned_shapes();
if (shape_a_aligned.size() < 2 || shape_b_aligned.size() < 2) {
throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent.");
}
// Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
// to FullyConnected representation: [I, K] * [K, O] = [I, O]
size_t K = *(shape_a_aligned.end() - 1);
ngraph::Shape B(shape_a_aligned.begin(), shape_a_aligned.end() - 2);
// Weights normalization
if (!matmul->get_transpose_b()) {
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
if (shape_b.size() != 2) {
auto reshape_shape =
ngraph::opset1::Constant::create<int64_t>(ngraph::element::i64, ngraph::Shape{2}, {-1ll, static_cast<int64_t>(K)});
fc_input_b = std::make_shared<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, true);
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
// Input normalization
if (matmul->get_transpose_a() && shape_a.size() != 1) {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
new_ops.push_back(fc_input_a.get_node_shared_ptr());
}
// Create FullyConnected
auto fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc_input_a, fc_input_b, output_shape, matmul->output(0).get_element_type());
fc->set_friendly_name(matmul->get_friendly_name());
new_ops.push_back(fc);
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, fc);
return true;
bool success = true;
ngraph::PartialShape shape_a_aligned, shape_b_aligned;
std::tie(success, shape_a_aligned, shape_b_aligned) = get_aligned_shapes();
if (!success) {
return false;
}
return false;
auto aligned_a_rank = shape_a_aligned.rank(), aligned_b_rank = shape_b_aligned.rank();
if (aligned_a_rank.is_dynamic() || aligned_b_rank.is_dynamic() || aligned_a_rank.get_length() < 2 || aligned_b_rank.get_length() < 2) {
throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent.");
}
// Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
// to FullyConnected representation: [I, K] * [K, O] = [I, O]
// Weights normalization
if (!matmul->get_transpose_b()) {
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
if (rank_b != 2) {
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
// Input normalization
if (matmul->get_transpose_a() && rank_a != 1) {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
new_ops.push_back(fc_input_a.get_node_shared_ptr());
}
auto output_rank = matmul->get_output_partial_shape(0).rank();
// Create FullyConnected
auto fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank, matmul->get_output_element_type(0));
fc->set_friendly_name(matmul->get_friendly_name());
new_ops.push_back(fc);
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, fc);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_m, "ConvertMatMulToFC");
this->register_matcher(m, callback);
}

View File

@ -9,53 +9,58 @@
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0);
MKLDNNPlugin::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
auto m_fc = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>([](ngraph::Output<ngraph::Node> output) {
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_shape()(output);
auto input = ngraph::pattern::any_input();
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto m_fc = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>({ input, weights }, [](ngraph::Output<ngraph::Node> output) {
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output);
});
auto m_bias = ngraph::pattern::any_input();
auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto m_add = ngraph::pattern::wrap_type<ngraph::opset1::Add>({m_fc, m_bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto& pattern_to_output = m.get_pattern_value_map();
auto add = pattern_to_output[m_add].get_node_shared_ptr();
auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode>(pattern_to_output[m_fc].get_node_shared_ptr());
if (!fc) {
if (!fc || transformation_callback(fc)) {
return false;
}
if (auto bcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(bias)) {
bias = bcast->input_value(0).get_node_shared_ptr();
}
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(bias)) {
return false;
}
ngraph::Shape bias_shape(bias->get_shape());
ngraph::Shape output_shape(fc->get_shape());
size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), size_t{1}, std::multiplies<int64_t>());
if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
ngraph::PartialShape output_shape(fc->get_output_partial_shape(0));
size_t bias_size = ngraph::shape_size(bias_shape);
auto rank = output_shape.rank().get_length();
if (rank == 0 || output_shape[rank - 1].is_dynamic()) {
return false;
}
if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) {
return false;
}
ngraph::NodeVector new_ops;
std::shared_ptr<ngraph::Node> final_bias = bias;
if (bias->get_shape().size() >= 2) {
final_bias = std::make_shared<ngraph::opset1::Reshape>(final_bias, ngraph::opset1::Constant::create(ngraph::element::i64,
ngraph::Shape{1}, {-1}), true);
if (bias_shape.size() >= 2) {
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 });
final_bias = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(final_bias, reshape_const, true);
new_ops.push_back(final_bias);
}
auto new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc->input(0).get_source_output(),
fc->input(1).get_source_output(),
auto new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc->input_value(0),
fc->input_value(1),
final_bias,
fc->get_shape(),
fc->get_output_rank(),
fc->get_output_type());
new_ops.push_back(new_fc);

View File

@ -8,40 +8,99 @@ constexpr ngraph::NodeTypeInfo MKLDNNPlugin::FullyConnectedNode::type_info;
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
const ngraph::Output<Node>& B,
const ngraph::Shape& output_shape,
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B}), m_output_shape(output_shape), m_output_type(output_type) {
validate_and_infer_types();
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
}
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
const ngraph::Output<Node>& B,
const ngraph::Output<Node>& C,
const ngraph::Shape& output_shape,
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B, C}), m_output_shape(output_shape), m_output_type(output_type) {
validate_and_infer_types();
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
check_new_args_count(this, new_args);
if (new_args.size() == 2) {
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), m_output_shape, m_output_type);
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), m_output_rank, m_output_type);
} else if (new_args.size() == 3) {
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_shape, m_output_type);
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_rank, m_output_type);
}
throw ngraph::ngraph_error("Unsupported number of arguments for FullyConnected operation");
}
void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
m_output_size = m_output_shape.back();
set_output_type(0, m_output_type == ngraph::element::undefined ? input_value(0).get_element_type() : m_output_type, m_output_shape);
const auto input_size = get_input_size();
NODE_VALIDATION_CHECK(this,
input_size == 2 || input_size == 3,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected: 2 or 3.");
const auto output_size = get_output_size();
NODE_VALIDATION_CHECK(this,
output_size == 1,
"Number of outputs is incorrect. Current value is: ",
output_size,
", expected: 1.");
// Weights shape: [O, I1, ..., Im];
// O - output channels dimensions, Ik - input channels dimensions
const auto weights_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
weights_pshape.is_static(),
"Weights pshape must be static");
const auto weights_shape = weights_pshape.to_shape();
const auto o_channels = weights_pshape[0];
if (input_size == 3) {
const auto bias_shape = get_input_partial_shape(2);
const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
NODE_VALIDATION_CHECK(this,
bias_shape == expected_bias_shape,
"Bias shape is incorrect. Current value is: ",
bias_shape,
", expected: ",
expected_bias_shape,
".");
}
// Activations shape: [B1, ..., Bn, I1, ..., Im];
// Bi - batch dimensions, Ik - input channels dimensions
const auto activations_pshape = get_input_partial_shape(0);
// Result shape: [B1, ..., Bn, O]
ngraph::PartialShape output_pshape;
if (activations_pshape.rank().is_static()) {
size_t output_channels_dimensions_count = weights_shape.size() - 1;
for (size_t i = 0; i < activations_pshape.size() - output_channels_dimensions_count; ++i) {
output_pshape.push_back(activations_pshape[i]);
}
output_pshape.push_back(o_channels);
if (m_output_rank.is_static()) {
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
output_pshape.insert(output_pshape.begin(), 1);
}
}
} else {
output_pshape = ngraph::PartialShape::dynamic();
}
auto output_type = m_output_type == ngraph::element::undefined ? get_input_element_type(0) : m_output_type;
set_output_type(0, output_type, output_pshape);
}
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
visitor.on_attribute("out-size", m_output_size);
visitor.on_attribute("out-shape", m_output_shape);
if (m_output_rank.is_static()) {
std::int64_t value = m_output_rank.get_length();
visitor.on_attribute("out-rank", value);
}
visitor.on_attribute("out-type", m_output_type);
return true;
}

View File

@ -19,13 +19,13 @@ public:
FullyConnectedNode(const ngraph::Output<Node> &A,
const ngraph::Output<Node> &B,
const ngraph::Shape &output_shape,
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type = ngraph::element::undefined);
FullyConnectedNode(const ngraph::Output<Node> &A,
const ngraph::Output<Node> &B,
const ngraph::Output<Node> &C,
const ngraph::Shape &output_shape,
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type = ngraph::element::undefined);
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
@ -34,13 +34,11 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
size_t get_out_size() const { return m_output_size; }
ngraph::Rank get_output_rank() const { return m_output_rank; }
ngraph::element::Type get_output_type() const { return m_output_type; }
private:
size_t m_output_size = 0;
ngraph::Shape m_output_shape = {};
ngraph::Rank m_output_rank;
ngraph::element::Type m_output_type;
};

View File

@ -14,6 +14,7 @@
#include "transformations/utils/utils.hpp"
namespace Reshape1DOps {
template <class BaseOp>
std::shared_ptr<ngraph::Node> convert(const ngraph::Output<ngraph::Node> & data, std::shared_ptr<BaseOp> node, ngraph::NodeVector &new_ops) {
auto new_strides = node->get_strides();
@ -171,13 +172,14 @@ ngraph::matcher_pass_callback get_callback() {
return true;
};
}
} // namespace Reshape1DOps
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DConvolution, "Reshape1DConvolution", 0);
MKLDNNPlugin::Reshape1DConvolution::Reshape1DConvolution() {
auto conv = ngraph::pattern::wrap_type<ngraph::opset1::Convolution>(ngraph::pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
this->register_matcher(m, get_callback());
this->register_matcher(m, Reshape1DOps::get_callback());
}
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupConvolution", 0);
@ -185,7 +187,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupC
MKLDNNPlugin::Reshape1DGroupConvolution::Reshape1DGroupConvolution() {
auto group_conv = ngraph::pattern::wrap_type<ngraph::opset1::GroupConvolution>(ngraph::pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(group_conv, "Reshape1DGroupConvolution");
this->register_matcher(m, get_callback());
this->register_matcher(m, Reshape1DOps::get_callback());
}
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
@ -193,7 +195,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
MKLDNNPlugin::Reshape1DAvgPool::Reshape1DAvgPool() {
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::AvgPool>(ngraph::pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
this->register_matcher(m, get_callback());
this->register_matcher(m, Reshape1DOps::get_callback());
}
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
@ -201,5 +203,5 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
MKLDNNPlugin::Reshape1DMaxPool::Reshape1DMaxPool() {
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::MaxPool>(ngraph::pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
this->register_matcher(m, get_callback());
this->register_matcher(m, Reshape1DOps::get_callback());
}

View File

@ -63,13 +63,13 @@ MKLDNNPlugin::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
if (fc->get_input_size() == 2) {
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
weightInput,
outShape,
ngraph::Rank(outShape.size()),
fc->output(0).get_element_type());
} else if (fc->get_input_size() == 3) {
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
weightInput,
fc->input_value(2),
outShape,
ngraph::Rank(outShape.size()),
fc->output(0).get_element_type());
} else {
return false;

View File

@ -5,60 +5,66 @@
#include "reshape_fully_connected.hpp"
#include "op/fully_connected.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <numeric>
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnected, "ReshapeFullyConnected", 0);
MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() {
ngraph::OutputVector twoInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input()};
ngraph::OutputVector threeInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(),
ngraph::OutputVector twoInputs = {
ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())};
ngraph::OutputVector threeInputs = {
ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input()};
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_shape());
auto fcThreeInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(threeInputs, ngraph::pattern::has_static_shape());
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_rank());
auto fcThreeInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(threeInputs, ngraph::pattern::has_static_rank());
const auto fcTwoOrThreeInputs = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fcTwoInputs, fcThreeInputs});
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode> (m.get_match_root());
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode>(m.get_match_root());
if (!fc || transformation_callback(fc)) {
return false;
}
auto input_shape = fc->input_value(0).get_shape();
auto output_shape = fc->get_shape();
auto fc_input_shape = fc->get_input_partial_shape(0);
auto input_rank = fc_input_shape.rank().get_length();
auto output_shape = fc->get_output_partial_shape(0);
if (input_shape.size() == 2) {
if (input_rank == 2 || input_rank == 0) {
return false;
}
ngraph::NodeVector new_ops;
std::vector<int64_t> reshape_shape{-1, static_cast<int64_t>(input_shape.back())};
auto reshape = std::make_shared<ngraph::opset1::Reshape>(fc->input_value(0),
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, reshape_shape), true);
int64_t K = *(fc->get_input_shape(1).rbegin()); // requested 2nd input with static shape in the matcher
auto reshape = std::make_shared<ngraph::opset1::Reshape>(
fc->input_value(0), ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{-1, K}), false);
if (reshape->get_output_partial_shape(0).rank().is_dynamic())
return false;
new_ops.push_back(reshape);
reshape->set_friendly_name(fc->get_friendly_name() + "/Reshape");
// Calculate output shape for new FullyConnected layer
// [I, K] * [O, K] = [I, O]
auto I = reshape->get_shape()[0];
auto O = fc->input_value(1).get_shape()[0];
ngraph::Shape output_shape_new{I, O};
auto I = reshape->get_output_partial_shape(0)[0];
auto O = fc->get_input_partial_shape(1)[0];
ngraph::PartialShape output_shape_new{I, O};
std::shared_ptr<ngraph::Node> fc_new;
if (fc->get_input_size() == 2) {
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
fc->input_value(1),
output_shape_new,
output_shape_new.rank(),
fc->get_output_type());
} else if (fc->get_input_size() == 3) {
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
fc->input_value(1),
fc->input_value(2),
output_shape_new,
output_shape_new.rank(),
fc->get_output_type());
} else {
return false;
@ -66,7 +72,29 @@ MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() {
new_ops.push_back(fc_new);
if (output_shape != output_shape_new) {
auto reshape_output = ngraph::op::util::reshapeTo(fc_new, output_shape);
auto I_idxs = std::vector<size_t>(input_rank - 1);
std::iota(I_idxs.begin(), I_idxs.end(), 0);
auto A_input_shape = ngraph::op::util::make_try_fold<ngraph::opset7::ShapeOf>(fc->input_value(0));
auto B_input_shape = ngraph::op::util::make_try_fold<ngraph::opset7::ShapeOf>(fc->input_value(1));
auto I_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(A_input_shape, {I_idxs});
auto O_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(B_input_shape, {0});
ngraph::OutputVector output_shape_dims{I_node, O_node};
const auto original_rank = fc->get_output_rank();
NGRAPH_CHECK(original_rank.is_static());
if (input_rank < original_rank.get_length()) {
const size_t const_shape_value = original_rank.get_length() - input_rank;
output_shape_dims.insert(
output_shape_dims.begin(), ngraph::opset1::Constant::create(I_node->get_element_type(), { const_shape_value }, { 1 }));
}
auto reshape_output_shape = ngraph::op::util::make_try_fold<ngraph::opset1::Concat>(output_shape_dims, 0);
auto reshape_output = std::make_shared<ngraph::opset1::Reshape>(fc_new, reshape_output_shape, false);
new_ops.push_back(A_input_shape);
new_ops.push_back(B_input_shape);
new_ops.push_back(I_node);
new_ops.push_back(O_node);
new_ops.push_back(reshape_output_shape);
new_ops.push_back(reshape_output);
reshape_output->set_friendly_name(fc->get_friendly_name());
fc_new->set_friendly_name(fc->get_friendly_name() + "/FC");

View File

@ -1,30 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/op/util/op_annotations.hpp>
#include <ngraph/opsets/opset4.hpp>
namespace ngraph {
namespace op {
namespace util {
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr<ngraph::Node>& shape_node,
const std::vector<size_t>& indices) {
return std::make_shared<ngraph::opset4::Gather>(shape_node,
ngraph::opset4::Constant::create(ngraph::element::i64, {indices.size()}, indices),
ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0}));
}
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output<ngraph::Node>& shape_source,
const std::vector<size_t>& indices) {
const auto & shape_node = std::make_shared<ngraph::opset4::ShapeOf>(shape_source);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices);
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -142,6 +142,12 @@ Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & inpu
}
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(
const std::shared_ptr<ngraph::Node>& shape_node, const std::vector<size_t>& indices);
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(
const ngraph::Output<ngraph::Node>& shape_source, const std::vector<size_t>& indices);
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -11,6 +11,7 @@
#include <ngraph/op/broadcast.hpp>
#include <ngraph/op/constant.hpp>
#include <ngraph/op/reshape.hpp>
#include <ngraph/op/gather.hpp>
#include <ngraph/op/util/op_annotations.hpp>
namespace ngraph {
@ -152,6 +153,20 @@ std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& nod
return result;
}
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr<ngraph::Node>& shape_node,
const std::vector<size_t>& indices) {
return make_try_fold<v7::Gather>(
shape_node,
v0::Constant::create(ngraph::element::i64, {indices.size()}, indices),
v0::Constant::create(ngraph::element::i64, {}, {0}));
}
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output<ngraph::Node>& shape_source,
const std::vector<size_t>& indices) {
const auto & shape_node = make_try_fold<v3::ShapeOf>(shape_source);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices);
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -22,6 +22,7 @@ addIeTargetTest(
inference_engine_lp_transformations
ov_shape_inference
inference_engine_s
unitTestUtils
ADD_CPPLINT
LABELS
CPU

View File

@ -0,0 +1,567 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph_transformations/op/fully_connected.hpp>
#include <ngraph_transformations/convert_matmul_to_fc.hpp>
#include <ngraph_transformations/fc_bias_fusion.hpp>
#include <ngraph_transformations/reshape_fully_connected.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace MKLDNNPlugin;
TEST(TransformationTests, ConvertMatMulToFCTest1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 2 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, true, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(transpose, input2, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest4) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest5) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
ASSERT_NO_THROW(m.run_passes(f));
}
TEST(TransformationTests, ConvertMatMulToFCTest6) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 2 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
ASSERT_NO_THROW(m.run_passes(f));
ASSERT_NO_THROW(check_rt_info(f));
}
TEST(TransformationTests, ConvertMatMulToFCTest7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
auto reshape_begin = std::make_shared<ngraph::opset1::Reshape>(
input1, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{-1, 2}), false);
auto fc = std::make_shared<FullyConnectedNode>(reshape_begin, input2, ngraph::Rank(2));
auto reshape_end = ngraph::op::util::reshapeTo(fc, ngraph::Shape{3, 2, 3});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest8) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
auto reshape_begin = std::make_shared<ngraph::opset1::Reshape>(
input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 2}), false);
auto fc = std::make_shared<FullyConnectedNode>(reshape_begin, input2, ngraph::Rank(2));
auto a_shape = std::make_shared<ngraph::opset3::ShapeOf>(input1);
auto I = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(a_shape, {0, 1});
auto O = ngraph::opset1::Constant::create(ngraph::element::i64, { 1 }, { 3 });
auto output_shape = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{I, O}, 0);
auto reshape_end = std::make_shared<ngraph::opset1::Reshape>(fc, output_shape, false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest9) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
auto pass_config = m.get_pass_config();
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest10) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
ASSERT_NO_THROW(m.run_passes(f));
}
TEST(TransformationTests, FullyConnectedBiasFusionTest1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(3));
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<FullyConnectedBiasFusion>();
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
check_rt_info(f);
});
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, FullyConnectedBiasFusionTest2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(3));
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<FullyConnectedBiasFusion>();
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
check_rt_info(f);
});
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, FullyConnectedBiasFusionTest3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<FullyConnectedBiasFusion>();
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
check_rt_info(f);
});
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, FullyConnectedBiasFusionTest4) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<FullyConnectedBiasFusion>();
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
check_rt_info(f);
});
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, FullyConnectedBiasFusionTest5) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 786, 128 }, { 1 });
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager manager;
manager.register_pass<FullyConnectedBiasFusion>();
ASSERT_NO_THROW(manager.run_passes(f));
}
TEST(TransformationTests, FullyConnectedBiasFusionTest6) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{ -1, -1 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 });
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2), ngraph::element::f32);
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager manager;
manager.register_pass<FullyConnectedBiasFusion>();
ASSERT_NO_THROW(manager.run_passes(f));
}
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1});
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, input2, ngraph::Rank(2));
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, ngraph::opset1::Constant::create(ngraph::element::i64, {4}, {1, 5, 2, 2}), false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2, 3 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 1, 2 }, { 1 });
auto add = std::make_shared<ngraph::opset1::Add>(matmul, biases);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<FullyConnectedBiasFusion>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto reshape_before_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 2 }, { -1, 3 });
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, reshape_before_const, false);
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, weights, biases, ngraph::Rank(2));
auto reshape_after_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, 5, 2, 2 });
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, reshape_after_const, false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ reshape_out }, ngraph::ParameterVector{ input1 });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_dynamic) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3});
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ConvertMatMulToFC>();
m.register_pass<ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3});
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1});
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, input2, ngraph::Rank(2));
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(input1);
auto gather = std::make_shared<ngraph::opset7::Gather>(
shape_of, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {0, 1}), ngraph::opset1::Constant::create(ngraph::element::i64, {}, {0}));
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{
ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {1}),
gather,
ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {2}),
}, 0);
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, concat, false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -304,6 +304,35 @@ public:
OPENVINO_ASSERT(rank().is_static());
return m_dimensions.size();
}
/// \brief Returns a read/write iterator that points to the inserted element in the shape.
iterator insert(iterator position, const Dimension& val) {
m_rank_is_static = true;
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
return m_dimensions.insert(position, val);
}
/// \brief Inserts count copies of the value before position
void insert(iterator position, size_t n, const Dimension& val) {
m_dimensions.insert(position, n, val);
m_rank_is_static = true;
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
}
/// \brief Inserts elements from range [first, last) before position
template <class InputIterator>
void insert(iterator position, InputIterator first, InputIterator last) {
m_dimensions.insert(position, first, last);
m_rank_is_static = true;
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
}
/// \brief Requests that the dimensions vector capacity be enough to contain n elements
void reserve(size_t n) {
m_dimensions.reserve(n);
}
/// \brief push element to the end of partial shape
void push_back(const Dimension& val) {
m_dimensions.push_back(val);
m_rank_is_static = true;
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
}
private:
// Private constructor for PartialShape::dynamic().

View File

@ -13,7 +13,6 @@
#include <numeric>
#include "itt.hpp"
#include "transformations/smart_reshape/utils.hpp"
bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& pattern_to_output,
const std::shared_ptr<ngraph::Node>& matmul_label,
@ -35,7 +34,10 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap&
const auto& raw_idx =
reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
const auto& idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
const auto& C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx);
const auto& C = std::make_shared<ngraph::opset4::Gather>(
std::make_shared<ngraph::opset4::ShapeOf>(shape_source),
ngraph::opset4::Constant::create(ngraph::element::i64, {idx.size()}, idx),
ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0}));
const auto& N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
const auto& pattern_vector =
reshape_is_A_input ? (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C}))