[dynamism][CPU] MatMul transformations: dynamic shapes support (#7582)
* [TESTS] CPUUnitTests: unitTestUtils added to CMakeLists * Reshape1D: functions moved to namespace * [nGraph] PartialShape extending * CPU specific transsformations for FullyConnected enabled to work with dynamic shapes * [Transformations] FC transformations: dynamic shapes support * [Transformations] FCBiasFusion: removed legacy check on bcast on weights * [TESTS] FC transformations: tests * [Transformations] SmartReshape: avoid tranformation utils methods * codestyle fix * codestyle fix * [TESTS] MatMulTransformation tests: compilation error fix * [CPU] FullyConnected: shape inference * postreview fixes * [CPU] FullyConnected: shape inference fix * [nGraph] PShape dimensions insertion fixed Co-authored-by: Stepyreva, Evgenya <evgenya.stepyreva@intel.com>
This commit is contained in:
parent
18aaaa79a0
commit
c84db94697
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#include "convert_matmul_to_fc.hpp"
|
#include "convert_matmul_to_fc.hpp"
|
||||||
#include "op/fully_connected.hpp"
|
#include "op/fully_connected.hpp"
|
||||||
#include <numeric>
|
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
@ -13,25 +12,39 @@
|
|||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertMatMulToFC, "ConvertMatMulToFC", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertMatMulToFC, "ConvertMatMulToFC", 0);
|
||||||
|
|
||||||
MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
|
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||||
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
|
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||||
ngraph::pattern::has_static_shape());
|
auto matmul_m = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
if (!matmul) {
|
|
||||||
|
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(pattern_map.at(matmul_m).get_node_shared_ptr());
|
||||||
|
if (!matmul || transformation_callback(matmul)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_a = matmul->input(0).get_source_output();
|
// fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations.
|
||||||
auto input_b = matmul->input(1).get_source_output();
|
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b.
|
||||||
|
auto fc_input_a = pattern_map.at(activations_m);
|
||||||
|
auto fc_input_b = pattern_map.at(weights_m);
|
||||||
|
|
||||||
auto shape_a = input_a.get_shape();
|
auto shape_a = fc_input_a.get_partial_shape();
|
||||||
auto shape_b = input_b.get_shape();
|
auto shape_b = fc_input_b.get_partial_shape();
|
||||||
auto output_shape = matmul->get_shape();
|
NGRAPH_CHECK(shape_b.is_static()); // requested 2nd input with static shape in the matcher
|
||||||
|
|
||||||
|
auto rank_a = shape_a.rank().get_length();
|
||||||
|
auto rank_b = shape_b.rank().get_length();
|
||||||
|
|
||||||
// Transformation to FC is not supported for 1D second input
|
// Transformation to FC is not supported for 1D second input
|
||||||
if (shape_b.size() == 1) {
|
if (rank_b == 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that if second inputs is Constant path and it's shape without ones dimensions has length <= 2
|
||||||
|
// we replace MatMul with FullyConnected operation.
|
||||||
|
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
|
||||||
|
std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,15 +55,17 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
|||||||
* for example: [2, 32, 64] [3, 64, 64] it will raise an exception.
|
* for example: [2, 32, 64] [3, 64, 64] it will raise an exception.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto get_aligned_shapes = [shape_a, shape_b, &matmul]() -> std::pair<ngraph::Shape, ngraph::Shape> {
|
auto get_aligned_shapes = [shape_a, shape_b, rank_a, rank_b, &matmul]() -> std::tuple<bool, ngraph::PartialShape, ngraph::PartialShape> {
|
||||||
ngraph::Shape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
|
ngraph::PartialShape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
|
||||||
size_t max_size = std::max(shape_a_aligned.size(), shape_b_aligned.size());
|
size_t max_size = std::max(rank_a, rank_b);
|
||||||
for (size_t i = 0, cnt = max_size - shape_a_aligned.size(); i < cnt; ++i)
|
for (size_t i = 0, cnt = max_size - rank_a; i < cnt; ++i) {
|
||||||
shape_a_aligned.insert(shape_a_aligned.begin(), 1);
|
shape_a_aligned.insert(shape_a_aligned.begin(), 1);
|
||||||
for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i)
|
}
|
||||||
|
for (size_t i = 0, cnt = max_size - rank_b; i < cnt; ++i) {
|
||||||
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
|
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
if (matmul->get_transpose_a() && shape_a.size() != 1) {
|
if (matmul->get_transpose_a() && rank_a != 1) {
|
||||||
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
|
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
|
||||||
}
|
}
|
||||||
if (matmul->get_transpose_b()) {
|
if (matmul->get_transpose_b()) {
|
||||||
@ -58,16 +73,25 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < max_size - 2; ++i) {
|
for (size_t i = 0; i < max_size - 2; ++i) {
|
||||||
if (shape_a_aligned[i] != shape_b_aligned[i] && shape_a_aligned[i] > 1 && shape_b_aligned[i] > 1) {
|
auto a_dim = shape_a_aligned[i], b_dim = shape_b_aligned[i];
|
||||||
|
if (a_dim.is_dynamic()) {
|
||||||
|
if (b_dim == 1) {
|
||||||
|
shape_a_aligned[i] = shape_b_aligned[i] = a_dim;
|
||||||
|
} else {
|
||||||
|
return std::make_tuple(false, ngraph::PartialShape{shape_a_aligned}, ngraph::PartialShape{shape_b_aligned});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// both dimensions are static
|
||||||
|
if (a_dim != b_dim && a_dim.get_length() > 1 && b_dim.get_length() > 1) {
|
||||||
std::ostringstream stream;
|
std::ostringstream stream;
|
||||||
stream << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
|
stream << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
|
||||||
throw ngraph::ngraph_error(stream.str());
|
throw ngraph::ngraph_error(stream.str());
|
||||||
}
|
}
|
||||||
size_t max_value = std::max(shape_a_aligned[i], shape_b_aligned[i]);
|
size_t max_value = std::max(a_dim.get_length(), b_dim.get_length());
|
||||||
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
|
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
|
||||||
}
|
}
|
||||||
|
return std::make_tuple(true, shape_a_aligned, shape_b_aligned);
|
||||||
return {shape_a_aligned, shape_b_aligned};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -78,76 +102,68 @@ MKLDNNPlugin::ConvertMatMulToFC::ConvertMatMulToFC() {
|
|||||||
* order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
|
* order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto create_transpose = [this](ngraph::Output<ngraph::Node> node, const std::string& transpose_name) -> std::shared_ptr<ngraph::Node> {
|
auto create_transpose = [this](const ngraph::Output<ngraph::Node>& node, const std::string& transpose_name) {
|
||||||
ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape();
|
auto rank = node.get_partial_shape().rank();
|
||||||
|
std::vector<size_t> transpose_order(rank.get_length());
|
||||||
std::vector<size_t> transpose_order(output_shape.size());
|
|
||||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||||
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
|
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
|
||||||
|
|
||||||
auto transpose = ngraph::pass::MatcherPass::register_new_node<ngraph::opset1::Transpose>(
|
auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order);
|
||||||
node, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{transpose_order.size()}, transpose_order));
|
auto transpose = ngraph::op::util::make_try_fold<ngraph::opset1::Transpose>(node, transpose_const);
|
||||||
|
if (!ngraph::is_type<ngraph::opset1::Constant>(transpose)) {
|
||||||
|
MatcherPass::register_new_node(transpose);
|
||||||
|
}
|
||||||
transpose->set_friendly_name(transpose_name);
|
transpose->set_friendly_name(transpose_name);
|
||||||
return transpose;
|
return transpose;
|
||||||
};
|
};
|
||||||
|
|
||||||
// fc_input_a and fc_input_b - are the final inputs that will be set to FullyConnected of GemmIE operations.
|
|
||||||
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and
|
|
||||||
// fc_input_b updated.
|
|
||||||
auto fc_input_a = input_a, fc_input_b = input_b;
|
|
||||||
|
|
||||||
// vector of new nGraph operations
|
|
||||||
ngraph::NodeVector new_ops;
|
ngraph::NodeVector new_ops;
|
||||||
|
bool success = true;
|
||||||
// Check that if second inputs is Constant operation and it's shape without ones dimensions has length <= 2
|
ngraph::PartialShape shape_a_aligned, shape_b_aligned;
|
||||||
// we replace MatMul with FullyConnected operation.
|
std::tie(success, shape_a_aligned, shape_b_aligned) = get_aligned_shapes();
|
||||||
// Otherwise we replace MatMul with Gemm.
|
if (!success) {
|
||||||
if ((std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
|
return false;
|
||||||
std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(fc_input_b.get_node_shared_ptr())) &&
|
|
||||||
std::count_if(shape_b.begin(), shape_b.end(), [](size_t x) { return x != 1; }) <= 2) {
|
|
||||||
ngraph::Shape shape_a_aligned, shape_b_aligned;
|
|
||||||
std::tie(shape_a_aligned, shape_b_aligned) = get_aligned_shapes();
|
|
||||||
|
|
||||||
if (shape_a_aligned.size() < 2 || shape_b_aligned.size() < 2) {
|
|
||||||
throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
|
|
||||||
// to FullyConnected representation: [I, K] * [K, O] = [I, O]
|
|
||||||
size_t K = *(shape_a_aligned.end() - 1);
|
|
||||||
ngraph::Shape B(shape_a_aligned.begin(), shape_a_aligned.end() - 2);
|
|
||||||
|
|
||||||
// Weights normalization
|
|
||||||
if (!matmul->get_transpose_b()) {
|
|
||||||
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
|
|
||||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shape_b.size() != 2) {
|
|
||||||
auto reshape_shape =
|
|
||||||
ngraph::opset1::Constant::create<int64_t>(ngraph::element::i64, ngraph::Shape{2}, {-1ll, static_cast<int64_t>(K)});
|
|
||||||
fc_input_b = std::make_shared<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, true);
|
|
||||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Input normalization
|
|
||||||
if (matmul->get_transpose_a() && shape_a.size() != 1) {
|
|
||||||
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
|
|
||||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create FullyConnected
|
|
||||||
auto fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc_input_a, fc_input_b, output_shape, matmul->output(0).get_element_type());
|
|
||||||
fc->set_friendly_name(matmul->get_friendly_name());
|
|
||||||
new_ops.push_back(fc);
|
|
||||||
|
|
||||||
ngraph::copy_runtime_info(matmul, new_ops);
|
|
||||||
ngraph::replace_node(matmul, fc);
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
|
auto aligned_a_rank = shape_a_aligned.rank(), aligned_b_rank = shape_b_aligned.rank();
|
||||||
|
if (aligned_a_rank.is_dynamic() || aligned_b_rank.is_dynamic() || aligned_a_rank.get_length() < 2 || aligned_b_rank.get_length() < 2) {
|
||||||
|
throw ngraph::ngraph_error("MatMul " + matmul->get_friendly_name() + " shapes are inconsistent.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
|
||||||
|
// to FullyConnected representation: [I, K] * [K, O] = [I, O]
|
||||||
|
|
||||||
|
// Weights normalization
|
||||||
|
if (!matmul->get_transpose_b()) {
|
||||||
|
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
|
||||||
|
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rank_b != 2) {
|
||||||
|
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
|
||||||
|
NGRAPH_CHECK(K.is_static()); // requested 2nd input with static shape in the matcher
|
||||||
|
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
|
||||||
|
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
|
||||||
|
fc_input_b = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
|
||||||
|
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input normalization
|
||||||
|
if (matmul->get_transpose_a() && rank_a != 1) {
|
||||||
|
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
|
||||||
|
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_rank = matmul->get_output_partial_shape(0).rank();
|
||||||
|
// Create FullyConnected
|
||||||
|
auto fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank, matmul->get_output_element_type(0));
|
||||||
|
fc->set_friendly_name(matmul->get_friendly_name());
|
||||||
|
new_ops.push_back(fc);
|
||||||
|
ngraph::copy_runtime_info(matmul, new_ops);
|
||||||
|
ngraph::replace_node(matmul, fc);
|
||||||
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_m, "ConvertMatMulToFC");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
@ -9,53 +9,58 @@
|
|||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0);
|
||||||
|
|
||||||
MKLDNNPlugin::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
|
MKLDNNPlugin::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
|
||||||
auto m_fc = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>([](ngraph::Output<ngraph::Node> output) {
|
auto input = ngraph::pattern::any_input();
|
||||||
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_shape()(output);
|
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto m_fc = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>({ input, weights }, [](ngraph::Output<ngraph::Node> output) {
|
||||||
|
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output);
|
||||||
});
|
});
|
||||||
auto m_bias = ngraph::pattern::any_input();
|
auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
auto m_add = ngraph::pattern::wrap_type<ngraph::opset1::Add>({m_fc, m_bias});
|
auto m_add = ngraph::pattern::wrap_type<ngraph::opset1::Add>({m_fc, m_bias});
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||||
auto & pattern_to_output = m.get_pattern_value_map();
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
|
||||||
auto add = pattern_to_output[m_add].get_node_shared_ptr();
|
auto add = pattern_to_output[m_add].get_node_shared_ptr();
|
||||||
auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
|
auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
|
||||||
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode>(pattern_to_output[m_fc].get_node_shared_ptr());
|
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode>(pattern_to_output[m_fc].get_node_shared_ptr());
|
||||||
if (!fc) {
|
if (!fc || transformation_callback(fc)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto bcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(bias)) {
|
|
||||||
bias = bcast->input_value(0).get_node_shared_ptr();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(bias)) {
|
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(bias)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::Shape bias_shape(bias->get_shape());
|
ngraph::Shape bias_shape(bias->get_shape());
|
||||||
ngraph::Shape output_shape(fc->get_shape());
|
ngraph::PartialShape output_shape(fc->get_output_partial_shape(0));
|
||||||
size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), size_t{1}, std::multiplies<int64_t>());
|
size_t bias_size = ngraph::shape_size(bias_shape);
|
||||||
if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
|
auto rank = output_shape.rank().get_length();
|
||||||
|
if (rank == 0 || output_shape[rank - 1].is_dynamic()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::NodeVector new_ops;
|
ngraph::NodeVector new_ops;
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> final_bias = bias;
|
std::shared_ptr<ngraph::Node> final_bias = bias;
|
||||||
if (bias->get_shape().size() >= 2) {
|
if (bias_shape.size() >= 2) {
|
||||||
final_bias = std::make_shared<ngraph::opset1::Reshape>(final_bias, ngraph::opset1::Constant::create(ngraph::element::i64,
|
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 });
|
||||||
ngraph::Shape{1}, {-1}), true);
|
final_bias = ngraph::op::util::make_try_fold<ngraph::opset1::Reshape>(final_bias, reshape_const, true);
|
||||||
new_ops.push_back(final_bias);
|
new_ops.push_back(final_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc->input(0).get_source_output(),
|
auto new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(fc->input_value(0),
|
||||||
fc->input(1).get_source_output(),
|
fc->input_value(1),
|
||||||
final_bias,
|
final_bias,
|
||||||
fc->get_shape(),
|
fc->get_output_rank(),
|
||||||
fc->get_output_type());
|
fc->get_output_type());
|
||||||
new_ops.push_back(new_fc);
|
new_ops.push_back(new_fc);
|
||||||
|
|
||||||
|
@ -8,40 +8,99 @@ constexpr ngraph::NodeTypeInfo MKLDNNPlugin::FullyConnectedNode::type_info;
|
|||||||
|
|
||||||
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
||||||
const ngraph::Output<Node>& B,
|
const ngraph::Output<Node>& B,
|
||||||
const ngraph::Shape& output_shape,
|
const ngraph::Rank& output_rank,
|
||||||
const ngraph::element::Type output_type)
|
const ngraph::element::Type output_type)
|
||||||
: Op({A, B}), m_output_shape(output_shape), m_output_type(output_type) {
|
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||||
validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
||||||
const ngraph::Output<Node>& B,
|
const ngraph::Output<Node>& B,
|
||||||
const ngraph::Output<Node>& C,
|
const ngraph::Output<Node>& C,
|
||||||
const ngraph::Shape& output_shape,
|
const ngraph::Rank& output_rank,
|
||||||
const ngraph::element::Type output_type)
|
const ngraph::element::Type output_type)
|
||||||
: Op({A, B, C}), m_output_shape(output_shape), m_output_type(output_type) {
|
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||||
validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||||
check_new_args_count(this, new_args);
|
check_new_args_count(this, new_args);
|
||||||
if (new_args.size() == 2) {
|
if (new_args.size() == 2) {
|
||||||
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), m_output_shape, m_output_type);
|
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), m_output_rank, m_output_type);
|
||||||
} else if (new_args.size() == 3) {
|
} else if (new_args.size() == 3) {
|
||||||
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_shape, m_output_type);
|
return std::make_shared<MKLDNNPlugin::FullyConnectedNode>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_rank, m_output_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw ngraph::ngraph_error("Unsupported number of arguments for FullyConnected operation");
|
throw ngraph::ngraph_error("Unsupported number of arguments for FullyConnected operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||||
m_output_size = m_output_shape.back();
|
const auto input_size = get_input_size();
|
||||||
set_output_type(0, m_output_type == ngraph::element::undefined ? input_value(0).get_element_type() : m_output_type, m_output_shape);
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
input_size == 2 || input_size == 3,
|
||||||
|
"Number of inputs is incorrect. Current value is: ",
|
||||||
|
input_size,
|
||||||
|
", expected: 2 or 3.");
|
||||||
|
|
||||||
|
const auto output_size = get_output_size();
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
output_size == 1,
|
||||||
|
"Number of outputs is incorrect. Current value is: ",
|
||||||
|
output_size,
|
||||||
|
", expected: 1.");
|
||||||
|
|
||||||
|
// Weights shape: [O, I1, ..., Im];
|
||||||
|
// O - output channels dimensions, Ik - input channels dimensions
|
||||||
|
const auto weights_pshape = get_input_partial_shape(1);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
weights_pshape.is_static(),
|
||||||
|
"Weights pshape must be static");
|
||||||
|
const auto weights_shape = weights_pshape.to_shape();
|
||||||
|
|
||||||
|
const auto o_channels = weights_pshape[0];
|
||||||
|
if (input_size == 3) {
|
||||||
|
const auto bias_shape = get_input_partial_shape(2);
|
||||||
|
const auto expected_bias_shape = ngraph::PartialShape{ o_channels };
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
bias_shape == expected_bias_shape,
|
||||||
|
"Bias shape is incorrect. Current value is: ",
|
||||||
|
bias_shape,
|
||||||
|
", expected: ",
|
||||||
|
expected_bias_shape,
|
||||||
|
".");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Activations shape: [B1, ..., Bn, I1, ..., Im];
|
||||||
|
// Bi - batch dimensions, Ik - input channels dimensions
|
||||||
|
const auto activations_pshape = get_input_partial_shape(0);
|
||||||
|
|
||||||
|
// Result shape: [B1, ..., Bn, O]
|
||||||
|
ngraph::PartialShape output_pshape;
|
||||||
|
if (activations_pshape.rank().is_static()) {
|
||||||
|
size_t output_channels_dimensions_count = weights_shape.size() - 1;
|
||||||
|
for (size_t i = 0; i < activations_pshape.size() - output_channels_dimensions_count; ++i) {
|
||||||
|
output_pshape.push_back(activations_pshape[i]);
|
||||||
|
}
|
||||||
|
output_pshape.push_back(o_channels);
|
||||||
|
|
||||||
|
if (m_output_rank.is_static()) {
|
||||||
|
while (output_pshape.rank().get_length() < m_output_rank.get_length()) {
|
||||||
|
output_pshape.insert(output_pshape.begin(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output_pshape = ngraph::PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_type = m_output_type == ngraph::element::undefined ? get_input_element_type(0) : m_output_type;
|
||||||
|
set_output_type(0, output_type, output_pshape);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||||
visitor.on_attribute("out-size", m_output_size);
|
if (m_output_rank.is_static()) {
|
||||||
visitor.on_attribute("out-shape", m_output_shape);
|
std::int64_t value = m_output_rank.get_length();
|
||||||
|
visitor.on_attribute("out-rank", value);
|
||||||
|
}
|
||||||
visitor.on_attribute("out-type", m_output_type);
|
visitor.on_attribute("out-type", m_output_type);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -19,13 +19,13 @@ public:
|
|||||||
|
|
||||||
FullyConnectedNode(const ngraph::Output<Node> &A,
|
FullyConnectedNode(const ngraph::Output<Node> &A,
|
||||||
const ngraph::Output<Node> &B,
|
const ngraph::Output<Node> &B,
|
||||||
const ngraph::Shape &output_shape,
|
const ngraph::Rank& output_rank,
|
||||||
const ngraph::element::Type output_type = ngraph::element::undefined);
|
const ngraph::element::Type output_type = ngraph::element::undefined);
|
||||||
|
|
||||||
FullyConnectedNode(const ngraph::Output<Node> &A,
|
FullyConnectedNode(const ngraph::Output<Node> &A,
|
||||||
const ngraph::Output<Node> &B,
|
const ngraph::Output<Node> &B,
|
||||||
const ngraph::Output<Node> &C,
|
const ngraph::Output<Node> &C,
|
||||||
const ngraph::Shape &output_shape,
|
const ngraph::Rank& output_rank,
|
||||||
const ngraph::element::Type output_type = ngraph::element::undefined);
|
const ngraph::element::Type output_type = ngraph::element::undefined);
|
||||||
|
|
||||||
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
|
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
|
||||||
@ -34,13 +34,11 @@ public:
|
|||||||
|
|
||||||
std::shared_ptr<Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
|
std::shared_ptr<Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
|
||||||
|
|
||||||
size_t get_out_size() const { return m_output_size; }
|
ngraph::Rank get_output_rank() const { return m_output_rank; }
|
||||||
|
|
||||||
ngraph::element::Type get_output_type() const { return m_output_type; }
|
ngraph::element::Type get_output_type() const { return m_output_type; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t m_output_size = 0;
|
ngraph::Rank m_output_rank;
|
||||||
ngraph::Shape m_output_shape = {};
|
|
||||||
ngraph::element::Type m_output_type;
|
ngraph::element::Type m_output_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
#include "transformations/utils/utils.hpp"
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
namespace Reshape1DOps {
|
||||||
template <class BaseOp>
|
template <class BaseOp>
|
||||||
std::shared_ptr<ngraph::Node> convert(const ngraph::Output<ngraph::Node> & data, std::shared_ptr<BaseOp> node, ngraph::NodeVector &new_ops) {
|
std::shared_ptr<ngraph::Node> convert(const ngraph::Output<ngraph::Node> & data, std::shared_ptr<BaseOp> node, ngraph::NodeVector &new_ops) {
|
||||||
auto new_strides = node->get_strides();
|
auto new_strides = node->get_strides();
|
||||||
@ -171,13 +172,14 @@ ngraph::matcher_pass_callback get_callback() {
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
} // namespace Reshape1DOps
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DConvolution, "Reshape1DConvolution", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DConvolution, "Reshape1DConvolution", 0);
|
||||||
|
|
||||||
MKLDNNPlugin::Reshape1DConvolution::Reshape1DConvolution() {
|
MKLDNNPlugin::Reshape1DConvolution::Reshape1DConvolution() {
|
||||||
auto conv = ngraph::pattern::wrap_type<ngraph::opset1::Convolution>(ngraph::pattern::has_static_shape());
|
auto conv = ngraph::pattern::wrap_type<ngraph::opset1::Convolution>(ngraph::pattern::has_static_shape());
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
|
||||||
this->register_matcher(m, get_callback());
|
this->register_matcher(m, Reshape1DOps::get_callback());
|
||||||
}
|
}
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupConvolution", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupConvolution", 0);
|
||||||
@ -185,7 +187,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DGroupConvolution, "Reshape1DGroupC
|
|||||||
MKLDNNPlugin::Reshape1DGroupConvolution::Reshape1DGroupConvolution() {
|
MKLDNNPlugin::Reshape1DGroupConvolution::Reshape1DGroupConvolution() {
|
||||||
auto group_conv = ngraph::pattern::wrap_type<ngraph::opset1::GroupConvolution>(ngraph::pattern::has_static_shape());
|
auto group_conv = ngraph::pattern::wrap_type<ngraph::opset1::GroupConvolution>(ngraph::pattern::has_static_shape());
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(group_conv, "Reshape1DGroupConvolution");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(group_conv, "Reshape1DGroupConvolution");
|
||||||
this->register_matcher(m, get_callback());
|
this->register_matcher(m, Reshape1DOps::get_callback());
|
||||||
}
|
}
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
|
||||||
@ -193,7 +195,7 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
|
|||||||
MKLDNNPlugin::Reshape1DAvgPool::Reshape1DAvgPool() {
|
MKLDNNPlugin::Reshape1DAvgPool::Reshape1DAvgPool() {
|
||||||
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::AvgPool>(ngraph::pattern::has_static_shape());
|
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::AvgPool>(ngraph::pattern::has_static_shape());
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
|
||||||
this->register_matcher(m, get_callback());
|
this->register_matcher(m, Reshape1DOps::get_callback());
|
||||||
}
|
}
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
|
||||||
@ -201,5 +203,5 @@ NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
|
|||||||
MKLDNNPlugin::Reshape1DMaxPool::Reshape1DMaxPool() {
|
MKLDNNPlugin::Reshape1DMaxPool::Reshape1DMaxPool() {
|
||||||
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::MaxPool>(ngraph::pattern::has_static_shape());
|
auto pool = ngraph::pattern::wrap_type<ngraph::opset1::MaxPool>(ngraph::pattern::has_static_shape());
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
|
||||||
this->register_matcher(m, get_callback());
|
this->register_matcher(m, Reshape1DOps::get_callback());
|
||||||
}
|
}
|
||||||
|
@ -63,13 +63,13 @@ MKLDNNPlugin::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
|
|||||||
if (fc->get_input_size() == 2) {
|
if (fc->get_input_size() == 2) {
|
||||||
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
|
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
|
||||||
weightInput,
|
weightInput,
|
||||||
outShape,
|
ngraph::Rank(outShape.size()),
|
||||||
fc->output(0).get_element_type());
|
fc->output(0).get_element_type());
|
||||||
} else if (fc->get_input_size() == 3) {
|
} else if (fc->get_input_size() == 3) {
|
||||||
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
|
new_fc = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape->input_value(0),
|
||||||
weightInput,
|
weightInput,
|
||||||
fc->input_value(2),
|
fc->input_value(2),
|
||||||
outShape,
|
ngraph::Rank(outShape.size()),
|
||||||
fc->output(0).get_element_type());
|
fc->output(0).get_element_type());
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
@ -5,60 +5,66 @@
|
|||||||
#include "reshape_fully_connected.hpp"
|
#include "reshape_fully_connected.hpp"
|
||||||
#include "op/fully_connected.hpp"
|
#include "op/fully_connected.hpp"
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
#include <transformations/utils/utils.hpp>
|
|
||||||
#include <ngraph/pattern/op/or.hpp>
|
#include <ngraph/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnected, "ReshapeFullyConnected", 0);
|
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnected, "ReshapeFullyConnected", 0);
|
||||||
|
|
||||||
MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() {
|
MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() {
|
||||||
ngraph::OutputVector twoInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input()};
|
ngraph::OutputVector twoInputs = {
|
||||||
ngraph::OutputVector threeInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(),
|
ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())};
|
||||||
|
ngraph::OutputVector threeInputs = {
|
||||||
|
ngraph::pattern::any_input(ngraph::pattern::has_static_rank()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
|
||||||
ngraph::pattern::any_input()};
|
ngraph::pattern::any_input()};
|
||||||
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_shape());
|
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_rank());
|
||||||
auto fcThreeInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(threeInputs, ngraph::pattern::has_static_shape());
|
auto fcThreeInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(threeInputs, ngraph::pattern::has_static_rank());
|
||||||
const auto fcTwoOrThreeInputs = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fcTwoInputs, fcThreeInputs});
|
const auto fcTwoOrThreeInputs = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fcTwoInputs, fcThreeInputs});
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
||||||
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode> (m.get_match_root());
|
auto fc = std::dynamic_pointer_cast<MKLDNNPlugin::FullyConnectedNode>(m.get_match_root());
|
||||||
if (!fc || transformation_callback(fc)) {
|
if (!fc || transformation_callback(fc)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_shape = fc->input_value(0).get_shape();
|
auto fc_input_shape = fc->get_input_partial_shape(0);
|
||||||
auto output_shape = fc->get_shape();
|
auto input_rank = fc_input_shape.rank().get_length();
|
||||||
|
auto output_shape = fc->get_output_partial_shape(0);
|
||||||
|
|
||||||
if (input_shape.size() == 2) {
|
if (input_rank == 2 || input_rank == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::NodeVector new_ops;
|
ngraph::NodeVector new_ops;
|
||||||
|
int64_t K = *(fc->get_input_shape(1).rbegin()); // requested 2nd input with static shape in the matcher
|
||||||
std::vector<int64_t> reshape_shape{-1, static_cast<int64_t>(input_shape.back())};
|
auto reshape = std::make_shared<ngraph::opset1::Reshape>(
|
||||||
auto reshape = std::make_shared<ngraph::opset1::Reshape>(fc->input_value(0),
|
fc->input_value(0), ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{-1, K}), false);
|
||||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, reshape_shape), true);
|
if (reshape->get_output_partial_shape(0).rank().is_dynamic())
|
||||||
|
return false;
|
||||||
new_ops.push_back(reshape);
|
new_ops.push_back(reshape);
|
||||||
|
|
||||||
reshape->set_friendly_name(fc->get_friendly_name() + "/Reshape");
|
reshape->set_friendly_name(fc->get_friendly_name() + "/Reshape");
|
||||||
|
|
||||||
// Calculate output shape for new FullyConnected layer
|
// Calculate output shape for new FullyConnected layer
|
||||||
// [I, K] * [O, K] = [I, O]
|
// [I, K] * [O, K] = [I, O]
|
||||||
auto I = reshape->get_shape()[0];
|
auto I = reshape->get_output_partial_shape(0)[0];
|
||||||
auto O = fc->input_value(1).get_shape()[0];
|
auto O = fc->get_input_partial_shape(1)[0];
|
||||||
ngraph::Shape output_shape_new{I, O};
|
ngraph::PartialShape output_shape_new{I, O};
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> fc_new;
|
std::shared_ptr<ngraph::Node> fc_new;
|
||||||
if (fc->get_input_size() == 2) {
|
if (fc->get_input_size() == 2) {
|
||||||
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
|
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
|
||||||
fc->input_value(1),
|
fc->input_value(1),
|
||||||
output_shape_new,
|
output_shape_new.rank(),
|
||||||
fc->get_output_type());
|
fc->get_output_type());
|
||||||
} else if (fc->get_input_size() == 3) {
|
} else if (fc->get_input_size() == 3) {
|
||||||
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
|
fc_new = std::make_shared<MKLDNNPlugin::FullyConnectedNode>(reshape,
|
||||||
fc->input_value(1),
|
fc->input_value(1),
|
||||||
fc->input_value(2),
|
fc->input_value(2),
|
||||||
output_shape_new,
|
output_shape_new.rank(),
|
||||||
fc->get_output_type());
|
fc->get_output_type());
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -66,7 +72,29 @@ MKLDNNPlugin::ReshapeFullyConnected::ReshapeFullyConnected() {
|
|||||||
new_ops.push_back(fc_new);
|
new_ops.push_back(fc_new);
|
||||||
|
|
||||||
if (output_shape != output_shape_new) {
|
if (output_shape != output_shape_new) {
|
||||||
auto reshape_output = ngraph::op::util::reshapeTo(fc_new, output_shape);
|
auto I_idxs = std::vector<size_t>(input_rank - 1);
|
||||||
|
std::iota(I_idxs.begin(), I_idxs.end(), 0);
|
||||||
|
auto A_input_shape = ngraph::op::util::make_try_fold<ngraph::opset7::ShapeOf>(fc->input_value(0));
|
||||||
|
auto B_input_shape = ngraph::op::util::make_try_fold<ngraph::opset7::ShapeOf>(fc->input_value(1));
|
||||||
|
auto I_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(A_input_shape, {I_idxs});
|
||||||
|
auto O_node = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(B_input_shape, {0});
|
||||||
|
ngraph::OutputVector output_shape_dims{I_node, O_node};
|
||||||
|
|
||||||
|
const auto original_rank = fc->get_output_rank();
|
||||||
|
NGRAPH_CHECK(original_rank.is_static());
|
||||||
|
if (input_rank < original_rank.get_length()) {
|
||||||
|
const size_t const_shape_value = original_rank.get_length() - input_rank;
|
||||||
|
output_shape_dims.insert(
|
||||||
|
output_shape_dims.begin(), ngraph::opset1::Constant::create(I_node->get_element_type(), { const_shape_value }, { 1 }));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reshape_output_shape = ngraph::op::util::make_try_fold<ngraph::opset1::Concat>(output_shape_dims, 0);
|
||||||
|
auto reshape_output = std::make_shared<ngraph::opset1::Reshape>(fc_new, reshape_output_shape, false);
|
||||||
|
new_ops.push_back(A_input_shape);
|
||||||
|
new_ops.push_back(B_input_shape);
|
||||||
|
new_ops.push_back(I_node);
|
||||||
|
new_ops.push_back(O_node);
|
||||||
|
new_ops.push_back(reshape_output_shape);
|
||||||
new_ops.push_back(reshape_output);
|
new_ops.push_back(reshape_output);
|
||||||
reshape_output->set_friendly_name(fc->get_friendly_name());
|
reshape_output->set_friendly_name(fc->get_friendly_name());
|
||||||
fc_new->set_friendly_name(fc->get_friendly_name() + "/FC");
|
fc_new->set_friendly_name(fc->get_friendly_name() + "/FC");
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
// Copyright (C) 2018-2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <transformations_visibility.hpp>
|
|
||||||
#include <ngraph/op/util/op_annotations.hpp>
|
|
||||||
#include <ngraph/opsets/opset4.hpp>
|
|
||||||
|
|
||||||
namespace ngraph {
|
|
||||||
namespace op {
|
|
||||||
namespace util {
|
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr<ngraph::Node>& shape_node,
|
|
||||||
const std::vector<size_t>& indices) {
|
|
||||||
return std::make_shared<ngraph::opset4::Gather>(shape_node,
|
|
||||||
ngraph::opset4::Constant::create(ngraph::element::i64, {indices.size()}, indices),
|
|
||||||
ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output<ngraph::Node>& shape_source,
|
|
||||||
const std::vector<size_t>& indices) {
|
|
||||||
const auto & shape_node = std::make_shared<ngraph::opset4::ShapeOf>(shape_source);
|
|
||||||
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace util
|
|
||||||
} // namespace op
|
|
||||||
} // namespace ngraph
|
|
@ -142,6 +142,12 @@ Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & inpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
|
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
|
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(
|
||||||
|
const std::shared_ptr<ngraph::Node>& shape_node, const std::vector<size_t>& indices);
|
||||||
|
|
||||||
|
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(
|
||||||
|
const ngraph::Output<ngraph::Node>& shape_source, const std::vector<size_t>& indices);
|
||||||
} // namespace util
|
} // namespace util
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <ngraph/op/broadcast.hpp>
|
#include <ngraph/op/broadcast.hpp>
|
||||||
#include <ngraph/op/constant.hpp>
|
#include <ngraph/op/constant.hpp>
|
||||||
#include <ngraph/op/reshape.hpp>
|
#include <ngraph/op/reshape.hpp>
|
||||||
|
#include <ngraph/op/gather.hpp>
|
||||||
#include <ngraph/op/util/op_annotations.hpp>
|
#include <ngraph/op/util/op_annotations.hpp>
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
@ -152,6 +153,20 @@ std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& nod
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_node(const std::shared_ptr<ngraph::Node>& shape_node,
|
||||||
|
const std::vector<size_t>& indices) {
|
||||||
|
return make_try_fold<v7::Gather>(
|
||||||
|
shape_node,
|
||||||
|
v0::Constant::create(ngraph::element::i64, {indices.size()}, indices),
|
||||||
|
v0::Constant::create(ngraph::element::i64, {}, {0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(const ngraph::Output<ngraph::Node>& shape_source,
|
||||||
|
const std::vector<size_t>& indices) {
|
||||||
|
const auto & shape_node = make_try_fold<v3::ShapeOf>(shape_source);
|
||||||
|
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace util
|
} // namespace util
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -22,6 +22,7 @@ addIeTargetTest(
|
|||||||
inference_engine_lp_transformations
|
inference_engine_lp_transformations
|
||||||
ov_shape_inference
|
ov_shape_inference
|
||||||
inference_engine_s
|
inference_engine_s
|
||||||
|
unitTestUtils
|
||||||
ADD_CPPLINT
|
ADD_CPPLINT
|
||||||
LABELS
|
LABELS
|
||||||
CPU
|
CPU
|
||||||
|
567
inference-engine/tests/unit/cpu/convert_matmul_test.cpp
Normal file
567
inference-engine/tests/unit/cpu/convert_matmul_test.cpp
Normal file
@ -0,0 +1,567 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
#include <ngraph_transformations/op/fully_connected.hpp>
|
||||||
|
#include <ngraph_transformations/convert_matmul_to_fc.hpp>
|
||||||
|
#include <ngraph_transformations/fc_bias_fusion.hpp>
|
||||||
|
#include <ngraph_transformations/reshape_fully_connected.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace MKLDNNPlugin;
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest1) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, true, false);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||||
|
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
|
||||||
|
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(transpose, input2, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest2) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||||
|
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||||
|
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest3) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest4) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest5) {
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
ASSERT_NO_THROW(m.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest6) {
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
ASSERT_NO_THROW(m.run_passes(f));
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest7) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||||
|
auto reshape_begin = std::make_shared<ngraph::opset1::Reshape>(
|
||||||
|
input1, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{-1, 2}), false);
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(reshape_begin, input2, ngraph::Rank(2));
|
||||||
|
auto reshape_end = ngraph::op::util::reshapeTo(fc, ngraph::Shape{3, 2, 3});
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest8) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||||
|
|
||||||
|
auto reshape_begin = std::make_shared<ngraph::opset1::Reshape>(
|
||||||
|
input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 2}), false);
|
||||||
|
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(reshape_begin, input2, ngraph::Rank(2));
|
||||||
|
auto a_shape = std::make_shared<ngraph::opset3::ShapeOf>(input1);
|
||||||
|
|
||||||
|
auto I = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(a_shape, {0, 1});
|
||||||
|
auto O = ngraph::opset1::Constant::create(ngraph::element::i64, { 1 }, { 3 });
|
||||||
|
auto output_shape = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{I, O}, 0);
|
||||||
|
auto reshape_end = std::make_shared<ngraph::opset1::Reshape>(fc, output_shape, false);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_end}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest9) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
auto pass_config = m.get_pass_config();
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest10) {
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
ASSERT_NO_THROW(m.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest1) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(3));
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
|
check_rt_info(f);
|
||||||
|
});
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||||
|
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest2) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(3));
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
|
check_rt_info(f);
|
||||||
|
});
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||||
|
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest3) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
|
check_rt_info(f);
|
||||||
|
});
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||||
|
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest4) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
|
check_rt_info(f);
|
||||||
|
});
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||||
|
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest5) {
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 786, 128 }, { 1 });
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FullyConnectedBiasFusionTest6) {
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{ -1, -1 });
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 });
|
||||||
|
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2), ngraph::element::f32);
|
||||||
|
|
||||||
|
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||||
|
|
||||||
|
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_1) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
|
||||||
|
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false);
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1});
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, input2, ngraph::Rank(2));
|
||||||
|
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, ngraph::opset1::Constant::create(ngraph::element::i64, {4}, {1, 5, 2, 2}), false);
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_2) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2, 3 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
|
||||||
|
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 1, 2 }, { 1 });
|
||||||
|
auto add = std::make_shared<ngraph::opset1::Add>(matmul, biases);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<FullyConnectedBiasFusion>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
|
||||||
|
auto reshape_before_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 2 }, { -1, 3 });
|
||||||
|
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, reshape_before_const, false);
|
||||||
|
|
||||||
|
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||||
|
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2 }, { 1 });
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, weights, biases, ngraph::Rank(2));
|
||||||
|
|
||||||
|
auto reshape_after_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, 5, 2, 2 });
|
||||||
|
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, reshape_after_const, false);
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ reshape_out }, ngraph::ParameterVector{ input1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_dynamic) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3});
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2, 3}, {1});
|
||||||
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
m.register_pass<ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ReshapeFullyConnected>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 2, 3});
|
||||||
|
auto reshape_1 = std::make_shared<ngraph::opset1::Reshape>(input1, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {-1, 3}), false);
|
||||||
|
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1});
|
||||||
|
auto matmul = std::make_shared<FullyConnectedNode>(reshape_1, input2, ngraph::Rank(2));
|
||||||
|
|
||||||
|
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(input1);
|
||||||
|
auto gather = std::make_shared<ngraph::opset7::Gather>(
|
||||||
|
shape_of, ngraph::opset1::Constant::create(ngraph::element::i64, {2}, {0, 1}), ngraph::opset1::Constant::create(ngraph::element::i64, {}, {0}));
|
||||||
|
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{
|
||||||
|
ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {1}),
|
||||||
|
gather,
|
||||||
|
ngraph::opset1::Constant::create(ngraph::element::i64, {1}, {2}),
|
||||||
|
}, 0);
|
||||||
|
auto reshape_out = std::make_shared<ngraph::opset1::Reshape>(matmul, concat, false);
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_out}, ngraph::ParameterVector{input1});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
@ -304,6 +304,35 @@ public:
|
|||||||
OPENVINO_ASSERT(rank().is_static());
|
OPENVINO_ASSERT(rank().is_static());
|
||||||
return m_dimensions.size();
|
return m_dimensions.size();
|
||||||
}
|
}
|
||||||
|
/// \brief Returns a read/write iterator that points to the inserted element in the shape.
|
||||||
|
iterator insert(iterator position, const Dimension& val) {
|
||||||
|
m_rank_is_static = true;
|
||||||
|
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||||
|
return m_dimensions.insert(position, val);
|
||||||
|
}
|
||||||
|
/// \brief Inserts count copies of the value before position
|
||||||
|
void insert(iterator position, size_t n, const Dimension& val) {
|
||||||
|
m_dimensions.insert(position, n, val);
|
||||||
|
m_rank_is_static = true;
|
||||||
|
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||||
|
}
|
||||||
|
/// \brief Inserts elements from range [first, last) before position
|
||||||
|
template <class InputIterator>
|
||||||
|
void insert(iterator position, InputIterator first, InputIterator last) {
|
||||||
|
m_dimensions.insert(position, first, last);
|
||||||
|
m_rank_is_static = true;
|
||||||
|
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||||
|
}
|
||||||
|
/// \brief Requests that the dimensions vector capacity be enough to contain n elements
|
||||||
|
void reserve(size_t n) {
|
||||||
|
m_dimensions.reserve(n);
|
||||||
|
}
|
||||||
|
/// \brief push element to the end of partial shape
|
||||||
|
void push_back(const Dimension& val) {
|
||||||
|
m_dimensions.push_back(val);
|
||||||
|
m_rank_is_static = true;
|
||||||
|
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Private constructor for PartialShape::dynamic().
|
// Private constructor for PartialShape::dynamic().
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "transformations/smart_reshape/utils.hpp"
|
|
||||||
|
|
||||||
bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& pattern_to_output,
|
bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& pattern_to_output,
|
||||||
const std::shared_ptr<ngraph::Node>& matmul_label,
|
const std::shared_ptr<ngraph::Node>& matmul_label,
|
||||||
@ -35,7 +34,10 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap&
|
|||||||
const auto& raw_idx =
|
const auto& raw_idx =
|
||||||
reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
|
reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
|
||||||
const auto& idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
|
const auto& idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
|
||||||
const auto& C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx);
|
const auto& C = std::make_shared<ngraph::opset4::Gather>(
|
||||||
|
std::make_shared<ngraph::opset4::ShapeOf>(shape_source),
|
||||||
|
ngraph::opset4::Constant::create(ngraph::element::i64, {idx.size()}, idx),
|
||||||
|
ngraph::opset4::Constant::create(ngraph::element::i64, {}, {0}));
|
||||||
const auto& N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
|
const auto& N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
|
||||||
const auto& pattern_vector =
|
const auto& pattern_vector =
|
||||||
reshape_is_A_input ? (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C}))
|
reshape_is_A_input ? (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C}))
|
||||||
|
Loading…
Reference in New Issue
Block a user