[Snippets] LIR serialization: additional connections between LoopBegin and LoopEnd nodes (#19630)
* [Snippets] LIR serialization improvements * Minor correction * Review comments
This commit is contained in:
parent
3c1b384694
commit
faa6b77247
@ -22,7 +22,7 @@ public:
|
||||
OPENVINO_OP("SerializationNode", "SnippetsOpset");
|
||||
|
||||
SerializationNode() = default;
|
||||
SerializationNode(const Output<Node> &arg, const std::shared_ptr<lowered::Expression>& expr);
|
||||
SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
|
||||
|
@ -70,11 +70,27 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
|
||||
auto first_node = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
|
||||
first_node->set_friendly_name("Start");
|
||||
first_node->get_rt_info()["execTimeMcs"] = 0;
|
||||
std::shared_ptr<Node> body_node = first_node;
|
||||
std::shared_ptr<Node> serialization_node = first_node;
|
||||
|
||||
// This map allows to get LoopBegin serialization node by original LoopBegin node
|
||||
// It is used to draw an edge between LoopBegin and LoopEnd serialization nodes
|
||||
std::map<std::shared_ptr<snippets::op::LoopBegin>, std::shared_ptr<Node>> loops_map;
|
||||
for (const auto& expr : m_expressions) {
|
||||
body_node = std::make_shared<op::SerializationNode>(body_node, expr);
|
||||
const auto node = expr->get_node();
|
||||
if (auto loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(node)) {
|
||||
OPENVINO_ASSERT(loops_map.count(loop_end->get_loop_begin()),
|
||||
"Serialization can't find LoopBegin that corresponds to LoopEnd with friendly name ",
|
||||
loop_end->get_friendly_name());
|
||||
auto loop_begin_serialization_node = loops_map.at(loop_end->get_loop_begin());
|
||||
serialization_node = std::make_shared<op::SerializationNode>(ov::OutputVector{serialization_node, loop_begin_serialization_node}, expr);
|
||||
} else {
|
||||
serialization_node = std::make_shared<op::SerializationNode>(ov::OutputVector{serialization_node}, expr);
|
||||
if (auto loop_begin = ov::as_type_ptr<snippets::op::LoopBegin>(node)) {
|
||||
loops_map[loop_begin] = serialization_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto last_node = std::make_shared<ov::op::v0::Result>(body_node);
|
||||
auto last_node = std::make_shared<ov::op::v0::Result>(serialization_node);
|
||||
last_node->set_friendly_name("End");
|
||||
const auto tmp_model = std::make_shared<ov::Model>(ResultVector {last_node},
|
||||
ParameterVector {first_node},
|
||||
|
@ -9,8 +9,8 @@ namespace ov {
|
||||
namespace snippets {
|
||||
namespace op {
|
||||
|
||||
SerializationNode::SerializationNode(const Output<Node> &arg, const std::shared_ptr<lowered::Expression> &expr)
|
||||
: Op({arg}), m_expr(expr) {
|
||||
SerializationNode::SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr)
|
||||
: Op(args), m_expr(expr) {
|
||||
if (!m_expr || !m_expr->get_node())
|
||||
OPENVINO_THROW("SerializationNode requires a valid expression with non-null node pointer");
|
||||
const auto &node = expr->get_node();
|
||||
@ -28,22 +28,22 @@ void SerializationNode::validate_and_infer_types() {
|
||||
|
||||
std::shared_ptr<Node> SerializationNode::clone_with_new_inputs(const OutputVector &new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<SerializationNode>(new_args.at(0), m_expr);
|
||||
return std::make_shared<SerializationNode>(new_args, m_expr);
|
||||
}
|
||||
|
||||
bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {
|
||||
std::vector<std::pair<std::string, ov::PartialShape>> shapes;
|
||||
const auto &node = m_expr->get_node();
|
||||
for (size_t i = 0; i < node->get_input_size(); i++) {
|
||||
const auto &pshape = node->get_input_partial_shape(i);
|
||||
if (pshape.begin() != pshape.end())
|
||||
shapes.emplace_back("in_shape_" + std::to_string(i), node->get_input_partial_shape(i));
|
||||
std::vector<std::pair<std::string, std::vector<size_t>>> shapes;
|
||||
for (size_t i = 0; i < m_expr->get_input_count(); i++) {
|
||||
const auto &shape = m_expr->get_input_port_descriptor(i)->get_shape();
|
||||
if (!shape.empty())
|
||||
shapes.emplace_back("in_shape_" + std::to_string(i), shape);
|
||||
}
|
||||
for (size_t i = 0; i < node->get_output_size(); i++) {
|
||||
const auto &pshape = node->get_output_partial_shape(i);
|
||||
if (pshape.begin() != pshape.end())
|
||||
shapes.emplace_back("out_shape_" + std::to_string(i), pshape);
|
||||
for (size_t i = 0; i < m_expr->get_output_count(); i++) {
|
||||
const auto &shape = m_expr->get_output_port_descriptor(i)->get_shape();
|
||||
if (!shape.empty())
|
||||
shapes.emplace_back("out_shape_" + std::to_string(i), shape);
|
||||
}
|
||||
|
||||
auto loop_ids = m_expr->get_loop_ids();
|
||||
auto rinfo = m_expr->get_reg_info();
|
||||
if (!rinfo.first.empty())
|
||||
@ -54,7 +54,7 @@ bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {
|
||||
visitor.on_attribute(s.first, s.second);
|
||||
|
||||
visitor.on_attribute("loop_ids", loop_ids);
|
||||
node->visit_attributes(visitor);
|
||||
m_expr->get_node()->visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user