[CONFORMANCE TOOLS][SUBAGRAPHS DUMPER] Repeat pattern extractor improvement (#21200)

* Init

* Implemented custom sort

* algo

* extend

* time

* decrease exec time

* fix code

* Impove pattern extractor

* repe

* Improve

* partial fix

* Patual fix

* fix_tests

* Remove extra + fix tests

* enable model

* Patial fix for align_in_info

* Exclude the nodes with the same size

* Prints

* update

* fix tests

* Fix crash

* Update cache.hpp
This commit is contained in:
Irina Efode 2023-12-15 10:34:33 +04:00 committed by GitHub
parent 193cc26658
commit d5b0f4d2d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 712 additions and 273 deletions

View File

@ -60,7 +60,10 @@ struct InputInfo {
}
InputInfo operator=(const InputInfo& input_info) {
if (this->is_const != input_info.is_const) {
auto default_in_info = InputInfo();
if (input_info == default_in_info) {
this->is_const = input_info.is_const;
} else if (this->is_const != input_info.is_const) {
throw std::runtime_error("Cast Const to Parameter! Impossible to update Input Info!");
}
this->ranges = input_info.ranges;

View File

@ -20,6 +20,8 @@ public:
void set_matchers(const MatchersMap& matchers = {}) { m_matchers = matchers; }
void set_shape_strict_match(bool shape_strict_match);
void set_match_attributes(bool match_attribute);
void set_match_in_types(bool match_in_types);
const MatchersMap& get_matchers() { return m_matchers; }
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;

View File

@ -21,6 +21,8 @@ public:
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
void set_strict_shape_match(bool strict_shape_match);
void set_match_attrib(bool match_attrib);
void set_match_in_types(bool match_in_types);
protected:
virtual void configure(const pugi::xml_document &cfg) {};
@ -37,6 +39,8 @@ protected:
std::vector<iMatcherConfig::Ptr> default_configs;
// match only shape ranks by default;
bool is_strict_shape_match = false;
bool is_match_attributes = true;
bool is_match_in_types = false;
};
} // namespace subgraph_dumper

View File

@ -11,10 +11,13 @@ namespace ov {
namespace tools {
namespace subgraph_dumper {
// note: please set any specific parameters related to graph comparation by `ModelComparetor::get()`
// for example node attributes match or shape strict comparation
class RepeatPatternExtractor final : public SubgraphExtractor {
private:
using InputVector = std::vector<ov::Input<ov::Node>>;
using OutputVector = std::vector<ov::Output<ov::Node>>;
using NodePair = std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>;
public:
using PatternBorders = std::pair<InputVector, OutputVector>;
@ -25,14 +28,28 @@ public:
std::vector<std::vector<ov::NodeVector>>
get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model);
void set_recursive_extraction(bool _is_recursive_extraction);
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
// minimal size of extracted subgraph
void set_min_graph_size(size_t _min_graph_size) {
min_graph_size = _min_graph_size;
}
// cut graphs by matched nodes
void set_split_by_matched_nodes(bool _is_split_by_matched_nodes) {
is_split_by_matched_nodes = _is_split_by_matched_nodes;
}
// recursive extraction from extracted subgraphs
void set_recursive_extraction(bool _is_recursive_extraction) {
is_recursive_extraction = _is_recursive_extraction;
}
protected:
// {subgraph, node_vector, input_info}
using ExtractedRepeatPattern = std::tuple<std::shared_ptr<ov::Model>, ov::NodeVector, std::map<std::string, ov::conformance::InputInfo>>;
bool is_recursive_extraction = true;
size_t min_graph_size = 2;
bool is_split_by_matched_nodes = false, is_recursive_extraction = false;
// find repeat patterns in model
std::list<std::vector<ExtractedRepeatPattern>>
find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
bool is_save_borders_only = false);
@ -40,8 +57,12 @@ protected:
std::list<std::vector<ExtractedRepeatPattern>>& secondary_extracted_patterns);
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
const std::shared_ptr<ov::Model>& pattern,
const ov::NodeVector& pattern_node_vector,
const std::vector<ov::NodeVector>& pattern_node_vector,
const std::map<std::string, ov::conformance::InputInfo>& in_info);
// extract repeated patterns by start_node
std::vector<std::vector<ov::NodeVector>>
get_patterns_by_nodes(const std::vector<size_t>& start_op_vec,
const ov::NodeVector& ordered_ops);
};

View File

