Fix TI Serialization for IR 11 (#7977)

* Fix TI Serialization for IR 11

* Enable TI Serialize tests with IR10
This commit is contained in:
Ivan Tikhonov 2021-10-20 16:30:22 +03:00 committed by GitHub
parent 028d78a90f
commit 5cd63f47e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 16 deletions

View File

@ -10,8 +10,12 @@
using namespace LayerTestsDefinitions; using namespace LayerTestsDefinitions;
namespace { namespace {
TEST_P(TensorIteratorTest, Serialize) { TEST_P(TensorIteratorTest, Serialize_IR10) {
Serialize(); Serialize(ov::pass::Serialize::Version::IR_V10);
}
TEST_P(TensorIteratorTest, Serialize_IR11) {
Serialize(ov::pass::Serialize::Version::IR_V11);
} }
const std::vector<InferenceEngine::Precision> netPrecisions = { const std::vector<InferenceEngine::Precision> netPrecisions = {

View File

@ -16,6 +16,7 @@
#include <ngraph/function.hpp> #include <ngraph/function.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <ngraph/type/bfloat16.hpp> #include <ngraph/type/bfloat16.hpp>
#include <ngraph/pass/serialize.hpp>
#include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/ngraph_test_utils.hpp"
#include "common_test_utils/common_utils.hpp" #include "common_test_utils/common_utils.hpp"
@ -55,7 +56,7 @@ public:
virtual void Run(); virtual void Run();
virtual void Serialize(); virtual void Serialize(ngraph::pass::Serialize::Version ir_version = ngraph::pass::Serialize::Version::UNSPECIFIED);
virtual void QueryNetwork(); virtual void QueryNetwork();

View File

@ -94,7 +94,7 @@ void LayerTestsCommon::Run() {
} }
} }
void LayerTestsCommon::Serialize() { void LayerTestsCommon::Serialize(ngraph::pass::Serialize::Version ir_version) {
SKIP_IF_CURRENT_TEST_IS_DISABLED(); SKIP_IF_CURRENT_TEST_IS_DISABLED();
std::string output_name = GetTestName().substr(0, CommonTestUtils::maxFileNameLength) + "_" + GetTimestamp(); std::string output_name = GetTestName().substr(0, CommonTestUtils::maxFileNameLength) + "_" + GetTimestamp();
@ -103,7 +103,7 @@ void LayerTestsCommon::Serialize() {
std::string out_bin_path = output_name + ".bin"; std::string out_bin_path = output_name + ".bin";
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(out_xml_path, out_bin_path); manager.register_pass<ov::pass::Serialize>(out_xml_path, out_bin_path, ir_version);
manager.run_passes(function); manager.run_passes(function);
function->validate_nodes_and_infer_types(); function->validate_nodes_and_infer_types();

View File

@ -261,16 +261,19 @@ class XmlSerializer : public ngraph::AttributeVisitor {
std::vector<std::string> map_type_from_body(const pugi::xml_node& xml_node, std::vector<std::string> map_type_from_body(const pugi::xml_node& xml_node,
const std::string& map_type, const std::string& map_type,
int64_t ir_version,
const std::string& body_name = "body") { const std::string& body_name = "body") {
std::vector<std::string> output; std::vector<std::string> output;
for (pugi::xml_node node : xml_node.child(body_name.c_str()).child("layers")) { for (pugi::xml_node node : xml_node.child(body_name.c_str()).child("layers")) {
if (!map_type.compare(node.attribute("type").value())) { if (map_type == node.attribute("type").value()) {
output.emplace_back(node.attribute("id").value()); output.emplace_back(node.attribute("id").value());
} }
} }
// ops for serialized body function are provided in reversed order if (ir_version < 11) {
std::reverse(output.begin(), output.end()); // ops for serialized body function are provided in reversed order
std::reverse(output.begin(), output.end());
}
return output; return output;
} }
@ -401,9 +404,10 @@ public:
if (is_body_target) { if (is_body_target) {
auto body_name = std::get<0>(bnames); auto body_name = std::get<0>(bnames);
auto portmap_name = std::get<1>(bnames); auto portmap_name = std::get<1>(bnames);
std::vector<std::string> result_mapping = map_type_from_body(m_xml_node.parent(), "Result", body_name); std::vector<std::string> result_mapping =
map_type_from_body(m_xml_node.parent(), "Result", m_version, body_name);
std::vector<std::string> parameter_mapping = std::vector<std::string> parameter_mapping =
map_type_from_body(m_xml_node.parent(), "Parameter", body_name); map_type_from_body(m_xml_node.parent(), "Parameter", m_version, body_name);
pugi::xml_node port_map = m_xml_node.parent().child(portmap_name.c_str()); pugi::xml_node port_map = m_xml_node.parent().child(portmap_name.c_str());
@ -504,9 +508,7 @@ public:
// to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes // to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes
// is removed. // is removed.
pugi::xml_node xml_body = m_xml_node.parent().append_child(name.c_str()); pugi::xml_node xml_body = m_xml_node.parent().append_child(name.c_str());
// FIXME: the issue with TensorIteratorTest.Serialize doesn't allow to use v11 order of operations ngfunction_2_ir(xml_body, *adapter.get(), m_custom_opsets, m_constant_write_handler, m_version);
// Need to use m_version instead of 10
ngfunction_2_ir(xml_body, *adapter.get(), m_custom_opsets, m_constant_write_handler, 10);
xml_body.remove_attribute("name"); xml_body.remove_attribute("name");
xml_body.remove_attribute("version"); xml_body.remove_attribute("version");
} else if (name == "net") { } else if (name == "net") {
@ -848,9 +850,6 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
const bool exec_graph = is_exec_graph(f); const bool exec_graph = is_exec_graph(f);
// change the order for parameters/results/sinks
// It should be a part of get_ordered_ops method
// FIXME: TensorIteratorTest.Serialize tests
auto sorted_ops = f.get_ordered_ops(); auto sorted_ops = f.get_ordered_ops();
if (version >= 11) { if (version >= 11) {
std::vector<std::shared_ptr<ov::Node>> result; std::vector<std::shared_ptr<ov::Node>> result;