Fix TI Serialization for IR 11 (#7977)
* Fix TI Serialization for IR 11 * Enable TI Serialize tests with IR10
This commit is contained in:
parent
028d78a90f
commit
5cd63f47e5
@ -10,8 +10,12 @@
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
TEST_P(TensorIteratorTest, Serialize) {
|
||||
Serialize();
|
||||
TEST_P(TensorIteratorTest, Serialize_IR10) {
|
||||
Serialize(ov::pass::Serialize::Version::IR_V10);
|
||||
}
|
||||
|
||||
TEST_P(TensorIteratorTest, Serialize_IR11) {
|
||||
Serialize(ov::pass::Serialize::Version::IR_V11);
|
||||
}
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/type/bfloat16.hpp>
|
||||
#include <ngraph/pass/serialize.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
@ -55,7 +56,7 @@ public:
|
||||
|
||||
virtual void Run();
|
||||
|
||||
virtual void Serialize();
|
||||
virtual void Serialize(ngraph::pass::Serialize::Version ir_version = ngraph::pass::Serialize::Version::UNSPECIFIED);
|
||||
|
||||
virtual void QueryNetwork();
|
||||
|
||||
|
@ -94,7 +94,7 @@ void LayerTestsCommon::Run() {
|
||||
}
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Serialize() {
|
||||
void LayerTestsCommon::Serialize(ngraph::pass::Serialize::Version ir_version) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED();
|
||||
|
||||
std::string output_name = GetTestName().substr(0, CommonTestUtils::maxFileNameLength) + "_" + GetTimestamp();
|
||||
@ -103,7 +103,7 @@ void LayerTestsCommon::Serialize() {
|
||||
std::string out_bin_path = output_name + ".bin";
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::Serialize>(out_xml_path, out_bin_path);
|
||||
manager.register_pass<ov::pass::Serialize>(out_xml_path, out_bin_path, ir_version);
|
||||
manager.run_passes(function);
|
||||
function->validate_nodes_and_infer_types();
|
||||
|
||||
|
@ -261,16 +261,19 @@ class XmlSerializer : public ngraph::AttributeVisitor {
|
||||
|
||||
std::vector<std::string> map_type_from_body(const pugi::xml_node& xml_node,
|
||||
const std::string& map_type,
|
||||
int64_t ir_version,
|
||||
const std::string& body_name = "body") {
|
||||
std::vector<std::string> output;
|
||||
for (pugi::xml_node node : xml_node.child(body_name.c_str()).child("layers")) {
|
||||
if (!map_type.compare(node.attribute("type").value())) {
|
||||
if (map_type == node.attribute("type").value()) {
|
||||
output.emplace_back(node.attribute("id").value());
|
||||
}
|
||||
}
|
||||
|
||||
if (ir_version < 11) {
|
||||
// ops for serialized body function are provided in reversed order
|
||||
std::reverse(output.begin(), output.end());
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -401,9 +404,10 @@ public:
|
||||
if (is_body_target) {
|
||||
auto body_name = std::get<0>(bnames);
|
||||
auto portmap_name = std::get<1>(bnames);
|
||||
std::vector<std::string> result_mapping = map_type_from_body(m_xml_node.parent(), "Result", body_name);
|
||||
std::vector<std::string> result_mapping =
|
||||
map_type_from_body(m_xml_node.parent(), "Result", m_version, body_name);
|
||||
std::vector<std::string> parameter_mapping =
|
||||
map_type_from_body(m_xml_node.parent(), "Parameter", body_name);
|
||||
map_type_from_body(m_xml_node.parent(), "Parameter", m_version, body_name);
|
||||
|
||||
pugi::xml_node port_map = m_xml_node.parent().child(portmap_name.c_str());
|
||||
|
||||
@ -504,9 +508,7 @@ public:
|
||||
// to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes
|
||||
// is removed.
|
||||
pugi::xml_node xml_body = m_xml_node.parent().append_child(name.c_str());
|
||||
// FIXME: the issue with TensorIteratorTest.Serialize doesn't allow to use v11 order of operations
|
||||
// Need to use m_version instead of 10
|
||||
ngfunction_2_ir(xml_body, *adapter.get(), m_custom_opsets, m_constant_write_handler, 10);
|
||||
ngfunction_2_ir(xml_body, *adapter.get(), m_custom_opsets, m_constant_write_handler, m_version);
|
||||
xml_body.remove_attribute("name");
|
||||
xml_body.remove_attribute("version");
|
||||
} else if (name == "net") {
|
||||
@ -848,9 +850,6 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
|
||||
|
||||
const bool exec_graph = is_exec_graph(f);
|
||||
|
||||
// change the order for parameters/results/sinks
|
||||
// It should be a part of get_ordered_ops method
|
||||
// FIXME: TensorIteratorTest.Serialize tests
|
||||
auto sorted_ops = f.get_ordered_ops();
|
||||
if (version >= 11) {
|
||||
std::vector<std::shared_ptr<ov::Node>> result;
|
||||
|
Loading…
Reference in New Issue
Block a user