[Snippets] LIR serialization: additional connections between LoopBegin and LoopEnd nodes (#19630)

* [Snippets] LIR serialization improvements

* Minor correction

* Review comments
This commit is contained in:
Vladislav Golubev 2023-09-12 11:20:18 +02:00 committed by GitHub
parent 3c1b384694
commit faa6b77247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 18 deletions

View File

@ -22,7 +22,7 @@ public:
OPENVINO_OP("SerializationNode", "SnippetsOpset");
SerializationNode() = default;
SerializationNode(const Output<Node> &arg, const std::shared_ptr<lowered::Expression>& expr);
SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr);
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;

View File

@ -70,11 +70,27 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
auto first_node = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
first_node->set_friendly_name("Start");
first_node->get_rt_info()["execTimeMcs"] = 0;
std::shared_ptr<Node> body_node = first_node;
std::shared_ptr<Node> serialization_node = first_node;
// This map allows to get LoopBegin serialization node by original LoopBegin node
// It is used to draw an edge between LoopBegin and LoopEnd serialization nodes
std::map<std::shared_ptr<snippets::op::LoopBegin>, std::shared_ptr<Node>> loops_map;
for (const auto& expr : m_expressions) {
body_node = std::make_shared<op::SerializationNode>(body_node, expr);
const auto node = expr->get_node();
if (auto loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(node)) {
OPENVINO_ASSERT(loops_map.count(loop_end->get_loop_begin()),
"Serialization can't find LoopBegin that corresponds to LoopEnd with friendly name ",
loop_end->get_friendly_name());
auto loop_begin_serialization_node = loops_map.at(loop_end->get_loop_begin());
serialization_node = std::make_shared<op::SerializationNode>(ov::OutputVector{serialization_node, loop_begin_serialization_node}, expr);
} else {
serialization_node = std::make_shared<op::SerializationNode>(ov::OutputVector{serialization_node}, expr);
if (auto loop_begin = ov::as_type_ptr<snippets::op::LoopBegin>(node)) {
loops_map[loop_begin] = serialization_node;
}
}
}
auto last_node = std::make_shared<ov::op::v0::Result>(body_node);
auto last_node = std::make_shared<ov::op::v0::Result>(serialization_node);
last_node->set_friendly_name("End");
const auto tmp_model = std::make_shared<ov::Model>(ResultVector {last_node},
ParameterVector {first_node},

View File

@ -9,8 +9,8 @@ namespace ov {
namespace snippets {
namespace op {
SerializationNode::SerializationNode(const Output<Node> &arg, const std::shared_ptr<lowered::Expression> &expr)
: Op({arg}), m_expr(expr) {
SerializationNode::SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr)
: Op(args), m_expr(expr) {
if (!m_expr || !m_expr->get_node())
OPENVINO_THROW("SerializationNode requires a valid expression with non-null node pointer");
const auto &node = expr->get_node();
@ -28,22 +28,22 @@ void SerializationNode::validate_and_infer_types() {
std::shared_ptr<Node> SerializationNode::clone_with_new_inputs(const OutputVector &new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<SerializationNode>(new_args.at(0), m_expr);
return std::make_shared<SerializationNode>(new_args, m_expr);
}
bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {
std::vector<std::pair<std::string, ov::PartialShape>> shapes;
const auto &node = m_expr->get_node();
for (size_t i = 0; i < node->get_input_size(); i++) {
const auto &pshape = node->get_input_partial_shape(i);
if (pshape.begin() != pshape.end())
shapes.emplace_back("in_shape_" + std::to_string(i), node->get_input_partial_shape(i));
std::vector<std::pair<std::string, std::vector<size_t>>> shapes;
for (size_t i = 0; i < m_expr->get_input_count(); i++) {
const auto &shape = m_expr->get_input_port_descriptor(i)->get_shape();
if (!shape.empty())
shapes.emplace_back("in_shape_" + std::to_string(i), shape);
}
for (size_t i = 0; i < node->get_output_size(); i++) {
const auto &pshape = node->get_output_partial_shape(i);
if (pshape.begin() != pshape.end())
shapes.emplace_back("out_shape_" + std::to_string(i), pshape);
for (size_t i = 0; i < m_expr->get_output_count(); i++) {
const auto &shape = m_expr->get_output_port_descriptor(i)->get_shape();
if (!shape.empty())
shapes.emplace_back("out_shape_" + std::to_string(i), shape);
}
auto loop_ids = m_expr->get_loop_ids();
auto rinfo = m_expr->get_reg_info();
if (!rinfo.first.empty())
@ -54,7 +54,7 @@ bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {
visitor.on_attribute(s.first, s.second);
visitor.on_attribute("loop_ids", loop_ids);
node->visit_attributes(visitor);
m_expr->get_node()->visit_attributes(visitor);
return true;
}