[CONFORMANCE] Readability in model_utils (#19175)
This commit is contained in:
parent
e48b2dfc34
commit
a0d1b91a78
@ -74,82 +74,89 @@ inline ExtractedPattern
|
||||
generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
|
||||
std::unordered_set<std::string>& checked_ops,
|
||||
const std::string& extractor_name) {
|
||||
std::unordered_map<std::string, std::shared_ptr<ov::Node>> model_map;
|
||||
// to create reults: { out_op_name, out_ports_without_target_inputs }
|
||||
// map to recover graph using cloned nodes and original connections
|
||||
// { original_node_name, cloned_node }
|
||||
std::unordered_map<std::string, std::shared_ptr<ov::Node>> cloned_node_map;
|
||||
// map to fill output nodes in models:
|
||||
// { original_node_names, out_port_idx_without_orig_node_to_check }
|
||||
std::unordered_map<std::string, std::unordered_set<size_t>> model_output_nodes;
|
||||
std::map<std::string, InputInfo> input_info;
|
||||
ov::ParameterVector params;
|
||||
std::map<std::string, InputInfo> model_input_info;
|
||||
ov::ParameterVector model_parameters;
|
||||
{
|
||||
// prepare map { original_op_name, cloned_op }
|
||||
size_t functional_op_cnt = 0;
|
||||
for (const auto& op : nodes) {
|
||||
auto op_name = op->get_friendly_name();
|
||||
checked_ops.insert(op_name);
|
||||
auto cloned_op = clone_node(op, true, false, op->get_friendly_name());
|
||||
model_map.insert({ op_name, cloned_op });
|
||||
|
||||
size_t output_cnt = op->outputs().size();
|
||||
std::vector<size_t> out_ports(output_cnt);
|
||||
// prepare map { original_op_name, cloned_node }
|
||||
size_t functional_node_cnt = 0;
|
||||
for (const auto& node : nodes) {
|
||||
auto orig_node_name = node->get_friendly_name();
|
||||
checked_ops.insert(orig_node_name);
|
||||
cloned_node_map.insert({ orig_node_name,
|
||||
clone_node(node, true, false, orig_node_name) });
|
||||
|
||||
// create temporary vector to fill node output indexes
|
||||
std::vector<size_t> out_ports(node->outputs().size());
|
||||
std::iota(out_ports.begin(), out_ports.end(), 0);
|
||||
std::unordered_set<size_t> out_ports_set(out_ports.begin(), out_ports.end());
|
||||
model_output_nodes.insert({ op_name, out_ports_set });
|
||||
if (!ov::op::util::is_output(op) && !ov::op::util::is_constant(op) && !ov::op::util::is_parameter(op)) {
|
||||
++functional_op_cnt;
|
||||
// fill by all nodes with output ports
|
||||
model_output_nodes.insert({
|
||||
orig_node_name,
|
||||
std::unordered_set<size_t>(out_ports.begin(), out_ports.end()) });
|
||||
if (!ov::op::util::is_output(node) &&
|
||||
!ov::op::util::is_constant(node) &&
|
||||
!ov::op::util::is_parameter(node)) {
|
||||
++functional_node_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
if (functional_op_cnt < 2) {
|
||||
throw std::runtime_error("Incorrect node number to create model");
|
||||
if (functional_node_cnt < 2) {
|
||||
throw std::runtime_error("Incorrect node number to create model!");
|
||||
}
|
||||
// replace new inputs by taken from graph if possible
|
||||
for (const auto& op : nodes) {
|
||||
// replace new inputs by taken from graph if possible and
|
||||
// find input and output nodes in future subgraph
|
||||
for (const auto& node : nodes) {
|
||||
// variable to store updated input index
|
||||
int filled_input_idx = -1;
|
||||
std::vector<size_t> not_filled_ports;
|
||||
auto in_cnt = op->inputs().size();
|
||||
auto cloned_op = model_map[op->get_friendly_name()];
|
||||
std::map<std::string, InputInfo> this_input_info = get_input_info_by_node(cloned_op);
|
||||
for (size_t in_idx = 0; in_idx < in_cnt; ++in_idx) {
|
||||
auto in_node = op->get_input_node_ptr(in_idx)->shared_from_this();
|
||||
for (size_t in_out_idx = 0; in_out_idx < in_node->outputs().size(); ++in_out_idx) {
|
||||
for (const auto& target_input : in_node->output(in_out_idx).get_target_inputs()) {
|
||||
auto out_in_node = target_input.get_node()->shared_from_this();
|
||||
if (out_in_node == op) {
|
||||
auto in_node_name = in_node->get_friendly_name();
|
||||
auto in_cloned_node = cloned_op->get_input_node_shared_ptr(in_idx);
|
||||
// if op input node is in subgraph
|
||||
if (model_map.count(in_node_name)) {
|
||||
auto in_node = model_map[in_node_name];
|
||||
auto in_cloned_friendly_name = in_cloned_node->get_friendly_name();
|
||||
ov::replace_output_update_name(in_cloned_node->get_default_output(), in_node->output(in_out_idx));
|
||||
in_cloned_node->clear_control_dependencies();
|
||||
if (ov::op::util::is_parameter(in_node)) {
|
||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(in_node);
|
||||
params.push_back(param);
|
||||
this_input_info.insert({ in_node->get_friendly_name(), this_input_info[in_cloned_friendly_name]});
|
||||
} else if (ov::op::util::is_constant(in_node)) {
|
||||
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(in_node);
|
||||
auto cloned_node = cloned_node_map[node->get_friendly_name()];
|
||||
auto node_input_info = get_input_info_by_node(cloned_node);
|
||||
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
|
||||
auto orig_in_node = node->get_input_node_ptr(in_idx)->shared_from_this();
|
||||
for (size_t out_idx = 0; out_idx < orig_in_node->outputs().size(); ++out_idx) {
|
||||
for (const auto& orig_node_to_check : orig_in_node->output(out_idx).get_target_inputs()) {
|
||||
if (orig_node_to_check.get_node()->shared_from_this() == node) {
|
||||
auto orig_in_node_name = orig_in_node->get_friendly_name();
|
||||
auto cloned_in_node = cloned_node->get_input_node_shared_ptr(in_idx);
|
||||
// if op input node is in subgraph replace parameters
|
||||
// in cloned node by other nodes from the map
|
||||
if (cloned_node_map.count(orig_in_node_name)) {
|
||||
auto orig_in_node = cloned_node_map[orig_in_node_name];
|
||||
auto cloned_in_node_name = cloned_in_node->get_friendly_name();
|
||||
ov::replace_output_update_name(cloned_in_node->get_default_output(), orig_in_node->output(out_idx));
|
||||
if (ov::op::util::is_parameter(orig_in_node)) {
|
||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(orig_in_node);
|
||||
model_parameters.push_back(param);
|
||||
node_input_info.insert({ orig_in_node->get_friendly_name(),
|
||||
node_input_info[cloned_in_node_name]});
|
||||
} else if (ov::op::util::is_constant(orig_in_node)) {
|
||||
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(orig_in_node);
|
||||
auto param = convert_const_to_param(op_to_replace);
|
||||
if (param != nullptr) {
|
||||
params.push_back(param);
|
||||
model_parameters.push_back(param);
|
||||
}
|
||||
// insert in_info with updated in_name
|
||||
this_input_info.insert({ in_node->get_friendly_name(), this_input_info[in_cloned_friendly_name]});
|
||||
node_input_info.insert({ orig_in_node->get_friendly_name(),
|
||||
node_input_info[cloned_in_node_name]});
|
||||
}
|
||||
// remove in_info with old name from input info
|
||||
this_input_info.erase(in_cloned_friendly_name);
|
||||
filled_input_idx++;
|
||||
model_output_nodes[in_node_name].erase(in_out_idx);
|
||||
if (model_output_nodes[in_node_name].empty()) {
|
||||
model_output_nodes.erase(in_node_name);
|
||||
// clean up replaced node data
|
||||
node_input_info.erase(cloned_in_node_name);
|
||||
model_output_nodes[orig_in_node_name].erase(out_idx);
|
||||
if (model_output_nodes[orig_in_node_name].empty()) {
|
||||
model_output_nodes.erase(orig_in_node_name);
|
||||
}
|
||||
} else if (ov::op::util::is_parameter(in_cloned_node)) {
|
||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(in_cloned_node);
|
||||
params.push_back(param);
|
||||
} else if (ov::op::util::is_constant(in_cloned_node)) {
|
||||
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(in_cloned_node);
|
||||
} else if (ov::op::util::is_parameter(cloned_in_node)) {
|
||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(cloned_in_node);
|
||||
model_parameters.push_back(param);
|
||||
} else if (ov::op::util::is_constant(cloned_in_node)) {
|
||||
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(cloned_in_node);
|
||||
auto param = convert_const_to_param(op_to_replace);
|
||||
if (param != nullptr) {
|
||||
params.push_back(param);
|
||||
model_parameters.push_back(param);
|
||||
}
|
||||
}
|
||||
break;
|
||||
@ -160,23 +167,25 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!this_input_info.empty()) {
|
||||
input_info.insert(this_input_info.begin(), this_input_info.end());
|
||||
if (!node_input_info.empty()) {
|
||||
model_input_info.insert(node_input_info.begin(), node_input_info.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
ov::ResultVector results;
|
||||
ov::ResultVector model_results;
|
||||
for (const auto& out_node_name : model_output_nodes) {
|
||||
auto out_node = model_map[out_node_name.first];
|
||||
auto out_node = cloned_node_map[out_node_name.first];
|
||||
if (ov::op::util::is_output(out_node)) {
|
||||
results.push_back(std::dynamic_pointer_cast<ov::op::v0::Result>(out_node));
|
||||
model_results.push_back(std::dynamic_pointer_cast<ov::op::v0::Result>(out_node));
|
||||
} else {
|
||||
for (const auto& out_port_id : out_node_name.second) {
|
||||
results.push_back(std::make_shared<ov::op::v0::Result>(out_node->output(out_port_id)));
|
||||
model_results.push_back(std::make_shared<ov::op::v0::Result>(out_node->output(out_port_id)));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto model = std::make_shared<ov::Model>(results, params);
|
||||
auto model = std::make_shared<ov::Model>(model_results, model_parameters);
|
||||
|
||||
// prepare unique model name based on operations from model
|
||||
std::string string_to_hash;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
std::ostringstream result;
|
||||
@ -193,12 +202,13 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
|
||||
}
|
||||
string_to_hash += result.str();
|
||||
}
|
||||
for (const auto& in : input_info) {
|
||||
for (const auto& in : model_input_info) {
|
||||
string_to_hash += (in.second.is_const ? "1" : "0");
|
||||
}
|
||||
auto h1 = std::hash<std::string>{}(string_to_hash);
|
||||
model->set_friendly_name(std::to_string(h1));
|
||||
return { model, input_info, extractor_name };
|
||||
|
||||
return { model, model_input_info, extractor_name };
|
||||
}
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
Loading…
Reference in New Issue
Block a user