[CONFORMANCE] Readability in model_utils (#19175)

This commit is contained in:
Irina Efode 2023-08-14 23:15:20 +04:00 committed by GitHub
parent e48b2dfc34
commit a0d1b91a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,82 +74,89 @@ inline ExtractedPattern
generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
std::unordered_set<std::string>& checked_ops,
const std::string& extractor_name) {
std::unordered_map<std::string, std::shared_ptr<ov::Node>> model_map;
// to create reults: { out_op_name, out_ports_without_target_inputs }
// map to recover graph using cloned nodes and original connections
// { original_node_name, cloned_node }
std::unordered_map<std::string, std::shared_ptr<ov::Node>> cloned_node_map;
// map to fill output nodes in models:
// { original_node_names, out_port_idx_without_orig_node_to_check }
std::unordered_map<std::string, std::unordered_set<size_t>> model_output_nodes;
std::map<std::string, InputInfo> input_info;
ov::ParameterVector params;
std::map<std::string, InputInfo> model_input_info;
ov::ParameterVector model_parameters;
{
// prepare map { original_op_name, cloned_op }
size_t functional_op_cnt = 0;
for (const auto& op : nodes) {
auto op_name = op->get_friendly_name();
checked_ops.insert(op_name);
auto cloned_op = clone_node(op, true, false, op->get_friendly_name());
model_map.insert({ op_name, cloned_op });
size_t output_cnt = op->outputs().size();
std::vector<size_t> out_ports(output_cnt);
// prepare map { original_op_name, cloned_node }
size_t functional_node_cnt = 0;
for (const auto& node : nodes) {
auto orig_node_name = node->get_friendly_name();
checked_ops.insert(orig_node_name);
cloned_node_map.insert({ orig_node_name,
clone_node(node, true, false, orig_node_name) });
// create temporary vector to fill node output indexes
std::vector<size_t> out_ports(node->outputs().size());
std::iota(out_ports.begin(), out_ports.end(), 0);
std::unordered_set<size_t> out_ports_set(out_ports.begin(), out_ports.end());
model_output_nodes.insert({ op_name, out_ports_set });
if (!ov::op::util::is_output(op) && !ov::op::util::is_constant(op) && !ov::op::util::is_parameter(op)) {
++functional_op_cnt;
// fill by all nodes with output ports
model_output_nodes.insert({
orig_node_name,
std::unordered_set<size_t>(out_ports.begin(), out_ports.end()) });
if (!ov::op::util::is_output(node) &&
!ov::op::util::is_constant(node) &&
!ov::op::util::is_parameter(node)) {
++functional_node_cnt;
}
}
if (functional_op_cnt < 2) {
throw std::runtime_error("Incorrect node number to create model");
if (functional_node_cnt < 2) {
throw std::runtime_error("Incorrect node number to create model!");
}
// replace new inputs by taken from graph if possible
for (const auto& op : nodes) {
// replace new inputs by taken from graph if possible and
// find input and output nodes in future subgraph
for (const auto& node : nodes) {
// variable to store updated input index
int filled_input_idx = -1;
std::vector<size_t> not_filled_ports;
auto in_cnt = op->inputs().size();
auto cloned_op = model_map[op->get_friendly_name()];
std::map<std::string, InputInfo> this_input_info = get_input_info_by_node(cloned_op);
for (size_t in_idx = 0; in_idx < in_cnt; ++in_idx) {
auto in_node = op->get_input_node_ptr(in_idx)->shared_from_this();
for (size_t in_out_idx = 0; in_out_idx < in_node->outputs().size(); ++in_out_idx) {
for (const auto& target_input : in_node->output(in_out_idx).get_target_inputs()) {
auto out_in_node = target_input.get_node()->shared_from_this();
if (out_in_node == op) {
auto in_node_name = in_node->get_friendly_name();
auto in_cloned_node = cloned_op->get_input_node_shared_ptr(in_idx);
// if op input node is in subgraph
if (model_map.count(in_node_name)) {
auto in_node = model_map[in_node_name];
auto in_cloned_friendly_name = in_cloned_node->get_friendly_name();
ov::replace_output_update_name(in_cloned_node->get_default_output(), in_node->output(in_out_idx));
in_cloned_node->clear_control_dependencies();
if (ov::op::util::is_parameter(in_node)) {
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(in_node);
params.push_back(param);
this_input_info.insert({ in_node->get_friendly_name(), this_input_info[in_cloned_friendly_name]});
} else if (ov::op::util::is_constant(in_node)) {
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(in_node);
auto cloned_node = cloned_node_map[node->get_friendly_name()];
auto node_input_info = get_input_info_by_node(cloned_node);
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
auto orig_in_node = node->get_input_node_ptr(in_idx)->shared_from_this();
for (size_t out_idx = 0; out_idx < orig_in_node->outputs().size(); ++out_idx) {
for (const auto& orig_node_to_check : orig_in_node->output(out_idx).get_target_inputs()) {
if (orig_node_to_check.get_node()->shared_from_this() == node) {
auto orig_in_node_name = orig_in_node->get_friendly_name();
auto cloned_in_node = cloned_node->get_input_node_shared_ptr(in_idx);
// if op input node is in subgraph replace parameters
// in cloned node by other nodes from the map
if (cloned_node_map.count(orig_in_node_name)) {
auto orig_in_node = cloned_node_map[orig_in_node_name];
auto cloned_in_node_name = cloned_in_node->get_friendly_name();
ov::replace_output_update_name(cloned_in_node->get_default_output(), orig_in_node->output(out_idx));
if (ov::op::util::is_parameter(orig_in_node)) {
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(orig_in_node);
model_parameters.push_back(param);
node_input_info.insert({ orig_in_node->get_friendly_name(),
node_input_info[cloned_in_node_name]});
} else if (ov::op::util::is_constant(orig_in_node)) {
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(orig_in_node);
auto param = convert_const_to_param(op_to_replace);
if (param != nullptr) {
params.push_back(param);
model_parameters.push_back(param);
}
// insert in_info with updated in_name
this_input_info.insert({ in_node->get_friendly_name(), this_input_info[in_cloned_friendly_name]});
node_input_info.insert({ orig_in_node->get_friendly_name(),
node_input_info[cloned_in_node_name]});
}
// remove in_info with old name from input info
this_input_info.erase(in_cloned_friendly_name);
filled_input_idx++;
model_output_nodes[in_node_name].erase(in_out_idx);
if (model_output_nodes[in_node_name].empty()) {
model_output_nodes.erase(in_node_name);
// clean up replaced node data
node_input_info.erase(cloned_in_node_name);
model_output_nodes[orig_in_node_name].erase(out_idx);
if (model_output_nodes[orig_in_node_name].empty()) {
model_output_nodes.erase(orig_in_node_name);
}
} else if (ov::op::util::is_parameter(in_cloned_node)) {
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(in_cloned_node);
params.push_back(param);
} else if (ov::op::util::is_constant(in_cloned_node)) {
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(in_cloned_node);
} else if (ov::op::util::is_parameter(cloned_in_node)) {
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(cloned_in_node);
model_parameters.push_back(param);
} else if (ov::op::util::is_constant(cloned_in_node)) {
auto op_to_replace = std::dynamic_pointer_cast<ov::op::v0::Constant>(cloned_in_node);
auto param = convert_const_to_param(op_to_replace);
if (param != nullptr) {
params.push_back(param);
model_parameters.push_back(param);
}
}
break;
@ -160,23 +167,25 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
}
}
}
if (!this_input_info.empty()) {
input_info.insert(this_input_info.begin(), this_input_info.end());
if (!node_input_info.empty()) {
model_input_info.insert(node_input_info.begin(), node_input_info.end());
}
}
}
ov::ResultVector results;
ov::ResultVector model_results;
for (const auto& out_node_name : model_output_nodes) {
auto out_node = model_map[out_node_name.first];
auto out_node = cloned_node_map[out_node_name.first];
if (ov::op::util::is_output(out_node)) {
results.push_back(std::dynamic_pointer_cast<ov::op::v0::Result>(out_node));
model_results.push_back(std::dynamic_pointer_cast<ov::op::v0::Result>(out_node));
} else {
for (const auto& out_port_id : out_node_name.second) {
results.push_back(std::make_shared<ov::op::v0::Result>(out_node->output(out_port_id)));
model_results.push_back(std::make_shared<ov::op::v0::Result>(out_node->output(out_port_id)));
}
}
}
auto model = std::make_shared<ov::Model>(results, params);
auto model = std::make_shared<ov::Model>(model_results, model_parameters);
// prepare unique model name based on operations from model
std::string string_to_hash;
for (const auto& op : model->get_ordered_ops()) {
std::ostringstream result;
@ -193,12 +202,13 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
}
string_to_hash += result.str();
}
for (const auto& in : input_info) {
for (const auto& in : model_input_info) {
string_to_hash += (in.second.is_const ? "1" : "0");
}
auto h1 = std::hash<std::string>{}(string_to_hash);
model->set_friendly_name(std::to_string(h1));
return { model, input_info, extractor_name };
return { model, model_input_info, extractor_name };
}
} // namespace subgraph_dumper