@ -19,11 +19,15 @@ align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, ov::conformance::InputInfo> &in_info,
const std::map<std::string, ov::conformance::InputInfo> &in_info_ref,
const std::map<std::string, std::string> &matched_op = {});
const std::unordered_map<std::string, std::string> &matched_op);
// get set nodes of subgraph after start_node
void
get_subgraph_set_node(std::unordered_set<std::shared_ptr<ov::Node>>& nodes_to_check,
const std::shared_ptr<ov::Node>& node);
inline std::pair<std::shared_ptr<ov::Model>, std::map<std::string, ov::conformance::InputInfo>>
generate_model(ov::NodeVector& nodes,
std::unordered_set<std::string>& checked_ops,
bool is_copy_constants = true,
bool is_save_only_borders = false) {
// map to recover graph using cloned nodes and original connections
@ -39,7 +43,6 @@ generate_model(ov::NodeVector& nodes,
size_t functional_node_cnt = 0;
for (const auto& node : nodes) {
auto orig_node_name = node->get_friendly_name();
checked_ops.insert(orig_node_name);
cloned_node_map.insert({ orig_node_name,
clone_node(node, is_copy_constants, false, orig_node_name) });

View File

@ -16,7 +16,7 @@ class ModelComparator {
public:
using Ptr = std::shared_ptr<ModelComparator>;
// { is_match, subgraph, graph, matched_nodes -> {subgraph_op_name, graph_op_name}}
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::unordered_map<std::string, std::string>>;
using InputInfo = ov::conformance::InputInfo;
// { model, subgraph, graph, subgraph_in_info, model_in_info, }
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
@ -46,10 +46,22 @@ public:
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref);
// {{matched_node_id}}
std::vector<std::vector<size_t>>
get_matched_op_patterns(const ov::NodeVector& ordered_nodes);
// { op_name_subgraph, op_name_graph}
std::unordered_map<std::string, std::string>
get_matched_ops_in_graphs(const std::shared_ptr<ov::Model>& subgraph,
const std::shared_ptr<ov::Model>& graph,
bool is_check_inputs = false) const;
void set_match_coefficient(float _match_coefficient);
float get_match_coefficient() { return match_coefficient; }
void set_shape_strict_match(bool is_shape_strict_match);
void set_match_attributes(bool match_attributes);
void set_match_in_types(bool match_in_types);
protected:
ov::tools::subgraph_dumper::MatchersManager m_manager = ov::tools::subgraph_dumper::MatchersManager();

View File

@ -18,6 +18,9 @@ template <typename dType>
inline ov::conformance::InputInfo::Range
get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& node) {
size_t elements_count = ov::shape_size(node->get_shape());
if (elements_count == 0) {
throw std::runtime_error("Impossible to get const ranges! Incorrect const size!");
}
const auto& const_values = node->cast_vector<dType>();
auto max = *std::max_element(const_values.begin(), const_values.end());
auto min = *std::min_element(const_values.begin(), const_values.end());

View File

@ -129,16 +129,34 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
if (subgraph == cached_model.first) {
auto meta = m_graph_cache[subgraph];
meta.set_input_info(graph_in_info);
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
m_graph_cache.erase(subgraph);
m_graph_cache.insert({graph, meta});
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
} else {
m_graph_cache[cached_model.first].update(model_path,
subgraph_in_info,
model_op_cnt,
this_op_cnt,
extractor_name);
}
m_graph_cache[cached_model.first].update(model_path,
subgraph_in_info,
model_op_cnt,
this_op_cnt,
extractor_name);
return;
} else {
auto matched_ops = std::get<3>(m_model_comparator->is_subgraph(extracted_model, cached_model.first));
auto cached_model_op_cnt =
cached_model.first->get_ops().size() - cached_model.second.get_input_info().size() -
cached_model.first->get_results().size();
auto extracted_model_op_cnt =
extracted_model->get_ops().size() - input_info.size() - extracted_model->get_results().size();
if (matched_ops.size() > 0.75 * extracted_model_op_cnt) {
if (cached_model_op_cnt > extracted_model_op_cnt) {
return;
}
m_graph_cache_bytesize += (extracted_model->get_graph_size() - cached_model.first->get_graph_size());
m_graph_cache.erase(cached_model.first);
ov::conformance::MetaInfo meta(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
m_graph_cache.insert({extracted_model, meta});
return;
}
}
}
}
@ -158,8 +176,10 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
if (pattern_model_size < cached_model_size) {
m_graph_cache_bytesize -= (cached_model_size - pattern_model_size);
auto meta = m_graph_cache[model_to_update];
auto matched_ops = m_model_comparator->get_matched_ops_in_graphs(model_to_update, extracted_model);
auto new_in_info = ov::util::align_input_info(model_to_update, extracted_model,
m_graph_cache.at(model_to_update).get_input_info(), input_info);
m_graph_cache.at(model_to_update).get_input_info(), input_info,
matched_ops);
meta.set_input_info(new_in_info);
m_graph_cache.erase(model_to_update);
model_to_update = extracted_model;

View File

@ -23,6 +23,18 @@ void MatchersManager::set_shape_strict_match(bool shape_strict_match) {
}
}
void MatchersManager::set_match_attributes(bool match_attribute) {
for (const auto& matcher : m_matchers) {
matcher.second->set_match_attrib(match_attribute);
}
}
void MatchersManager::set_match_in_types(bool match_in_types) {
for (const auto& matcher : m_matchers) {
matcher.second->set_match_in_types(match_in_types);
}
}
bool MatchersManager::match(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref) const {
for (const auto& it : m_matchers) {

View File

@ -28,6 +28,14 @@ void SingleOpMatcher::set_strict_shape_match(bool strict_shape_match) {
is_strict_shape_match = strict_shape_match;
}
void SingleOpMatcher::set_match_attrib(bool match_attrib) {
is_match_attributes = match_attrib;
}
void SingleOpMatcher::set_match_in_types(bool match_in_types) {
is_match_in_types = match_in_types;
}
bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref) const {
if (node->get_input_size() != ref->get_input_size()) {
@ -52,6 +60,15 @@ bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
return false;
}
if (is_match_in_types) {
const auto& in_node = node->get_input_node_shared_ptr(port_id);
const auto& in_node_ref = ref->get_input_node_shared_ptr(port_id);
if (ov::util::is_node_to_skip(in_node) || ov::util::is_node_to_skip(in_node_ref)) {
continue;
} else if (in_node->get_type_info() != in_node_ref->get_type_info()) {
return false;
}
}
}
return true;
}
@ -125,15 +142,17 @@ bool SingleOpMatcher::match(const std::shared_ptr<ov::Node> &node,
if (!same_op_type(node, ref)) {
return false;
}
if (is_match_attributes) {
if (!match_attrs(node, ref) && !ov::util::is_node_to_skip(node)) {
return false;
}
}
if (!match_inputs(node, ref)) {
return false;
}
if (!match_outputs(node, ref)) {
return false;
}
if (!match_attrs(node, ref) && !ov::util::is_node_to_skip(node)) {
return false;
}
return true;
}

View File

