[CPU] ConvertMatMulToFC fix (#15933)
This commit is contained in:
parent
3de00347f3
commit
84285ac317
@ -73,26 +73,15 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
std::swap(*(shape_b_aligned.end() - 1), *(shape_b_aligned.end() - 2));
|
||||
}
|
||||
|
||||
// check on per-batch MatMul which can't be converted to FC
|
||||
for (size_t i = 0; i < max_size - 2; ++i) {
|
||||
auto a_dim = shape_a_aligned[i], b_dim = shape_b_aligned[i];
|
||||
if (a_dim.is_dynamic()) {
|
||||
if (b_dim == 1) {
|
||||
shape_a_aligned[i] = shape_b_aligned[i] = a_dim;
|
||||
} else {
|
||||
return std::make_tuple(false, ngraph::PartialShape{shape_a_aligned}, ngraph::PartialShape{shape_b_aligned});
|
||||
}
|
||||
continue;
|
||||
if (shape_b_aligned[i] == 1) {
|
||||
shape_b_aligned[i] = shape_a_aligned[i];
|
||||
} else {
|
||||
return std::make_tuple(false, std::move(shape_a_aligned), std::move(shape_b_aligned));
|
||||
}
|
||||
// both dimensions are static
|
||||
if (a_dim != b_dim && a_dim.get_length() > 1 && b_dim.get_length() > 1) {
|
||||
std::ostringstream stream;
|
||||
stream << "Shapes can't be aligned: " << shape_a_aligned << " " << shape_b_aligned;
|
||||
throw ngraph::ngraph_error(stream.str());
|
||||
}
|
||||
size_t max_value = std::max(a_dim.get_length(), b_dim.get_length());
|
||||
shape_a_aligned[i] = shape_b_aligned[i] = max_value;
|
||||
}
|
||||
return std::make_tuple(true, shape_a_aligned, shape_b_aligned);
|
||||
return std::make_tuple(true, std::move(shape_a_aligned), std::move(shape_b_aligned));
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -17,27 +17,22 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ov_ops/type_relaxed.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest1) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 2 }, { 1 });
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, true, false);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
|
||||
@ -45,163 +40,109 @@ TEST(TransformationTests, ConvertMatMulToFCTest1) {
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(transpose, input2, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest2) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest3) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest4) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest4) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest5) {
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest5) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 }, { 1 });
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
ASSERT_NO_THROW(m.run_passes(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest6) {
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest6) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{ -1, -1, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 3, 1, 2 }, { 1 });
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
ASSERT_NO_THROW(m.run_passes(f));
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest7) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest7) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(2));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest8) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest8) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 2}, {1});
|
||||
@ -213,57 +154,98 @@ TEST(TransformationTests, ConvertMatMulToFCTest8) {
|
||||
auto O = ngraph::opset1::Constant::create(ngraph::element::i64, { 1 }, { 3 });
|
||||
auto output_shape = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{I, O}, 0);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest9) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest9) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
auto pass_config = m.get_pass_config();
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest10) {
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest10) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
ASSERT_NO_THROW(m.run_passes(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest11) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{18, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{18, 80, 1}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest12) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 80, 1}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest13) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 80, 1}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{80, 1}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3));
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest14) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{-1, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{1, 80, 1}, {1});
|
||||
auto matmul = std::make_shared<ov::op::TypeRelaxed<ngraph::opset1::MatMul>>(
|
||||
ov::element::TypeVector{ngraph::element::f32, ngraph::element::f32},
|
||||
ov::element::TypeVector{ngraph::element::f32},
|
||||
ov::op::TemporaryReplaceOutputType(input1, ngraph::element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(input2, ngraph::element::f32).get(),
|
||||
false,
|
||||
true);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{-1, -1, 1});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{80, 1}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(3), ngraph::element::f32);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest1) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||
@ -272,32 +254,20 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest1) {
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest2) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||
@ -306,31 +276,20 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest2) {
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, -1, 3072});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 3072}, {1});
|
||||
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(3));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest3) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||
@ -339,31 +298,20 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest3) {
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest4) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest4) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||
@ -372,30 +320,20 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest4) {
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{-1, 128});
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
|
||||
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest5) {
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest5) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 786, 128 }, { 1 });
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||
@ -403,82 +341,68 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest5) {
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FullyConnectedBiasFusionTest6) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{ -1, -1 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 });
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2), ngraph::element::f32);
|
||||
TEST_F(TransformationTestsF, FullyConnectedBiasFusionTest6) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{ -1, -1 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 });
|
||||
auto fc = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2), ngraph::element::f32);
|
||||
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 786 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
|
||||
|
||||
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
}
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::u8, ngraph::PartialShape{ -1, -1 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::i8, ngraph::Shape{ 786, 128 }, { 1 });
|
||||
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, bias, ngraph::Rank(2), ngraph::element::f32);
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_1) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 2, 3}, {1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{5, 2, 3});
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 3}, {1});
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, input2, ngraph::Rank(2));
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_2) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 2, 3 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 3 }, { 1 });
|
||||
@ -486,15 +410,10 @@ TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) {
|
||||
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(matmul, biases);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ConvertMatMulToFC>();
|
||||
m.register_pass<FullyConnectedBiasFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
manager.register_pass<FullyConnectedBiasFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
|
||||
auto reshape_before_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 2 }, { -1, 3 });
|
||||
@ -502,11 +421,8 @@ TEST(TransformationTests, ConvertMatMulToFCTest_second_input_rank_adj_3) {
|
||||
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
|
||||
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2 }, { 1 });
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, biases, ngraph::Rank(2));
|
||||
|
||||
auto reshape_after_const = ngraph::opset1::Constant::create(ngraph::element::i64, { 4 }, { 1, 5, 2, 2 });
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user