Fixed build of RTInfoSerialization tests (#7978)

This commit is contained in:
Ilya Churaev 2021-10-13 13:50:30 +03:00 committed by GitHub
parent 917a29255e
commit 02f3a175d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 44 deletions

View File

@ -2,21 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <common_test_utils/file_utils.hpp>
#include <gtest/gtest.h>
#include <common_test_utils/file_utils.hpp>
#include <file_utils.h>
#include <ie_api.h>
#include <ie_iextension.h>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ie_core.hpp"
#include "ngraph/ngraph.hpp"
#include "transformations/serialize.hpp"
#include <openvino/opsets/opset8.hpp>
#include <transformations/rt_info/attributes.hpp>
#include "frontend_manager/frontend_manager.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/serialize.hpp"
#include "read_ir.hpp"
#include "transformations/rt_info/attributes.hpp"
#include "util/test_common.hpp"
using namespace ov;
using namespace ngraph;
class RTInfoSerializationTest : public ov::test::TestsCommon {
class RTInfoSerializationTest : public CommonTestUtils::TestsCommon {
protected:
std::string test_name = GetTestName() + "_" + GetTimestamp();
std::string m_out_xml_path = test_name + ".xml";
@ -26,17 +28,35 @@ protected:
CommonTestUtils::removeIRFiles(m_out_xml_path, m_out_bin_path);
}
std::shared_ptr<ngraph::Function> getWithIRFrontend(const std::string& model_path,
const std::string& weights_path) {
ngraph::frontend::FrontEnd::Ptr FE;
ngraph::frontend::InputModel::Ptr inputModel;
ov::VariantVector params{ov::make_variant(model_path), ov::make_variant(weights_path)};
FE = manager.load_by_model(params);
if (FE)
inputModel = FE->load(params);
if (inputModel)
return FE->convert(inputModel);
return nullptr;
}
private:
ngraph::frontend::FrontEndManager manager;
};
TEST_F(RTInfoSerializationTest, all_attributes_latest) {
auto init_info = [](RTMap& info) {
auto init_info = [](RTMap & info) {
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] = std::make_shared<ov::PrimitivesPriority>("priority");
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("priority");
info[ov::OldApiMap::get_type_info_static()] = std::make_shared<ov::OldApiMap>(
ov::OldApiMapAttr(std::vector<uint64_t>{0, 2, 3, 1}, ngraph::element::Type_t::f32));
ov::OldApiMapAttr(std::vector<uint64_t>{0, 2, 3, 1}, ngraph::element::Type_t::f32));
};
std::shared_ptr<ngraph::Function> function;
@ -51,26 +71,26 @@ TEST_F(RTInfoSerializationTest, all_attributes_latest) {
}
pass::Manager m;
m.register_pass<ov::pass::Serialize>(m_out_xml_path, m_out_bin_path);
m.register_pass<pass::Serialize>(m_out_xml_path, m_out_bin_path);
m.run_passes(function);
auto f = ov::test::readIR(m_out_xml_path, m_out_bin_path);
auto f = getWithIRFrontend(m_out_xml_path, m_out_bin_path);
ASSERT_NE(nullptr, f);
auto check_info = [](const RTMap& info) {
const std::string& key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
auto check_info = [](const RTMap & info) {
const std::string & key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
ASSERT_TRUE(info.count(key));
auto fused_names_attr = std::dynamic_pointer_cast<VariantWrapper<ngraph::FusedNames>>(info.at(key));
ASSERT_TRUE(fused_names_attr);
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
const std::string& pkey = ov::PrimitivesPriority::get_type_info_static();
const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
ASSERT_TRUE(info.count(pkey));
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
ASSERT_TRUE(primitives_priority_attr);
ASSERT_EQ(primitives_priority_attr->get(), "priority");
const std::string& old_api_map_key = ov::OldApiMap::get_type_info_static();
const std::string & old_api_map_key = ov::OldApiMap::get_type_info_static();
ASSERT_TRUE(info.count(old_api_map_key));
auto old_api_map_attr = std::dynamic_pointer_cast<ov::OldApiMap>(info.at(old_api_map_key));
ASSERT_TRUE(old_api_map_attr);
@ -87,10 +107,11 @@ TEST_F(RTInfoSerializationTest, all_attributes_latest) {
}
TEST_F(RTInfoSerializationTest, all_attributes_v10) {
auto init_info = [](RTMap& info) {
auto init_info = [](RTMap & info) {
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] = std::make_shared<ov::PrimitivesPriority>("priority");
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("priority");
};
std::shared_ptr<ngraph::Function> function;
@ -105,14 +126,14 @@ TEST_F(RTInfoSerializationTest, all_attributes_v10) {
}
pass::Manager m;
m.register_pass<ov::pass::Serialize>(m_out_xml_path, m_out_bin_path, ov::pass::Serialize::Version::IR_V10);
m.register_pass<pass::Serialize>(m_out_xml_path, m_out_bin_path, pass::Serialize::Version::IR_V10);
m.run_passes(function);
auto f = ov::test::readIR(m_out_xml_path, m_out_bin_path);
auto f = getWithIRFrontend(m_out_xml_path, m_out_bin_path);
ASSERT_NE(nullptr, f);
auto check_info = [](const RTMap& info) {
const std::string& key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
auto check_info = [](const RTMap & info) {
const std::string & key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
ASSERT_FALSE(info.count(key));
};
@ -124,10 +145,11 @@ TEST_F(RTInfoSerializationTest, all_attributes_v10) {
}
TEST_F(RTInfoSerializationTest, all_attributes_v11) {
auto init_info = [](RTMap& info) {
auto init_info = [](RTMap & info) {
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] = std::make_shared<ov::PrimitivesPriority>("priority");
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("priority");
};
std::shared_ptr<ngraph::Function> function;
@ -142,20 +164,20 @@ TEST_F(RTInfoSerializationTest, all_attributes_v11) {
}
pass::Manager m;
m.register_pass<ov::pass::Serialize>(m_out_xml_path, m_out_bin_path);
m.register_pass<pass::Serialize>(m_out_xml_path, m_out_bin_path);
m.run_passes(function);
auto f = ov::test::readIR(m_out_xml_path, m_out_bin_path);
auto f = getWithIRFrontend(m_out_xml_path, m_out_bin_path);
ASSERT_NE(nullptr, f);
auto check_info = [](const RTMap& info) {
const std::string& key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
auto check_info = [](const RTMap & info) {
const std::string & key = VariantWrapper<ngraph::FusedNames>::get_type_info_static();
ASSERT_TRUE(info.count(key));
auto fused_names_attr = std::dynamic_pointer_cast<VariantWrapper<ngraph::FusedNames>>(info.at(key));
ASSERT_TRUE(fused_names_attr);
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
const std::string& pkey = ov::PrimitivesPriority::get_type_info_static();
const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
ASSERT_TRUE(info.count(pkey));
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
ASSERT_TRUE(primitives_priority_attr);
@ -194,10 +216,10 @@ TEST_F(RTInfoSerializationTest, parameter_result_v11) {
}
pass::Manager m;
m.register_pass<ov::pass::Serialize>(m_out_xml_path, m_out_bin_path, ov::pass::Serialize::Version::IR_V11);
m.register_pass<pass::Serialize>(m_out_xml_path, m_out_bin_path, pass::Serialize::Version::IR_V11);
m.run_passes(function);
auto f = ov::test::readIR(m_out_xml_path, m_out_bin_path);
auto f = getWithIRFrontend(m_out_xml_path, m_out_bin_path);
ASSERT_NE(nullptr, f);
ASSERT_EQ(function->get_results().size(), f->get_results().size());

View File

@ -86,7 +86,6 @@ set(SRC
pass/serialization/cleanup.cpp
pass/serialization/const_compression.cpp
pass/serialization/deterministicity.cpp
pass/serialization/rt_info_serialization.cpp
pass/serialization/serialize.cpp
pattern.cpp
preprocess.cpp

View File

@ -6,8 +6,9 @@
#include <fstream>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/serialize.hpp"
#include "util/graph_comparator.hpp"
#include "util/test_common.hpp"
class SerializationCleanupTest : public ov::test::TestsCommon {
@ -24,11 +25,10 @@ protected:
namespace {
std::shared_ptr<ngraph::Function> CreateTestFunction(const std::string& name, const ngraph::PartialShape& ps) {
using namespace ngraph;
const auto param = std::make_shared<op::Parameter>(element::f16, ps);
const auto convert = std::make_shared<op::Convert>(param, element::f32);
const auto result = std::make_shared<op::Result>(convert);
return std::make_shared<Function>(ResultVector{result}, ParameterVector{param}, name);
const auto param = std::make_shared<ov::opset8::Parameter>(ov::element::f16, ps);
const auto convert = std::make_shared<ov::opset8::Convert>(param, ov::element::f32);
const auto result = std::make_shared<ov::opset8::Result>(convert);
return std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{param}, name);
}
} // namespace