@ -68,16 +68,15 @@ std::vector<FusedNamesExtractor::ExtractedPattern>
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
auto compiled_op_name = extract_compiled_model_names(model);
std::vector<FusedNamesExtractor::ExtractedPattern> matched_patterns;
std::unordered_set<std::string> checked_ops;
ov::NodeVector nodes;
for (const auto& op : model->get_ordered_ops()) {
auto op_name = op->get_friendly_name();
if (ov::util::is_node_to_skip(op) || checked_ops.count(op_name)) {
if (ov::util::is_node_to_skip(op)) {
continue;
}
if (compiled_op_name.count(op_name)) {
try {
auto extracted_pattern = ov::util::generate_model(nodes, checked_ops, is_save_const);
auto extracted_pattern = ov::util::generate_model(nodes, is_save_const);
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
} catch(std::exception& e) {
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
@ -111,7 +110,7 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
}
}
try {
auto extracted_pattern = ov::util::generate_model(nodes, checked_ops, is_save_const);
auto extracted_pattern = ov::util::generate_model(nodes, is_save_const);
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
} catch(std::exception& e) {
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {

View File

@ -11,10 +11,6 @@
using namespace ov::tools::subgraph_dumper;
void RepeatPatternExtractor::set_recursive_extraction(bool _is_recursive_extraction) {
is_recursive_extraction = _is_recursive_extraction;
}
std::vector<RepeatPatternExtractor::ExtractedPattern>
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
std::vector<ExtractedPattern> extracted_patterns;
@ -29,23 +25,31 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>
RepeatPatternExtractor::get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model) {
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>> extracted_patterns;
for (auto& pattern : find_repeat_patterns(model, true)) {
auto a = find_repeat_patterns(model, true);
for (auto& pattern : a) {
std::vector<RepeatPatternExtractor::PatternBorders> same_pattern_borders;
for (const auto& pattern_structure : pattern) {
std::set<std::string> output_names;
for (const auto& result : std::get<0>(pattern_structure)->get_results()) {
output_names.insert(result->get_input_node_shared_ptr(0)->get_friendly_name());
}
RepeatPatternExtractor::InputVector in_vec;
RepeatPatternExtractor::OutputVector out_vec;
for (const auto& node : std::get<1>(pattern_structure)) {
if (output_names.count(node->get_friendly_name())) {
OutputVector node_outputs = node->outputs();
out_vec.insert(out_vec.end(), node_outputs.begin(), node_outputs.end());
} else {
for (const auto& input : node->inputs()) {
in_vec.push_back(input);
auto node_vector = std::get<1>(pattern_structure);
for (const auto& node : node_vector) {
for (size_t out_idx = 0; out_idx < node->outputs().size(); ++out_idx) {
size_t idx = 0;
const auto target_inputs = node->get_output_target_inputs(out_idx);
for (const auto& target_input : target_inputs) {
const auto target_in_node = target_input.get_node()->shared_from_this();
if (std::find(node_vector.begin(), node_vector.end(), target_in_node) == node_vector.end()) {
++idx;
}
}
if (idx == target_inputs.size()) {
out_vec.push_back(node->output(out_idx));
}
}
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
auto in_node = node->get_input_node_shared_ptr(in_idx);
if (std::find(node_vector.begin(), node_vector.end(), in_node) == node_vector.end()) {
in_vec.push_back(node->input(in_idx));
}
}
}
@ -73,7 +77,7 @@ void
RepeatPatternExtractor::update_extractor_cache(
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
const std::shared_ptr<ov::Model>& pattern,
const ov::NodeVector& pattern_node_vector,
const std::vector<ov::NodeVector>& pattern_node_vector,
const std::map<std::string, ov::conformance::InputInfo>& pattern_in_info) {
for (auto& extracted_pattern : extracted_patterns) {
auto& pattern_structure = extracted_pattern.front();
@ -81,13 +85,20 @@ RepeatPatternExtractor::update_extractor_cache(
if (model_comparator->match(pattern, cached_pattern)) {
try {
const auto& cached_in_info = std::get<2>(pattern_structure);
ov::util::align_input_info(pattern, cached_pattern, pattern_in_info, cached_in_info);
extracted_pattern.push_back({ pattern, pattern_node_vector, pattern_in_info });
ov::util::align_input_info(pattern, cached_pattern,
pattern_in_info, cached_in_info,
model_comparator->get_matched_ops_in_graphs(pattern, cached_pattern));
for (const auto& p : pattern_node_vector) {
extracted_pattern.push_back({ pattern, p, pattern_in_info });
}
return;
} catch(std::exception) {}
}
}
extracted_patterns.push_back({{ pattern, pattern_node_vector, pattern_in_info }});
extracted_patterns.push_back({{ pattern, pattern_node_vector.front(), pattern_in_info }});
for (size_t i = 1; i < pattern_node_vector.size(); ++i) {
extracted_patterns.back().push_back({ pattern, pattern_node_vector[i], pattern_in_info });
}
}
void
@ -101,117 +112,191 @@ RepeatPatternExtractor::update_extractor_cache(
const auto& pattern = std::get<0>(pattern_structure);
const auto& pattern_node_vector = std::get<1>(pattern_structure);
const auto& pattern_in_info = std::get<2>(pattern_structure);
update_extractor_cache(extracted_patterns, pattern, pattern_node_vector, pattern_in_info);
update_extractor_cache(extracted_patterns, pattern, {pattern_node_vector}, pattern_in_info);
extern_it->pop_back();
}
secondary_extracted_patterns.pop_front();
}
}
std::vector<std::vector<ov::NodeVector>>
RepeatPatternExtractor::get_patterns_by_nodes(const std::vector<size_t>& start_op_vec,
const ov::NodeVector& ordered_ops) {
// handle case impossible to extract repeat patterns
if (start_op_vec.size() < 2 || ordered_ops.size() < 3) {
return {{}};
}
// is_recursive_extraction = true;
// prepare node vectors contains potential patterns from start_node to output
// first one is biggest subgraph, last one is smallest one
auto pattern_cnt = start_op_vec.size();
std::vector<ov::NodeVector> patterns(pattern_cnt);
for (size_t pattern_idx = 0; pattern_idx < pattern_cnt; ++pattern_idx) {
// get only nodes are after start node in graph
std::unordered_set<std::shared_ptr<ov::Node>> nodes_to_check;
const auto& start_op_idx = start_op_vec[pattern_idx];
util::get_subgraph_set_node(nodes_to_check, ordered_ops[start_op_idx]);
for (const auto& op : ordered_ops) {
if (nodes_to_check.count(op)) {
patterns[pattern_idx].push_back(op);
}
}
}
{
// reverse node vectors to anylize from small to big subgraphs
std::reverse(patterns.begin(), patterns.end());
std::vector<ov::NodeVector> potential_patterns(pattern_cnt);
for (size_t i_orig = 0; i_orig < pattern_cnt - 1; ++i_orig) {
// skip comparation in case pattern is iniatized
if (!potential_patterns[i_orig].empty()) {
continue;
}
for (size_t i_ref = i_orig + 1; i_ref < pattern_cnt; ++i_ref) {
if (!potential_patterns[i_ref].empty()) {
continue;
}
// extract minimal intersected patterns
auto intersection_len = std::min(patterns[i_orig].size(), patterns[i_ref].size());
ov::NodeVector pattern_orig(intersection_len, nullptr), pattern_ref(intersection_len, nullptr);
for (size_t j = 0; j < intersection_len; ++j) {
if (model_comparator->match(patterns[i_orig][j], patterns[i_ref][j])) {
if (patterns[i_orig][j] == patterns[i_ref][j]) {
break;
}
if (is_split_by_matched_nodes) {
if (model_comparator->match(patterns[i_orig][0], patterns[i_orig][j]) ||
model_comparator->match(patterns[i_ref][0], patterns[i_ref][j])) {
break;
}
}
// check inputs and matching in case not start_node
if (j != 0) {
bool is_input_matched = false;
for (size_t input_idx = 0; input_idx < patterns[i_orig][j]->inputs().size(); ++input_idx) {
auto in_orig = patterns[i_orig][j]->get_input_node_shared_ptr(input_idx);
auto in_ref = patterns[i_ref][j]->get_input_node_shared_ptr(input_idx);
if (std::find(pattern_orig.begin(), pattern_orig.end(), in_orig) != pattern_orig.end() &&
std::find(pattern_ref.begin(), pattern_ref.end(), in_ref) != pattern_ref.end()) {
is_input_matched = true;
break;
}
}
if (!is_input_matched) {
continue;
}
}
pattern_orig[j] = patterns[i_orig][j];
pattern_ref[j] = patterns[i_ref][j];
}
}
// fill vectors only by valid nodes
ov::NodeVector orig, ref;
for (size_t node_idx = 0; node_idx < pattern_orig.size(); ++node_idx) {
if (pattern_orig[node_idx] != 0) {
orig.emplace_back(pattern_orig[node_idx]);
}
if (pattern_ref[node_idx] != 0) {
ref.emplace_back(pattern_ref[node_idx]);
}
}
if (orig.size() < min_graph_size) {
continue;
}
potential_patterns[i_orig] = orig;
potential_patterns[i_ref] = ref;
}
}
// sort patterns by node vectors size
std::sort(potential_patterns.begin(), potential_patterns.end(), [](const ov::NodeVector& a, const ov::NodeVector& b) {
return a.size() > b.size();
});
// exclude not repeated pattern
while (potential_patterns.rbegin()->size() < 2 && !potential_patterns.empty()) {
potential_patterns.pop_back();
}
patterns = potential_patterns;
}
// group node vectors to the patterns: std::vector<ov::NodeVector>
std::vector<std::vector<ov::NodeVector>> pattern_vec;
for (size_t pattern_idx = 0; pattern_idx < patterns.size(); ++pattern_idx) {
const auto& pattern = patterns[pattern_idx];
if (pattern_vec.empty()) {
pattern_vec.push_back({{pattern}});
} else if (pattern_vec.rbegin()->begin()->size() != pattern.size()) {
pattern_vec.push_back({{pattern}});
} else {
auto it = pattern_vec.rbegin();
while (it != pattern_vec.rend()) {
auto ref = it->front();
if (ref.size() != pattern.size()) {
pattern_vec.push_back({{pattern}});
break;
}
bool is_matched = true;
for (size_t i = 0; i < pattern.size(); ++i) {
if (!model_comparator->match(pattern[i], ref[i])) {
is_matched = false;
break;
}
}
if (is_matched) {
it->push_back(pattern);
break;
} else {
it++;
}
}
if (it == pattern_vec.rend()) {
pattern_vec.push_back({{pattern}});
}
}
}
return pattern_vec;
}
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>
RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
bool is_save_borders_only) {
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>> extracted_patterns;
std::unordered_set<std::string> checked_ops;
auto ordered_ops = model->get_ordered_ops();
auto op_cnt = ordered_ops.size();
if (ordered_ops.size() < 2) {
return extracted_patterns;
}
auto matched_nodes_pattern = model_comparator->get_matched_op_patterns(ordered_ops);
for (size_t idx = 0; idx < op_cnt; ++idx) {
auto op = ordered_ops[idx];
auto op_name = op->get_friendly_name();
if (checked_ops.count(op_name)|| ov::util::is_node_to_skip(op)) {
for (size_t i = 0; i < matched_nodes_pattern.size(); ++i) {
auto matched_nodes = matched_nodes_pattern[i];
if (matched_nodes.size() < 2 || i > 0 && matched_nodes.size() == matched_nodes_pattern[i - 1].size()) {
continue;
}
// find the same nodes
std::vector<size_t> start_node_idx{idx};
for (size_t i = idx + 1; i < op_cnt; ++i) {
if (model_comparator->match(op, ordered_ops[i])) {
start_node_idx.push_back(i);
}
}
if (start_node_idx.size() < 2) {
checked_ops.insert(op_name);
continue;
}
std::vector<std::set<std::shared_ptr<ov::Node>>> nodes(start_node_idx.size());
for (size_t i = 0; i < start_node_idx.size(); ++i) {
for (size_t j = i + 1; j < start_node_idx.size(); ++j) {
size_t node_idx = start_node_idx[i], ref_node_idx = start_node_idx[j];
while (node_idx < op_cnt && ref_node_idx < op_cnt) {
auto node = ordered_ops[node_idx];
auto ref_node = ordered_ops[ref_node_idx];
if (checked_ops.count(node->get_friendly_name()) ||
checked_ops.count(ref_node->get_friendly_name())) {
break;
}
if (!ov::util::is_node_to_skip(node) &&
!ov::util::is_node_to_skip(ref_node)) {
if (node_idx == start_node_idx[i] && ref_node_idx == start_node_idx[j]) {
nodes[i].insert(node);
nodes[j].insert(ref_node);
} else if (model_comparator->match(node, ref_node)) {
// check if we met the same node
if (model_comparator->match(node, op)) {
break;
}
if (checked_ops.count(node->get_friendly_name()) ||
checked_ops.count(ref_node->get_friendly_name())) {
break;
}
// check that any input node is using in graph
bool is_input_in_graph = false;
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
auto in_node = node->get_input_node_ptr(in_idx)->shared_from_this();
auto ref_in_node = ref_node->get_input_node_ptr(in_idx)->shared_from_this();
if (nodes[i].count(in_node) && nodes[j].count(ref_in_node)) {
is_input_in_graph = true;
break;
}
}
if (!is_input_in_graph) {
break;
}
nodes[i].insert(ordered_ops[node_idx]);
nodes[j].insert(ordered_ops[ref_node_idx]);
} else {
break;
}
}
++node_idx;
++ref_node_idx;
}
}
}
for (size_t i = 0; i < start_node_idx.size(); ++i) {
for (auto& nodes_vector : get_patterns_by_nodes(matched_nodes, ordered_ops)) {
try {
std::unordered_set<std::string> tmp_checked_ops;
// model, in_info, extractor_name
ov::NodeVector nodes_vector(nodes[i].begin(), nodes[i].end());
auto extracted_pattern = ov::util::generate_model(nodes_vector, tmp_checked_ops, is_save_const, is_save_borders_only);
if (nodes_vector.size() < 1) {
continue;
}
auto extracted_pattern = ov::util::generate_model(nodes_vector.front(), is_save_const, is_save_borders_only);
auto extracted_model = extracted_pattern.first;
if (is_recursive_extraction && nodes_vector.size() > 20) {
auto secondary_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
if (!secondary_patterns.empty()) {
tmp_checked_ops.clear();
update_extractor_cache(extracted_patterns, secondary_patterns);
} else {
update_extractor_cache(extracted_patterns,
extracted_model,
nodes_vector,
extracted_pattern.second);
auto extracted_input_info = extracted_pattern.second;
if (extracted_model == nullptr) {
continue;
}
bool is_insert_res = true;
if (is_recursive_extraction) {
auto tmp_extracted_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
if (!tmp_extracted_patterns.empty()) {
is_insert_res = false;
update_extractor_cache(extracted_patterns, tmp_extracted_patterns);
}
} else {
}
if (is_insert_res) {
update_extractor_cache(extracted_patterns,
extracted_model,
nodes_vector,
extracted_pattern.second);
extracted_input_info);
}
nodes[i].clear();
checked_ops.insert(tmp_checked_ops.begin(), tmp_checked_ops.end());
} catch(std::exception& e) {
if (std::string(e.what()).find("Incorrect node number to create model!") == std::string::npos) {
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;
@ -219,23 +304,28 @@ RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &m
}
}
if (is_extract_body) {
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(op);
auto ti_body = ti->get_function();
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(op)) {
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(op);
auto loop_body = loop->get_function();
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(op)) {
auto if_op = ov::as_type_ptr<ov::op::v8::If>(op);
std::vector<std::shared_ptr<ov::Model>> bodies;
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
auto if_body = if_op->get_function(i);
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
for (const auto& matched_node_idx : matched_nodes) {
const auto& matched_node = ordered_ops[matched_node_idx];
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(matched_node)) {
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(matched_node);
auto ti_body = ti->get_function();
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(matched_node)) {
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(matched_node);
auto loop_body = loop->get_function();
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(matched_node)) {
auto if_op = ov::as_type_ptr<ov::op::v8::If>(matched_node);
std::vector<std::shared_ptr<ov::Model>> bodies;
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
auto if_body = if_op->get_function(i);
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
}
} else {
break;
}
}
}

View File

@ -41,51 +41,40 @@ align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, ov::conformance::InputInfo>& in_info,
const std::map<std::string, ov::conformance::InputInfo>& in_info_ref,
const std::map<std::string, std::string> &matched_op) {
bool is_update_required = !matched_op.empty();
if (!is_update_required) {
for (const auto& ref_item : in_info_ref) {
if (!in_info.count(ref_item.first)) {
is_update_required = true;
break;
} else if (in_info.at(ref_item.first).is_const != ref_item.second.is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
const std::unordered_map<std::string, std::string> &matched_op) {
std::map<std::string, ov::conformance::InputInfo> updated_input_info(in_info_ref);
for (const auto& op : model->get_ordered_ops()) {
const auto op_name = op->get_friendly_name();
if (!in_info.count(op_name)) {
continue;
}
}
std::map<std::string, ov::conformance::InputInfo> updated_input_info = in_info_ref;
if (is_update_required) {
// align matched model names
const auto& ref_model_ops = model_ref->get_ordered_ops();
const auto& model_ops = model->get_ordered_ops();
size_t ref_ordered_ops_size = ref_model_ops.size();
size_t ordered_ops_size = model_ops.size();
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
throw std::runtime_error("Matched models can not be compared according different op numbers!");
}
for (size_t i = 0; i < ordered_ops_size; ++i) {
auto model_op_name = model_ops[i]->get_friendly_name();
if (!in_info.count(model_op_name)) {
continue;
}
if (!matched_op.empty()) {
if (!matched_op.count(model_op_name)) {
continue;
}
}
auto model_ref_op_name = matched_op.empty() ? ref_model_ops[i]->get_friendly_name() : matched_op.at(model_op_name);
const auto& in_info_item = in_info.at(model_op_name);
const auto& ref_in_info_item = in_info_ref.at(model_ref_op_name);
if (in_info_item.is_const != ref_in_info_item.is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
updated_input_info[model_ref_op_name] = in_info_item;
if (matched_op.count(op_name) && in_info_ref.count(matched_op.at(op_name))) {
updated_input_info[matched_op.at(op_name)] = in_info.at(op_name);
}
}
return updated_input_info;
}
void
get_subgraph_set_node(std::unordered_set<std::shared_ptr<ov::Node>>& nodes_to_check,
const std::shared_ptr<ov::Node>& node) {
if (nodes_to_check.empty()) {
nodes_to_check.insert(node);
}
for (size_t out_idx = 0; out_idx < node->outputs().size(); ++out_idx) {
for (const auto& out : node->get_output_target_inputs(out_idx)) {
const auto& output_node = out.get_node()->shared_from_this();
if (ov::op::util::is_output(output_node)) {
return;
}
if (!nodes_to_check.count(output_node)) {
nodes_to_check.insert(output_node);
get_subgraph_set_node(nodes_to_check, output_node);
}
}
}
return;
}
} // namespace util
} // namespace ov

View File

@ -20,46 +20,43 @@ void ModelComparator::set_shape_strict_match(bool in_is_shape_strict_match) {
m_manager.set_shape_strict_match(in_is_shape_strict_match);
}
void ModelComparator::set_match_attributes(bool match_attributes) {
m_manager.set_match_attributes(match_attributes);
}
void ModelComparator::set_match_in_types(bool match_in_types) {
m_manager.set_match_in_types(true);
}
inline ModelComparator::IsSubgraphTuple
prepare_is_subgraph_result(bool is_subgraph,
const std::shared_ptr<ov::Model>& subgraph,
const std::shared_ptr<ov::Model>& graph,
const std::map<std::string, std::string>& matched_ops) {
return is_subgraph ?
std::make_tuple(is_subgraph, subgraph, graph, matched_ops) :
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
const std::unordered_map<std::string, std::string>& matched_ops) {
return std::make_tuple(is_subgraph, subgraph, graph, matched_ops);
}
ModelComparator::IsSubgraphTuple
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const {
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
ref_ordered_ops = ref_model->get_ordered_ops();
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
auto in_info = ov::util::get_input_info_by_model(model);
auto in_info_ref = ov::util::get_input_info_by_model(ref_model);
size_t ordered_ops_cnt = model->get_ordered_ops().size() - in_info.size() - model->get_results().size(),
ref_ordered_ops_cnt = ref_model->get_ordered_ops().size() - in_info_ref.size() - ref_model->get_results().size();
bool is_model = ordered_ops_cnt > ref_ordered_ops_cnt;
size_t subgraph_to_check_ops_cnt;
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
if (is_model) {
graph_to_check_ops = ordered_ops;
subgraph_to_check_ops = ref_ordered_ops;
graph = model;
subgraph = ref_model;
subgraph_to_check_ops_cnt = ref_ordered_ops_cnt;
} else {
graph_to_check_ops = ref_ordered_ops;
subgraph_to_check_ops = ordered_ops;
graph = ref_model;
subgraph = model;
subgraph_to_check_ops_cnt = ordered_ops_cnt;
}
std::map<std::string, std::string> matched_op_names;
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
if (m_manager.match(*graph_it, *subgraph_it)) {
matched_op_names.insert({ (*subgraph_it)->get_friendly_name(), (*graph_it)->get_friendly_name()});
++subgraph_it;
}
++graph_it;
}
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), subgraph, graph, matched_op_names);
auto matched_op_names = get_matched_ops_in_graphs(subgraph, graph);
return prepare_is_subgraph_result(matched_op_names.size() == subgraph_to_check_ops_cnt, subgraph, graph, matched_op_names);
}
bool
@ -94,17 +91,17 @@ ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref) {
ModelComparator::ExtractedSubgraphTuple res = { false, nullptr, nullptr, {}, {} };
m_manager.set_match_in_types(true);
auto extractor_res = is_subgraph(model, ref_model);
if (std::get<0>(extractor_res)) {
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
std::shared_ptr<ov::Model> subgraph = nullptr, graph = nullptr;
// if (model == subgraph && ref_model == graph)
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
subgraph = model;
subgraph_in_info = in_info;
graph = ref_model;
graph_in_info = in_info_ref;
// else if (subgraph == ref_model && graph = model)
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
subgraph = ref_model;
subgraph_in_info = in_info_ref;
@ -114,11 +111,50 @@ ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
throw std::runtime_error("Generated models are incompatible with original ones!");
}
try {
subgraph_in_info = ov::util::align_input_info(subgraph, graph, subgraph_in_info, graph_in_info);
return { true, subgraph, graph, subgraph_in_info, graph_in_info };
auto subgraph_in_info_new = ov::util::align_input_info(subgraph, graph,
subgraph_in_info, graph_in_info,
get_matched_ops_in_graphs(subgraph, graph));
res = { true, subgraph, graph, subgraph_in_info_new, graph_in_info };
} catch(std::exception) {}
}
return { false, nullptr, nullptr, {}, {} };
m_manager.set_match_in_types(false);
return res;
}
std::unordered_map<std::string, std::string>
ModelComparator::get_matched_ops_in_graphs(const std::shared_ptr<ov::Model>& subgraph,
const std::shared_ptr<ov::Model>& graph,
bool is_check_inputs) const {
std::unordered_map<std::string, std::string> matched_op_names;
std::unordered_set<std::string> checked_op;
const auto subgraph_to_check_ops = subgraph->get_ordered_ops();
const auto graph_to_check_ops = graph->get_ordered_ops();
for (const auto& subgraph_op : subgraph_to_check_ops) {
for (const auto& graph_op : graph_to_check_ops) {
if (ov::util::is_node_to_skip(subgraph_op) ||
ov::util::is_node_to_skip(graph_op)) {
continue;
}
if (match(subgraph_op, graph_op) && !checked_op.count(graph_op->get_friendly_name())) {
matched_op_names.insert({subgraph_op->get_friendly_name(), graph_op->get_friendly_name()});
checked_op.insert(graph_op->get_friendly_name());
if (is_check_inputs) {
for (size_t idx = 0; idx < graph_op->inputs().size(); ++idx) {
auto graph_in = graph_op->get_input_node_shared_ptr(idx);
auto subgraph_in = subgraph_op->get_input_node_shared_ptr(idx);
if (ov::util::is_node_to_skip(graph_in) && ov::util::is_node_to_skip(subgraph_in)) {
if (match(subgraph_in, graph_in)) {
matched_op_names.insert({subgraph_in->get_friendly_name(), graph_in->get_friendly_name()});
checked_op.insert(graph_in->get_friendly_name());
}
}
}
}
break;
}
}
}
return matched_op_names;
}
std::pair<bool, std::map<std::string, ov::conformance::InputInfo>>
@ -128,9 +164,35 @@ ModelComparator::match(const std::shared_ptr<ov::Model> &model,
const std::map<std::string, InputInfo> &in_info_ref) {
try {
if (match(model, model_ref)) {
auto new_input_info = ov::util::align_input_info(model, model_ref, in_info, in_info_ref);
auto new_input_info = ov::util::align_input_info(model, model_ref,
in_info, in_info_ref,
get_matched_ops_in_graphs(model, model_ref, true));
return {true, new_input_info};
}
} catch (std::exception) {}
return {false, {}};
}
std::vector<std::vector<size_t>>
ModelComparator::get_matched_op_patterns(const ov::NodeVector& ordered_ops) {
std::vector<std::vector<size_t>> matched_nodes;
for (size_t node_idx = 0; node_idx < ordered_ops.size(); ++node_idx) {
bool is_matched = false;
for (auto& matched_node_idx : matched_nodes) {
if (match(ordered_ops[matched_node_idx.front()], ordered_ops[node_idx])) {
matched_node_idx.push_back(node_idx);
is_matched = true;
break;
}
}
if (!is_matched && !ov::util::is_node_to_skip(ordered_ops[node_idx])) {
matched_nodes.push_back({node_idx});
}
}
std::sort(matched_nodes.begin(), matched_nodes.end(),
[](const std::vector<size_t>& a, const std::vector<size_t>& b){ return a.size() > b.size(); });
while (!matched_nodes.empty() && matched_nodes.rbegin()->size() == 1) {
matched_nodes.pop_back();
}
return matched_nodes;
}

View File

@ -12,14 +12,18 @@
#include "test_models/model_0.hpp"
#include "test_models/model_1.hpp"
#include "test_models/model_2.hpp"
#include "test_models/model_3.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/serialize.hpp"
namespace {
using namespace ov::tools::subgraph_dumper;
// ======================= ExtractorsManagerTest Unit tests =======================
class RepeatPatternExtractorTest : public SubgraphsDumperBaseTest {
// ======================= ExtractorsManagerTest Func tests =======================
class RepeatPatternExtractorFuncTest : public SubgraphsDumperBaseTest {
protected:
RepeatPatternExtractor extractor;
@ -52,41 +56,57 @@ protected:
std::sort(pattern_vec.begin(), pattern_vec.end());
}
// not allowed to sort inputs/outputs according there are not copy constructor
// void sort_borders(std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec) {
// for (auto& pattern : pattern_vec) {
// for (auto& node_vec : pattern) {
// std::sort(node_vec.first.begin(), node_vec.first.end());
// std::sort(node_vec.second.begin(), node_vec.second.end());
// }
// std::sort(pattern.begin(), pattern.end());
// }
// std::sort(pattern_vec.begin(), pattern_vec.end());
// }
void
is_equal_borders(const std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec_orig,
const std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec_ref) {
ASSERT_EQ(pattern_vec_orig.size(), pattern_vec_ref.size());
size_t orig_borders_cnt = 0, ref_borderd_cnt = 0, eq_borders = 0;
for (const auto& pattern_orig : pattern_vec_orig) {
orig_borders_cnt += pattern_orig.size();
ref_borderd_cnt = 0;
for (const auto& pattern_ref : pattern_vec_ref) {
ref_borderd_cnt += pattern_ref.size();
if (pattern_ref.size() != pattern_orig.size()) {
continue;
}
for (const auto& node_vec_orig : pattern_orig) {
// size_t eq_pattens = 0;
for (const auto& node_vec_ref : pattern_ref) {
if (node_vec_orig == node_vec_ref) {
++eq_borders;
break;
}
}
}
}
}
ASSERT_EQ(orig_borders_cnt, ref_borderd_cnt);
ASSERT_EQ(orig_borders_cnt, eq_borders);
}
};
TEST_F(RepeatPatternExtractorTest, extract_0) {
TEST_F(RepeatPatternExtractorFuncTest, extract_0) {
auto test_model = Model_0();
auto models = extractor.extract(test_model.get());
auto ref = test_model.get_repeat_pattern_ref();
ASSERT_TRUE(is_match(models, ref));
}
TEST_F(RepeatPatternExtractorTest, extract_1) {
TEST_F(RepeatPatternExtractorFuncTest, extract_1) {
auto test_model = Model_1();
auto models = extractor.extract(test_model.get());
auto ref = test_model.get_repeat_pattern_ref();
ASSERT_TRUE(is_match(models, ref));
}
TEST_F(RepeatPatternExtractorTest, extract_2) {
TEST_F(RepeatPatternExtractorFuncTest, extract_2) {
auto test_model = Model_2();
auto models = extractor.extract(test_model.get());
auto ref = test_model.get_repeat_pattern_ref();
ASSERT_TRUE(is_match(models, ref));
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_0) {
auto test_model = Model_0();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
@ -95,7 +115,7 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_1) {
auto test_model = Model_1();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
@ -104,7 +124,7 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_2) {
auto test_model = Model_2();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
@ -113,32 +133,24 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_0) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_0) {
auto test_model = Model_0();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
is_equal_borders(extracted_borders, ref_borders);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_1) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_1) {
auto test_model = Model_1();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
is_equal_borders(extracted_borders, ref_borders);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_2) {
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_2) {
auto test_model = Model_2();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
is_equal_borders(extracted_borders, ref_borders);
}
} // namespace

View File

@ -56,7 +56,7 @@ public:
{
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
ref_pattern_1 = {test_abs_1->inputs(), test_relu_1->outputs()};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_1}};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_1, ref_pattern_0}};
ref_borders = std::move(ref_res);
}
}

View File

@ -93,7 +93,7 @@ public:
std::shared_ptr<ov::op::v1::Multiply> test_multiply_0_1 =
std::make_shared<ov::op::v1::Multiply>(test_add_0, test_multiply_0_0);
test_multiply_0_0->set_friendly_name("Op_" + std::to_string(op_idx++));
test_multiply_0_1->set_friendly_name("Op_" + std::to_string(op_idx++));
std::shared_ptr<ov::op::v0::Relu> test_relu_0_1 =
std::make_shared<ov::op::v0::Relu>(test_multiply_0_1);
@ -134,11 +134,15 @@ public:
ref_pattern_0_1_0 = {test_abs_0_1->inputs(), test_clamp_0_1->outputs()},
test_pattern_0_1_1 = {test_multiply_0_1->inputs(), test_relu_0_1->outputs()},
test_pattern_1_1 = {test_multiply_1_1->inputs(), test_relu_1_1->outputs()};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_0_0},
{ref_pattern_1, ref_pattern_0_1_0},
{test_pattern_0_1_1, test_pattern_1_1}};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0_0, ref_pattern_0},
{ref_pattern_0_1_0, ref_pattern_1},
{test_pattern_1_1, test_pattern_0_1_1}};
ref_borders = std::move(ref_res);
}
start_ops = {test_abs_0, test_abs_0_0, test_abs_0_1, test_abs_1};
out_nodes = {test_abs_0, test_relu_0, test_add_0, test_multiply_0_1,
test_relu_0_1, test_add};
start_node = test_abs_0;
}
std::shared_ptr<ov::Model> get() {
@ -197,8 +201,24 @@ public:
std::vector<std::vector<PatternBorders>>
get_ref_node_borders() { return ref_borders; }
ov::NodeVector
get_start_ops() { return start_ops; }
std::unordered_set<std::shared_ptr<ov::Node>>
get_out_nodes_after_abs_0() {
return out_nodes;
}
std::shared_ptr<ov::Node>
get_test_abs_0() {
return start_node;
}
protected:
std::shared_ptr<ov::Model> model;
std::vector<std::vector<ov::NodeVector>> ref_nodes;
std::vector<std::vector<PatternBorders>> ref_borders;
ov::NodeVector start_ops;
std::unordered_set<std::shared_ptr<ov::Node>> out_nodes;
std::shared_ptr<ov::Node> start_node;
};

View File

@ -0,0 +1,159 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/abs.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/clamp.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/split.hpp"
#include "matchers/subgraph/repeat_pattern.hpp"
class Model_3 {
protected:
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
std::shared_ptr<ov::Model> model;
ov::NodeVector start_ops;
ov::NodeVector node_queue;
std::vector<std::vector<std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>>> ordered_patterns;
std::vector<ov::NodeVector> repeats;
public:
Model_3() {
// param_00
// |
// relu_0
// |
// split_1
// |
// +-------------+
// | |
// relu_2 clamp_3
// | |
// split_4 |
// | |
// +------------+ |
// | | |
// relu_5 clamp_6 |
// | | |
// +------------+ |
// | |
// add_7 |
// | |
// concat_8 |
// | |
// +-------------+
// |
// multiply_9 param_01
// |------------------+
// add_10 param_02
// |------------------+
// multiply_11
// |
// result_00
auto param_00 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 1, 24, 1});
param_00->set_friendly_name("param_00");
auto relu_0 = std::make_shared<ov::op::v0::Relu>(param_00);
relu_0->set_friendly_name("relu_0");
auto axis_split_1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, std::vector<int64_t>({2}));
auto split_1 = std::make_shared<ov::op::v1::Split>(relu_0, axis_split_1, 2);
split_1->set_friendly_name("split_1");
auto relu_2 = std::make_shared<ov::op::v0::Relu>(split_1->output(0));
relu_2->set_friendly_name("relu_2");
auto clamp_3 = std::make_shared<ov::op::v0::Clamp>(split_1->output(1), 0 , 10);
clamp_3->set_friendly_name("clamp_3");
auto axis_split_4 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, std::vector<int64_t>({2}));
auto split_4 = std::make_shared<ov::op::v1::Split>(relu_2, axis_split_4, 2);
split_4->set_friendly_name("split_4");
auto relu_5 = std::make_shared<ov::op::v0::Relu>(split_4->output(0));
relu_5->set_friendly_name("relu_5");
auto clamp_6 = std::make_shared<ov::op::v0::Clamp>(split_4->output(1), 0, 10);
clamp_6->set_friendly_name("clamp_6");
auto add_7 = std::make_shared<ov::op::v1::Add>(relu_5, clamp_6);
add_7->set_friendly_name("add_7");
auto param_03 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, add_7->get_shape());
param_03->set_friendly_name("param_03");
auto concat_8 = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{add_7, param_03}, 2);
concat_8->set_friendly_name("concat_8");
auto multiply_9 = std::make_shared<ov::op::v1::Multiply>(concat_8, clamp_3);
multiply_9->set_friendly_name("multiply_9");
auto param_01 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, multiply_9->get_shape());
param_01->set_friendly_name("param_01");
auto add_10 = std::make_shared<ov::op::v1::Add>(multiply_9, param_01);
add_10->set_friendly_name("add_10");
auto param_02 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, add_10->get_shape());
param_02->set_friendly_name("param_02");
auto multiply_11 = std::make_shared<ov::op::v1::Multiply>(add_10, param_02);
multiply_11->set_friendly_name("multiply_11");
auto result_00 = std::make_shared<ov::op::v0::Result>(multiply_11);
result_00->set_friendly_name("result_00");
model = std::make_shared<ov::Model>(ov::ResultVector{result_00},
ov::ParameterVector{param_00, param_01, param_02, param_03});
start_ops = {relu_0, relu_2, relu_5};
ordered_patterns = {{
{ relu_0, {}},
{ split_1, {0}},
{ relu_2, {1}},
{ split_4, {2}},
{ relu_5, {3}},
{ clamp_6, {3}},
{ add_7, {4, 5}},
{ concat_8, {6}},
{ clamp_3, {1}},
{ multiply_9, {7, 8}},
{ add_10, {9}},
{ multiply_11, {10}},
}, {
{ relu_2, {}},
{ split_4, {0}},
{ relu_5, {1}},
{ clamp_6, {1}},
{ add_7, {2, 3}},
{ concat_8, {4}},
{ multiply_9, {5}},
{ add_10, {6}},
{ multiply_11, {7}},
}, {
{ relu_5, {}},
{ add_7, {0}},
{ concat_8, {1}},
{ multiply_9, {2}},
{ add_10, {3}},
{ multiply_11, {4}},
}};
repeats = {
{ relu_0, split_1, relu_2 },
{ relu_2, split_4, relu_5 },
};
node_queue = {
relu_0, split_1, relu_2, split_4, relu_5, clamp_6,
add_7, concat_8, clamp_3, multiply_9, add_10, multiply_11,
};
}
std::shared_ptr<ov::Model> get() {
return model;
}
ov::NodeVector
get_start_ops() { return start_ops; }
std::vector<std::vector<std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>>>
get_ordered_patterns() { return ordered_patterns; }
std::vector<ov::NodeVector>
get_repeats() { return repeats; }
ov::NodeVector
get_queue() { return node_queue; }
};

View File

@ -30,9 +30,8 @@ TEST_F(ModelUtilsTest, generate_0) {
Model_0 test;
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
auto model_with_in_info = ov::util::generate_model(func_ops);
recovered_model = std::get<0>(model_with_in_info);
}
{
@ -44,9 +43,8 @@ TEST_F(ModelUtilsTest, generate_1) {
Model_1 test;
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
auto model_with_in_info = ov::util::generate_model(func_ops);
recovered_model = std::get<0>(model_with_in_info);
}
{
@ -58,9 +56,8 @@ TEST_F(ModelUtilsTest, generate_2) {
Model_2 test;
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
auto model_with_in_info = ov::util::generate_model(func_ops);
recovered_model = std::get<0>(model_with_in_info);
auto in_info = std::get<1>(model_with_in_info);
}
@ -74,10 +71,11 @@ TEST_F(ModelUtilsTest, align_input_info) {
auto in_info_0 = ov::util::get_input_info_by_model(test_model_0.get());
auto in_info_1 = ov::util::get_input_info_by_model(test_model_1.get());
ASSERT_NE(in_info_0, in_info_1);
std::unordered_map<std::string, std::string> a;
ASSERT_NO_THROW(ov::util::align_input_info(test_model_0.get(), test_model_1.get(),
in_info_0, in_info_1));
in_info_0, in_info_1, a));
auto in_info_ref = ov::util::align_input_info(test_model_0.get(), test_model_1.get(),
in_info_0, in_info_1);
in_info_0, in_info_1, a);
ASSERT_EQ(in_info_1, in_info_ref);
}
@ -88,7 +86,7 @@ TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
auto in_info_0 = ov::util::get_input_info_by_model(test_model_0);
auto in_info_1 = ov::util::get_input_info_by_model(test_model_1);
ASSERT_NE(in_info_0, in_info_1);
std::map<std::string, std::string> matched_ops;
auto matched_ops = ov::util::ModelComparator::get()->get_matched_ops_in_graphs(test_model_0, test_model_1);
auto params_0 = test_model_0->get_parameters();
auto params_1 = test_model_1->get_parameters();
size_t params_cnt = params_0.size();
@ -96,9 +94,9 @@ TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
matched_ops.insert({params_0[param_id]->get_friendly_name(),
params_1[param_id]->get_friendly_name()});
}
ASSERT_NO_THROW(ov::util::align_input_info(test_model_0, test_model_1,
in_info_0, in_info_1,
matched_ops));
// ASSERT_NO_THROW(ov::util::align_input_info(test_model_0, test_model_1,
// in_info_0, in_info_1,
// matched_ops));
auto ref = ov::util::align_input_info(test_model_0, test_model_1,
in_info_0, in_info_1, matched_ops);
ASSERT_EQ(in_info_1, ref);
@ -118,4 +116,13 @@ TEST_F(ModelUtilsTest, get_input_info_by_model) {
ASSERT_EQ(cur, ref);
}
TEST_F(ModelUtilsTest, get_subgraph_set_node) {
Model_1 model;
std::unordered_set<std::shared_ptr<ov::Node>> out_ops;
ov::util::get_subgraph_set_node(out_ops, model.get_test_abs_0());
auto expected = model.get_out_nodes_after_abs_0();
std::set<std::shared_ptr<ov::Node>> orig(out_ops.begin(), out_ops.end()),
ref(expected.begin(), expected.end());
ASSERT_EQ(orig, ref);
}
} // namespace

View File

@ -11,7 +11,8 @@
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/core/model.hpp"
#include "openvino/openvino.hpp"
#include "utils/model.hpp"
namespace {
using namespace ov::tools::subgraph_dumper;
@ -115,10 +116,11 @@ TEST_F(ModelComparatorTest, match_with_low_coeff) {
TEST_F(ModelComparatorTest, match_with_in_info) {
ov::util::ModelComparator::Ptr model_comparator = ov::util::ModelComparator::get();
std::map<std::string, ov::conformance::InputInfo>
test_in_info({{"test_parameter_0", ov::conformance::InputInfo()}}),
test_in_info_1({{"test_parameter_1", ov::conformance::InputInfo({}, 1, 2, true)}});
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info)));
test_in_info({{"test_parameter_0", ov::conformance::InputInfo(ov::Shape{1, 2})}}),
test_in_info_({{"test_parameter_0", ov::conformance::InputInfo(ov::Shape{1, 2})}}),
test_in_info_1({{"test_parameter_1", ov::conformance::InputInfo(ov::Shape{2, 5}, 1, 2, true)}});
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_));
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_)));
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1)));
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info));