[CONFORMANCE TOOLS][SUBAGRAPHS DUMPER] Repeat pattern extractor improvement (#21200)
* Init * Implemented custom sort * algo * extend * time * decrease exec time * fix code * Impove pattern extractor * repe * Improve * partial fix * Patual fix * fix_tests * Remove extra + fix tests * enable model * Patial fix for align_in_info * Exclude the nodes with the same size * Prints * update * fix tests * Fix crash * Update cache.hpp
This commit is contained in:
parent
193cc26658
commit
d5b0f4d2d7
@ -60,7 +60,10 @@ struct InputInfo {
|
||||
}
|
||||
|
||||
InputInfo operator=(const InputInfo& input_info) {
|
||||
if (this->is_const != input_info.is_const) {
|
||||
auto default_in_info = InputInfo();
|
||||
if (input_info == default_in_info) {
|
||||
this->is_const = input_info.is_const;
|
||||
} else if (this->is_const != input_info.is_const) {
|
||||
throw std::runtime_error("Cast Const to Parameter! Impossible to update Input Info!");
|
||||
}
|
||||
this->ranges = input_info.ranges;
|
||||
|
@ -20,6 +20,8 @@ public:
|
||||
|
||||
void set_matchers(const MatchersMap& matchers = {}) { m_matchers = matchers; }
|
||||
void set_shape_strict_match(bool shape_strict_match);
|
||||
void set_match_attributes(bool match_attribute);
|
||||
void set_match_in_types(bool match_in_types);
|
||||
|
||||
const MatchersMap& get_matchers() { return m_matchers; }
|
||||
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
|
||||
|
@ -21,6 +21,8 @@ public:
|
||||
|
||||
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
|
||||
void set_strict_shape_match(bool strict_shape_match);
|
||||
void set_match_attrib(bool match_attrib);
|
||||
void set_match_in_types(bool match_in_types);
|
||||
|
||||
protected:
|
||||
virtual void configure(const pugi::xml_document &cfg) {};
|
||||
@ -37,6 +39,8 @@ protected:
|
||||
std::vector<iMatcherConfig::Ptr> default_configs;
|
||||
// match only shape ranks by default;
|
||||
bool is_strict_shape_match = false;
|
||||
bool is_match_attributes = true;
|
||||
bool is_match_in_types = false;
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -11,10 +11,13 @@ namespace ov {
|
||||
namespace tools {
|
||||
namespace subgraph_dumper {
|
||||
|
||||
// note: please set any specific parameters related to graph comparation by `ModelComparetor::get()`
|
||||
// for example node attributes match or shape strict comparation
|
||||
class RepeatPatternExtractor final : public SubgraphExtractor {
|
||||
private:
|
||||
using InputVector = std::vector<ov::Input<ov::Node>>;
|
||||
using OutputVector = std::vector<ov::Output<ov::Node>>;
|
||||
using NodePair = std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>;
|
||||
|
||||
public:
|
||||
using PatternBorders = std::pair<InputVector, OutputVector>;
|
||||
@ -25,14 +28,28 @@ public:
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model);
|
||||
|
||||
void set_recursive_extraction(bool _is_recursive_extraction);
|
||||
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
|
||||
|
||||
// minimal size of extracted subgraph
|
||||
void set_min_graph_size(size_t _min_graph_size) {
|
||||
min_graph_size = _min_graph_size;
|
||||
}
|
||||
// cut graphs by matched nodes
|
||||
void set_split_by_matched_nodes(bool _is_split_by_matched_nodes) {
|
||||
is_split_by_matched_nodes = _is_split_by_matched_nodes;
|
||||
}
|
||||
// recursive extraction from extracted subgraphs
|
||||
void set_recursive_extraction(bool _is_recursive_extraction) {
|
||||
is_recursive_extraction = _is_recursive_extraction;
|
||||
}
|
||||
|
||||
protected:
|
||||
// {subgraph, node_vector, input_info}
|
||||
using ExtractedRepeatPattern = std::tuple<std::shared_ptr<ov::Model>, ov::NodeVector, std::map<std::string, ov::conformance::InputInfo>>;
|
||||
bool is_recursive_extraction = true;
|
||||
size_t min_graph_size = 2;
|
||||
bool is_split_by_matched_nodes = false, is_recursive_extraction = false;
|
||||
|
||||
// find repeat patterns in model
|
||||
std::list<std::vector<ExtractedRepeatPattern>>
|
||||
find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_save_borders_only = false);
|
||||
@ -40,8 +57,12 @@ protected:
|
||||
std::list<std::vector<ExtractedRepeatPattern>>& secondary_extracted_patterns);
|
||||
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
|
||||
const std::shared_ptr<ov::Model>& pattern,
|
||||
const ov::NodeVector& pattern_node_vector,
|
||||
const std::vector<ov::NodeVector>& pattern_node_vector,
|
||||
const std::map<std::string, ov::conformance::InputInfo>& in_info);
|
||||
// extract repeated patterns by start_node
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_patterns_by_nodes(const std::vector<size_t>& start_op_vec,
|
||||
const ov::NodeVector& ordered_ops);
|
||||
|
||||
};
|
||||
|
||||
|
@ -19,11 +19,15 @@ align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, ov::conformance::InputInfo> &in_info,
|
||||
const std::map<std::string, ov::conformance::InputInfo> &in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op = {});
|
||||
const std::unordered_map<std::string, std::string> &matched_op);
|
||||
|
||||
// get set nodes of subgraph after start_node
|
||||
void
|
||||
get_subgraph_set_node(std::unordered_set<std::shared_ptr<ov::Node>>& nodes_to_check,
|
||||
const std::shared_ptr<ov::Node>& node);
|
||||
|
||||
inline std::pair<std::shared_ptr<ov::Model>, std::map<std::string, ov::conformance::InputInfo>>
|
||||
generate_model(ov::NodeVector& nodes,
|
||||
std::unordered_set<std::string>& checked_ops,
|
||||
bool is_copy_constants = true,
|
||||
bool is_save_only_borders = false) {
|
||||
// map to recover graph using cloned nodes and original connections
|
||||
@ -39,7 +43,6 @@ generate_model(ov::NodeVector& nodes,
|
||||
size_t functional_node_cnt = 0;
|
||||
for (const auto& node : nodes) {
|
||||
auto orig_node_name = node->get_friendly_name();
|
||||
checked_ops.insert(orig_node_name);
|
||||
cloned_node_map.insert({ orig_node_name,
|
||||
clone_node(node, is_copy_constants, false, orig_node_name) });
|
||||
|
||||
|
@ -16,7 +16,7 @@ class ModelComparator {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<ModelComparator>;
|
||||
// { is_match, subgraph, graph, matched_nodes -> {subgraph_op_name, graph_op_name}}
|
||||
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
|
||||
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::unordered_map<std::string, std::string>>;
|
||||
using InputInfo = ov::conformance::InputInfo;
|
||||
// { model, subgraph, graph, subgraph_in_info, model_in_info, }
|
||||
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
|
||||
@ -46,10 +46,22 @@ public:
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref);
|
||||
|
||||
// {{matched_node_id}}
|
||||
std::vector<std::vector<size_t>>
|
||||
get_matched_op_patterns(const ov::NodeVector& ordered_nodes);
|
||||
|
||||
// { op_name_subgraph, op_name_graph}
|
||||
std::unordered_map<std::string, std::string>
|
||||
get_matched_ops_in_graphs(const std::shared_ptr<ov::Model>& subgraph,
|
||||
const std::shared_ptr<ov::Model>& graph,
|
||||
bool is_check_inputs = false) const;
|
||||
|
||||
void set_match_coefficient(float _match_coefficient);
|
||||
float get_match_coefficient() { return match_coefficient; }
|
||||
void set_shape_strict_match(bool is_shape_strict_match);
|
||||
void set_match_attributes(bool match_attributes);
|
||||
void set_match_in_types(bool match_in_types);
|
||||
|
||||
protected:
|
||||
ov::tools::subgraph_dumper::MatchersManager m_manager = ov::tools::subgraph_dumper::MatchersManager();
|
||||
|
@ -18,6 +18,9 @@ template <typename dType>
|
||||
inline ov::conformance::InputInfo::Range
|
||||
get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& node) {
|
||||
size_t elements_count = ov::shape_size(node->get_shape());
|
||||
if (elements_count == 0) {
|
||||
throw std::runtime_error("Impossible to get const ranges! Incorrect const size!");
|
||||
}
|
||||
const auto& const_values = node->cast_vector<dType>();
|
||||
auto max = *std::max_element(const_values.begin(), const_values.end());
|
||||
auto min = *std::min_element(const_values.begin(), const_values.end());
|
||||
|
@ -129,16 +129,34 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
|
||||
if (subgraph == cached_model.first) {
|
||||
auto meta = m_graph_cache[subgraph];
|
||||
meta.set_input_info(graph_in_info);
|
||||
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
|
||||
m_graph_cache.erase(subgraph);
|
||||
m_graph_cache.insert({graph, meta});
|
||||
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
|
||||
} else {
|
||||
m_graph_cache[cached_model.first].update(model_path,
|
||||
subgraph_in_info,
|
||||
model_op_cnt,
|
||||
this_op_cnt,
|
||||
extractor_name);
|
||||
}
|
||||
m_graph_cache[cached_model.first].update(model_path,
|
||||
subgraph_in_info,
|
||||
model_op_cnt,
|
||||
this_op_cnt,
|
||||
extractor_name);
|
||||
return;
|
||||
} else {
|
||||
auto matched_ops = std::get<3>(m_model_comparator->is_subgraph(extracted_model, cached_model.first));
|
||||
auto cached_model_op_cnt =
|
||||
cached_model.first->get_ops().size() - cached_model.second.get_input_info().size() -
|
||||
cached_model.first->get_results().size();
|
||||
auto extracted_model_op_cnt =
|
||||
extracted_model->get_ops().size() - input_info.size() - extracted_model->get_results().size();
|
||||
if (matched_ops.size() > 0.75 * extracted_model_op_cnt) {
|
||||
if (cached_model_op_cnt > extracted_model_op_cnt) {
|
||||
return;
|
||||
}
|
||||
m_graph_cache_bytesize += (extracted_model->get_graph_size() - cached_model.first->get_graph_size());
|
||||
m_graph_cache.erase(cached_model.first);
|
||||
ov::conformance::MetaInfo meta(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
|
||||
m_graph_cache.insert({extracted_model, meta});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -158,8 +176,10 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
|
||||
if (pattern_model_size < cached_model_size) {
|
||||
m_graph_cache_bytesize -= (cached_model_size - pattern_model_size);
|
||||
auto meta = m_graph_cache[model_to_update];
|
||||
auto matched_ops = m_model_comparator->get_matched_ops_in_graphs(model_to_update, extracted_model);
|
||||
auto new_in_info = ov::util::align_input_info(model_to_update, extracted_model,
|
||||
m_graph_cache.at(model_to_update).get_input_info(), input_info);
|
||||
m_graph_cache.at(model_to_update).get_input_info(), input_info,
|
||||
matched_ops);
|
||||
meta.set_input_info(new_in_info);
|
||||
m_graph_cache.erase(model_to_update);
|
||||
model_to_update = extracted_model;
|
||||
|
@ -23,6 +23,18 @@ void MatchersManager::set_shape_strict_match(bool shape_strict_match) {
|
||||
}
|
||||
}
|
||||
|
||||
void MatchersManager::set_match_attributes(bool match_attribute) {
|
||||
for (const auto& matcher : m_matchers) {
|
||||
matcher.second->set_match_attrib(match_attribute);
|
||||
}
|
||||
}
|
||||
|
||||
void MatchersManager::set_match_in_types(bool match_in_types) {
|
||||
for (const auto& matcher : m_matchers) {
|
||||
matcher.second->set_match_in_types(match_in_types);
|
||||
}
|
||||
}
|
||||
|
||||
bool MatchersManager::match(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref) const {
|
||||
for (const auto& it : m_matchers) {
|
||||
|
@ -28,6 +28,14 @@ void SingleOpMatcher::set_strict_shape_match(bool strict_shape_match) {
|
||||
is_strict_shape_match = strict_shape_match;
|
||||
}
|
||||
|
||||
void SingleOpMatcher::set_match_attrib(bool match_attrib) {
|
||||
is_match_attributes = match_attrib;
|
||||
}
|
||||
|
||||
void SingleOpMatcher::set_match_in_types(bool match_in_types) {
|
||||
is_match_in_types = match_in_types;
|
||||
}
|
||||
|
||||
bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref) const {
|
||||
if (node->get_input_size() != ref->get_input_size()) {
|
||||
@ -52,6 +60,15 @@ bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
|
||||
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (is_match_in_types) {
|
||||
const auto& in_node = node->get_input_node_shared_ptr(port_id);
|
||||
const auto& in_node_ref = ref->get_input_node_shared_ptr(port_id);
|
||||
if (ov::util::is_node_to_skip(in_node) || ov::util::is_node_to_skip(in_node_ref)) {
|
||||
continue;
|
||||
} else if (in_node->get_type_info() != in_node_ref->get_type_info()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -125,15 +142,17 @@ bool SingleOpMatcher::match(const std::shared_ptr<ov::Node> &node,
|
||||
if (!same_op_type(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (is_match_attributes) {
|
||||
if (!match_attrs(node, ref) && !ov::util::is_node_to_skip(node)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!match_inputs(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (!match_outputs(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (!match_attrs(node, ref) && !ov::util::is_node_to_skip(node)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -68,16 +68,15 @@ std::vector<FusedNamesExtractor::ExtractedPattern>
|
||||
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
auto compiled_op_name = extract_compiled_model_names(model);
|
||||
std::vector<FusedNamesExtractor::ExtractedPattern> matched_patterns;
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
ov::NodeVector nodes;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
auto op_name = op->get_friendly_name();
|
||||
if (ov::util::is_node_to_skip(op) || checked_ops.count(op_name)) {
|
||||
if (ov::util::is_node_to_skip(op)) {
|
||||
continue;
|
||||
}
|
||||
if (compiled_op_name.count(op_name)) {
|
||||
try {
|
||||
auto extracted_pattern = ov::util::generate_model(nodes, checked_ops, is_save_const);
|
||||
auto extracted_pattern = ov::util::generate_model(nodes, is_save_const);
|
||||
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
|
||||
} catch(std::exception& e) {
|
||||
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
|
||||
@ -111,7 +110,7 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
}
|
||||
}
|
||||
try {
|
||||
auto extracted_pattern = ov::util::generate_model(nodes, checked_ops, is_save_const);
|
||||
auto extracted_pattern = ov::util::generate_model(nodes, is_save_const);
|
||||
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
|
||||
} catch(std::exception& e) {
|
||||
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
|
||||
|
@ -11,10 +11,6 @@
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
void RepeatPatternExtractor::set_recursive_extraction(bool _is_recursive_extraction) {
|
||||
is_recursive_extraction = _is_recursive_extraction;
|
||||
}
|
||||
|
||||
std::vector<RepeatPatternExtractor::ExtractedPattern>
|
||||
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<ExtractedPattern> extracted_patterns;
|
||||
@ -29,23 +25,31 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>
|
||||
RepeatPatternExtractor::get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>> extracted_patterns;
|
||||
for (auto& pattern : find_repeat_patterns(model, true)) {
|
||||
auto a = find_repeat_patterns(model, true);
|
||||
for (auto& pattern : a) {
|
||||
std::vector<RepeatPatternExtractor::PatternBorders> same_pattern_borders;
|
||||
for (const auto& pattern_structure : pattern) {
|
||||
std::set<std::string> output_names;
|
||||
for (const auto& result : std::get<0>(pattern_structure)->get_results()) {
|
||||
output_names.insert(result->get_input_node_shared_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
||||
RepeatPatternExtractor::InputVector in_vec;
|
||||
RepeatPatternExtractor::OutputVector out_vec;
|
||||
for (const auto& node : std::get<1>(pattern_structure)) {
|
||||
if (output_names.count(node->get_friendly_name())) {
|
||||
OutputVector node_outputs = node->outputs();
|
||||
out_vec.insert(out_vec.end(), node_outputs.begin(), node_outputs.end());
|
||||
} else {
|
||||
for (const auto& input : node->inputs()) {
|
||||
in_vec.push_back(input);
|
||||
auto node_vector = std::get<1>(pattern_structure);
|
||||
for (const auto& node : node_vector) {
|
||||
for (size_t out_idx = 0; out_idx < node->outputs().size(); ++out_idx) {
|
||||
size_t idx = 0;
|
||||
const auto target_inputs = node->get_output_target_inputs(out_idx);
|
||||
for (const auto& target_input : target_inputs) {
|
||||
const auto target_in_node = target_input.get_node()->shared_from_this();
|
||||
if (std::find(node_vector.begin(), node_vector.end(), target_in_node) == node_vector.end()) {
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
if (idx == target_inputs.size()) {
|
||||
out_vec.push_back(node->output(out_idx));
|
||||
}
|
||||
}
|
||||
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
|
||||
auto in_node = node->get_input_node_shared_ptr(in_idx);
|
||||
if (std::find(node_vector.begin(), node_vector.end(), in_node) == node_vector.end()) {
|
||||
in_vec.push_back(node->input(in_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -73,7 +77,7 @@ void
|
||||
RepeatPatternExtractor::update_extractor_cache(
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
|
||||
const std::shared_ptr<ov::Model>& pattern,
|
||||
const ov::NodeVector& pattern_node_vector,
|
||||
const std::vector<ov::NodeVector>& pattern_node_vector,
|
||||
const std::map<std::string, ov::conformance::InputInfo>& pattern_in_info) {
|
||||
for (auto& extracted_pattern : extracted_patterns) {
|
||||
auto& pattern_structure = extracted_pattern.front();
|
||||
@ -81,13 +85,20 @@ RepeatPatternExtractor::update_extractor_cache(
|
||||
if (model_comparator->match(pattern, cached_pattern)) {
|
||||
try {
|
||||
const auto& cached_in_info = std::get<2>(pattern_structure);
|
||||
ov::util::align_input_info(pattern, cached_pattern, pattern_in_info, cached_in_info);
|
||||
extracted_pattern.push_back({ pattern, pattern_node_vector, pattern_in_info });
|
||||
ov::util::align_input_info(pattern, cached_pattern,
|
||||
pattern_in_info, cached_in_info,
|
||||
model_comparator->get_matched_ops_in_graphs(pattern, cached_pattern));
|
||||
for (const auto& p : pattern_node_vector) {
|
||||
extracted_pattern.push_back({ pattern, p, pattern_in_info });
|
||||
}
|
||||
return;
|
||||
} catch(std::exception) {}
|
||||
}
|
||||
}
|
||||
extracted_patterns.push_back({{ pattern, pattern_node_vector, pattern_in_info }});
|
||||
extracted_patterns.push_back({{ pattern, pattern_node_vector.front(), pattern_in_info }});
|
||||
for (size_t i = 1; i < pattern_node_vector.size(); ++i) {
|
||||
extracted_patterns.back().push_back({ pattern, pattern_node_vector[i], pattern_in_info });
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
@ -101,117 +112,191 @@ RepeatPatternExtractor::update_extractor_cache(
|
||||
const auto& pattern = std::get<0>(pattern_structure);
|
||||
const auto& pattern_node_vector = std::get<1>(pattern_structure);
|
||||
const auto& pattern_in_info = std::get<2>(pattern_structure);
|
||||
update_extractor_cache(extracted_patterns, pattern, pattern_node_vector, pattern_in_info);
|
||||
update_extractor_cache(extracted_patterns, pattern, {pattern_node_vector}, pattern_in_info);
|
||||
extern_it->pop_back();
|
||||
}
|
||||
secondary_extracted_patterns.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
RepeatPatternExtractor::get_patterns_by_nodes(const std::vector<size_t>& start_op_vec,
|
||||
const ov::NodeVector& ordered_ops) {
|
||||
// handle case impossible to extract repeat patterns
|
||||
if (start_op_vec.size() < 2 || ordered_ops.size() < 3) {
|
||||
return {{}};
|
||||
}
|
||||
// is_recursive_extraction = true;
|
||||
// prepare node vectors contains potential patterns from start_node to output
|
||||
// first one is biggest subgraph, last one is smallest one
|
||||
auto pattern_cnt = start_op_vec.size();
|
||||
std::vector<ov::NodeVector> patterns(pattern_cnt);
|
||||
for (size_t pattern_idx = 0; pattern_idx < pattern_cnt; ++pattern_idx) {
|
||||
// get only nodes are after start node in graph
|
||||
std::unordered_set<std::shared_ptr<ov::Node>> nodes_to_check;
|
||||
const auto& start_op_idx = start_op_vec[pattern_idx];
|
||||
util::get_subgraph_set_node(nodes_to_check, ordered_ops[start_op_idx]);
|
||||
for (const auto& op : ordered_ops) {
|
||||
if (nodes_to_check.count(op)) {
|
||||
patterns[pattern_idx].push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
// reverse node vectors to anylize from small to big subgraphs
|
||||
std::reverse(patterns.begin(), patterns.end());
|
||||
std::vector<ov::NodeVector> potential_patterns(pattern_cnt);
|
||||
for (size_t i_orig = 0; i_orig < pattern_cnt - 1; ++i_orig) {
|
||||
// skip comparation in case pattern is iniatized
|
||||
if (!potential_patterns[i_orig].empty()) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i_ref = i_orig + 1; i_ref < pattern_cnt; ++i_ref) {
|
||||
if (!potential_patterns[i_ref].empty()) {
|
||||
continue;
|
||||
}
|
||||
// extract minimal intersected patterns
|
||||
auto intersection_len = std::min(patterns[i_orig].size(), patterns[i_ref].size());
|
||||
ov::NodeVector pattern_orig(intersection_len, nullptr), pattern_ref(intersection_len, nullptr);
|
||||
for (size_t j = 0; j < intersection_len; ++j) {
|
||||
if (model_comparator->match(patterns[i_orig][j], patterns[i_ref][j])) {
|
||||
if (patterns[i_orig][j] == patterns[i_ref][j]) {
|
||||
break;
|
||||
}
|
||||
if (is_split_by_matched_nodes) {
|
||||
if (model_comparator->match(patterns[i_orig][0], patterns[i_orig][j]) ||
|
||||
model_comparator->match(patterns[i_ref][0], patterns[i_ref][j])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// check inputs and matching in case not start_node
|
||||
if (j != 0) {
|
||||
bool is_input_matched = false;
|
||||
for (size_t input_idx = 0; input_idx < patterns[i_orig][j]->inputs().size(); ++input_idx) {
|
||||
auto in_orig = patterns[i_orig][j]->get_input_node_shared_ptr(input_idx);
|
||||
auto in_ref = patterns[i_ref][j]->get_input_node_shared_ptr(input_idx);
|
||||
if (std::find(pattern_orig.begin(), pattern_orig.end(), in_orig) != pattern_orig.end() &&
|
||||
std::find(pattern_ref.begin(), pattern_ref.end(), in_ref) != pattern_ref.end()) {
|
||||
is_input_matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_input_matched) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pattern_orig[j] = patterns[i_orig][j];
|
||||
pattern_ref[j] = patterns[i_ref][j];
|
||||
}
|
||||
}
|
||||
// fill vectors only by valid nodes
|
||||
ov::NodeVector orig, ref;
|
||||
for (size_t node_idx = 0; node_idx < pattern_orig.size(); ++node_idx) {
|
||||
if (pattern_orig[node_idx] != 0) {
|
||||
orig.emplace_back(pattern_orig[node_idx]);
|
||||
}
|
||||
if (pattern_ref[node_idx] != 0) {
|
||||
ref.emplace_back(pattern_ref[node_idx]);
|
||||
}
|
||||
}
|
||||
if (orig.size() < min_graph_size) {
|
||||
continue;
|
||||
}
|
||||
potential_patterns[i_orig] = orig;
|
||||
potential_patterns[i_ref] = ref;
|
||||
}
|
||||
}
|
||||
// sort patterns by node vectors size
|
||||
std::sort(potential_patterns.begin(), potential_patterns.end(), [](const ov::NodeVector& a, const ov::NodeVector& b) {
|
||||
return a.size() > b.size();
|
||||
});
|
||||
|
||||
// exclude not repeated pattern
|
||||
while (potential_patterns.rbegin()->size() < 2 && !potential_patterns.empty()) {
|
||||
potential_patterns.pop_back();
|
||||
}
|
||||
patterns = potential_patterns;
|
||||
}
|
||||
|
||||
// group node vectors to the patterns: std::vector<ov::NodeVector>
|
||||
std::vector<std::vector<ov::NodeVector>> pattern_vec;
|
||||
for (size_t pattern_idx = 0; pattern_idx < patterns.size(); ++pattern_idx) {
|
||||
const auto& pattern = patterns[pattern_idx];
|
||||
if (pattern_vec.empty()) {
|
||||
pattern_vec.push_back({{pattern}});
|
||||
} else if (pattern_vec.rbegin()->begin()->size() != pattern.size()) {
|
||||
pattern_vec.push_back({{pattern}});
|
||||
} else {
|
||||
auto it = pattern_vec.rbegin();
|
||||
while (it != pattern_vec.rend()) {
|
||||
auto ref = it->front();
|
||||
if (ref.size() != pattern.size()) {
|
||||
pattern_vec.push_back({{pattern}});
|
||||
break;
|
||||
}
|
||||
bool is_matched = true;
|
||||
for (size_t i = 0; i < pattern.size(); ++i) {
|
||||
if (!model_comparator->match(pattern[i], ref[i])) {
|
||||
is_matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_matched) {
|
||||
it->push_back(pattern);
|
||||
break;
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
if (it == pattern_vec.rend()) {
|
||||
pattern_vec.push_back({{pattern}});
|
||||
}
|
||||
}
|
||||
}
|
||||
return pattern_vec;
|
||||
}
|
||||
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>
|
||||
RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_save_borders_only) {
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>> extracted_patterns;
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
|
||||
auto ordered_ops = model->get_ordered_ops();
|
||||
auto op_cnt = ordered_ops.size();
|
||||
if (ordered_ops.size() < 2) {
|
||||
return extracted_patterns;
|
||||
}
|
||||
auto matched_nodes_pattern = model_comparator->get_matched_op_patterns(ordered_ops);
|
||||
|
||||
for (size_t idx = 0; idx < op_cnt; ++idx) {
|
||||
auto op = ordered_ops[idx];
|
||||
auto op_name = op->get_friendly_name();
|
||||
if (checked_ops.count(op_name)|| ov::util::is_node_to_skip(op)) {
|
||||
for (size_t i = 0; i < matched_nodes_pattern.size(); ++i) {
|
||||
auto matched_nodes = matched_nodes_pattern[i];
|
||||
if (matched_nodes.size() < 2 || i > 0 && matched_nodes.size() == matched_nodes_pattern[i - 1].size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// find the same nodes
|
||||
std::vector<size_t> start_node_idx{idx};
|
||||
for (size_t i = idx + 1; i < op_cnt; ++i) {
|
||||
if (model_comparator->match(op, ordered_ops[i])) {
|
||||
start_node_idx.push_back(i);
|
||||
}
|
||||
}
|
||||
if (start_node_idx.size() < 2) {
|
||||
checked_ops.insert(op_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::set<std::shared_ptr<ov::Node>>> nodes(start_node_idx.size());
|
||||
for (size_t i = 0; i < start_node_idx.size(); ++i) {
|
||||
for (size_t j = i + 1; j < start_node_idx.size(); ++j) {
|
||||
size_t node_idx = start_node_idx[i], ref_node_idx = start_node_idx[j];
|
||||
while (node_idx < op_cnt && ref_node_idx < op_cnt) {
|
||||
auto node = ordered_ops[node_idx];
|
||||
auto ref_node = ordered_ops[ref_node_idx];
|
||||
if (checked_ops.count(node->get_friendly_name()) ||
|
||||
checked_ops.count(ref_node->get_friendly_name())) {
|
||||
break;
|
||||
}
|
||||
if (!ov::util::is_node_to_skip(node) &&
|
||||
!ov::util::is_node_to_skip(ref_node)) {
|
||||
if (node_idx == start_node_idx[i] && ref_node_idx == start_node_idx[j]) {
|
||||
nodes[i].insert(node);
|
||||
nodes[j].insert(ref_node);
|
||||
} else if (model_comparator->match(node, ref_node)) {
|
||||
// check if we met the same node
|
||||
if (model_comparator->match(node, op)) {
|
||||
break;
|
||||
}
|
||||
if (checked_ops.count(node->get_friendly_name()) ||
|
||||
checked_ops.count(ref_node->get_friendly_name())) {
|
||||
break;
|
||||
}
|
||||
// check that any input node is using in graph
|
||||
bool is_input_in_graph = false;
|
||||
for (size_t in_idx = 0; in_idx < node->inputs().size(); ++in_idx) {
|
||||
auto in_node = node->get_input_node_ptr(in_idx)->shared_from_this();
|
||||
auto ref_in_node = ref_node->get_input_node_ptr(in_idx)->shared_from_this();
|
||||
if (nodes[i].count(in_node) && nodes[j].count(ref_in_node)) {
|
||||
is_input_in_graph = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_input_in_graph) {
|
||||
break;
|
||||
}
|
||||
|
||||
nodes[i].insert(ordered_ops[node_idx]);
|
||||
nodes[j].insert(ordered_ops[ref_node_idx]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
++node_idx;
|
||||
++ref_node_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < start_node_idx.size(); ++i) {
|
||||
for (auto& nodes_vector : get_patterns_by_nodes(matched_nodes, ordered_ops)) {
|
||||
try {
|
||||
std::unordered_set<std::string> tmp_checked_ops;
|
||||
// model, in_info, extractor_name
|
||||
ov::NodeVector nodes_vector(nodes[i].begin(), nodes[i].end());
|
||||
auto extracted_pattern = ov::util::generate_model(nodes_vector, tmp_checked_ops, is_save_const, is_save_borders_only);
|
||||
if (nodes_vector.size() < 1) {
|
||||
continue;
|
||||
}
|
||||
auto extracted_pattern = ov::util::generate_model(nodes_vector.front(), is_save_const, is_save_borders_only);
|
||||
auto extracted_model = extracted_pattern.first;
|
||||
if (is_recursive_extraction && nodes_vector.size() > 20) {
|
||||
auto secondary_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
|
||||
if (!secondary_patterns.empty()) {
|
||||
tmp_checked_ops.clear();
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else {
|
||||
update_extractor_cache(extracted_patterns,
|
||||
extracted_model,
|
||||
nodes_vector,
|
||||
extracted_pattern.second);
|
||||
auto extracted_input_info = extracted_pattern.second;
|
||||
if (extracted_model == nullptr) {
|
||||
continue;
|
||||
}
|
||||
bool is_insert_res = true;
|
||||
if (is_recursive_extraction) {
|
||||
auto tmp_extracted_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
|
||||
if (!tmp_extracted_patterns.empty()) {
|
||||
is_insert_res = false;
|
||||
update_extractor_cache(extracted_patterns, tmp_extracted_patterns);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if (is_insert_res) {
|
||||
update_extractor_cache(extracted_patterns,
|
||||
extracted_model,
|
||||
nodes_vector,
|
||||
extracted_pattern.second);
|
||||
extracted_input_info);
|
||||
}
|
||||
nodes[i].clear();
|
||||
checked_ops.insert(tmp_checked_ops.begin(), tmp_checked_ops.end());
|
||||
} catch(std::exception& e) {
|
||||
if (std::string(e.what()).find("Incorrect node number to create model!") == std::string::npos) {
|
||||
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;
|
||||
@ -219,23 +304,28 @@ RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &m
|
||||
}
|
||||
}
|
||||
if (is_extract_body) {
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
|
||||
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(op);
|
||||
auto ti_body = ti->get_function();
|
||||
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(op)) {
|
||||
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(op);
|
||||
auto loop_body = loop->get_function();
|
||||
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(op)) {
|
||||
auto if_op = ov::as_type_ptr<ov::op::v8::If>(op);
|
||||
std::vector<std::shared_ptr<ov::Model>> bodies;
|
||||
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
|
||||
auto if_body = if_op->get_function(i);
|
||||
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
|
||||
for (const auto& matched_node_idx : matched_nodes) {
|
||||
const auto& matched_node = ordered_ops[matched_node_idx];
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(matched_node)) {
|
||||
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(matched_node);
|
||||
auto ti_body = ti->get_function();
|
||||
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(matched_node)) {
|
||||
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(matched_node);
|
||||
auto loop_body = loop->get_function();
|
||||
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(matched_node)) {
|
||||
auto if_op = ov::as_type_ptr<ov::op::v8::If>(matched_node);
|
||||
std::vector<std::shared_ptr<ov::Model>> bodies;
|
||||
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
|
||||
auto if_body = if_op->get_function(i);
|
||||
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41,51 +41,40 @@ align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, ov::conformance::InputInfo>& in_info,
|
||||
const std::map<std::string, ov::conformance::InputInfo>& in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op) {
|
||||
bool is_update_required = !matched_op.empty();
|
||||
if (!is_update_required) {
|
||||
for (const auto& ref_item : in_info_ref) {
|
||||
if (!in_info.count(ref_item.first)) {
|
||||
is_update_required = true;
|
||||
break;
|
||||
} else if (in_info.at(ref_item.first).is_const != ref_item.second.is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
const std::unordered_map<std::string, std::string> &matched_op) {
|
||||
std::map<std::string, ov::conformance::InputInfo> updated_input_info(in_info_ref);
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
const auto op_name = op->get_friendly_name();
|
||||
if (!in_info.count(op_name)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, ov::conformance::InputInfo> updated_input_info = in_info_ref;
|
||||
if (is_update_required) {
|
||||
// align matched model names
|
||||
const auto& ref_model_ops = model_ref->get_ordered_ops();
|
||||
const auto& model_ops = model->get_ordered_ops();
|
||||
size_t ref_ordered_ops_size = ref_model_ops.size();
|
||||
size_t ordered_ops_size = model_ops.size();
|
||||
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
|
||||
throw std::runtime_error("Matched models can not be compared according different op numbers!");
|
||||
}
|
||||
for (size_t i = 0; i < ordered_ops_size; ++i) {
|
||||
auto model_op_name = model_ops[i]->get_friendly_name();
|
||||
if (!in_info.count(model_op_name)) {
|
||||
continue;
|
||||
}
|
||||
if (!matched_op.empty()) {
|
||||
if (!matched_op.count(model_op_name)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto model_ref_op_name = matched_op.empty() ? ref_model_ops[i]->get_friendly_name() : matched_op.at(model_op_name);
|
||||
|
||||
const auto& in_info_item = in_info.at(model_op_name);
|
||||
const auto& ref_in_info_item = in_info_ref.at(model_ref_op_name);
|
||||
if (in_info_item.is_const != ref_in_info_item.is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
updated_input_info[model_ref_op_name] = in_info_item;
|
||||
if (matched_op.count(op_name) && in_info_ref.count(matched_op.at(op_name))) {
|
||||
updated_input_info[matched_op.at(op_name)] = in_info.at(op_name);
|
||||
}
|
||||
}
|
||||
return updated_input_info;
|
||||
}
|
||||
|
||||
void
|
||||
get_subgraph_set_node(std::unordered_set<std::shared_ptr<ov::Node>>& nodes_to_check,
|
||||
const std::shared_ptr<ov::Node>& node) {
|
||||
if (nodes_to_check.empty()) {
|
||||
nodes_to_check.insert(node);
|
||||
}
|
||||
for (size_t out_idx = 0; out_idx < node->outputs().size(); ++out_idx) {
|
||||
for (const auto& out : node->get_output_target_inputs(out_idx)) {
|
||||
const auto& output_node = out.get_node()->shared_from_this();
|
||||
if (ov::op::util::is_output(output_node)) {
|
||||
return;
|
||||
}
|
||||
if (!nodes_to_check.count(output_node)) {
|
||||
nodes_to_check.insert(output_node);
|
||||
get_subgraph_set_node(nodes_to_check, output_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace util
|
||||
} // namespace ov
|
@ -20,46 +20,43 @@ void ModelComparator::set_shape_strict_match(bool in_is_shape_strict_match) {
|
||||
m_manager.set_shape_strict_match(in_is_shape_strict_match);
|
||||
}
|
||||
|
||||
void ModelComparator::set_match_attributes(bool match_attributes) {
|
||||
m_manager.set_match_attributes(match_attributes);
|
||||
}
|
||||
|
||||
void ModelComparator::set_match_in_types(bool match_in_types) {
|
||||
m_manager.set_match_in_types(true);
|
||||
}
|
||||
|
||||
inline ModelComparator::IsSubgraphTuple
|
||||
prepare_is_subgraph_result(bool is_subgraph,
|
||||
const std::shared_ptr<ov::Model>& subgraph,
|
||||
const std::shared_ptr<ov::Model>& graph,
|
||||
const std::map<std::string, std::string>& matched_ops) {
|
||||
return is_subgraph ?
|
||||
std::make_tuple(is_subgraph, subgraph, graph, matched_ops) :
|
||||
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
|
||||
const std::unordered_map<std::string, std::string>& matched_ops) {
|
||||
return std::make_tuple(is_subgraph, subgraph, graph, matched_ops);
|
||||
}
|
||||
|
||||
ModelComparator::IsSubgraphTuple
|
||||
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const {
|
||||
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
|
||||
ref_ordered_ops = ref_model->get_ordered_ops();
|
||||
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
|
||||
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
|
||||
auto in_info = ov::util::get_input_info_by_model(model);
|
||||
auto in_info_ref = ov::util::get_input_info_by_model(ref_model);
|
||||
size_t ordered_ops_cnt = model->get_ordered_ops().size() - in_info.size() - model->get_results().size(),
|
||||
ref_ordered_ops_cnt = ref_model->get_ordered_ops().size() - in_info_ref.size() - ref_model->get_results().size();
|
||||
bool is_model = ordered_ops_cnt > ref_ordered_ops_cnt;
|
||||
size_t subgraph_to_check_ops_cnt;
|
||||
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
|
||||
if (is_model) {
|
||||
graph_to_check_ops = ordered_ops;
|
||||
subgraph_to_check_ops = ref_ordered_ops;
|
||||
graph = model;
|
||||
subgraph = ref_model;
|
||||
subgraph_to_check_ops_cnt = ref_ordered_ops_cnt;
|
||||
} else {
|
||||
graph_to_check_ops = ref_ordered_ops;
|
||||
subgraph_to_check_ops = ordered_ops;
|
||||
graph = ref_model;
|
||||
subgraph = model;
|
||||
subgraph_to_check_ops_cnt = ordered_ops_cnt;
|
||||
}
|
||||
std::map<std::string, std::string> matched_op_names;
|
||||
|
||||
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
|
||||
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
|
||||
if (m_manager.match(*graph_it, *subgraph_it)) {
|
||||
matched_op_names.insert({ (*subgraph_it)->get_friendly_name(), (*graph_it)->get_friendly_name()});
|
||||
++subgraph_it;
|
||||
}
|
||||
++graph_it;
|
||||
}
|
||||
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), subgraph, graph, matched_op_names);
|
||||
auto matched_op_names = get_matched_ops_in_graphs(subgraph, graph);
|
||||
return prepare_is_subgraph_result(matched_op_names.size() == subgraph_to_check_ops_cnt, subgraph, graph, matched_op_names);
|
||||
}
|
||||
|
||||
bool
|
||||
@ -94,17 +91,17 @@ ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
ModelComparator::ExtractedSubgraphTuple res = { false, nullptr, nullptr, {}, {} };
|
||||
m_manager.set_match_in_types(true);
|
||||
auto extractor_res = is_subgraph(model, ref_model);
|
||||
if (std::get<0>(extractor_res)) {
|
||||
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
|
||||
std::shared_ptr<ov::Model> subgraph = nullptr, graph = nullptr;
|
||||
// if (model == subgraph && ref_model == graph)
|
||||
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
|
||||
subgraph = model;
|
||||
subgraph_in_info = in_info;
|
||||
graph = ref_model;
|
||||
graph_in_info = in_info_ref;
|
||||
// else if (subgraph == ref_model && graph = model)
|
||||
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
|
||||
subgraph = ref_model;
|
||||
subgraph_in_info = in_info_ref;
|
||||
@ -114,11 +111,50 @@ ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
throw std::runtime_error("Generated models are incompatible with original ones!");
|
||||
}
|
||||
try {
|
||||
subgraph_in_info = ov::util::align_input_info(subgraph, graph, subgraph_in_info, graph_in_info);
|
||||
return { true, subgraph, graph, subgraph_in_info, graph_in_info };
|
||||
auto subgraph_in_info_new = ov::util::align_input_info(subgraph, graph,
|
||||
subgraph_in_info, graph_in_info,
|
||||
get_matched_ops_in_graphs(subgraph, graph));
|
||||
res = { true, subgraph, graph, subgraph_in_info_new, graph_in_info };
|
||||
} catch(std::exception) {}
|
||||
}
|
||||
return { false, nullptr, nullptr, {}, {} };
|
||||
m_manager.set_match_in_types(false);
|
||||
return res;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string>
|
||||
ModelComparator::get_matched_ops_in_graphs(const std::shared_ptr<ov::Model>& subgraph,
|
||||
const std::shared_ptr<ov::Model>& graph,
|
||||
bool is_check_inputs) const {
|
||||
std::unordered_map<std::string, std::string> matched_op_names;
|
||||
std::unordered_set<std::string> checked_op;
|
||||
const auto subgraph_to_check_ops = subgraph->get_ordered_ops();
|
||||
const auto graph_to_check_ops = graph->get_ordered_ops();
|
||||
for (const auto& subgraph_op : subgraph_to_check_ops) {
|
||||
for (const auto& graph_op : graph_to_check_ops) {
|
||||
if (ov::util::is_node_to_skip(subgraph_op) ||
|
||||
ov::util::is_node_to_skip(graph_op)) {
|
||||
continue;
|
||||
}
|
||||
if (match(subgraph_op, graph_op) && !checked_op.count(graph_op->get_friendly_name())) {
|
||||
matched_op_names.insert({subgraph_op->get_friendly_name(), graph_op->get_friendly_name()});
|
||||
checked_op.insert(graph_op->get_friendly_name());
|
||||
if (is_check_inputs) {
|
||||
for (size_t idx = 0; idx < graph_op->inputs().size(); ++idx) {
|
||||
auto graph_in = graph_op->get_input_node_shared_ptr(idx);
|
||||
auto subgraph_in = subgraph_op->get_input_node_shared_ptr(idx);
|
||||
if (ov::util::is_node_to_skip(graph_in) && ov::util::is_node_to_skip(subgraph_in)) {
|
||||
if (match(subgraph_in, graph_in)) {
|
||||
matched_op_names.insert({subgraph_in->get_friendly_name(), graph_in->get_friendly_name()});
|
||||
checked_op.insert(graph_in->get_friendly_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched_op_names;
|
||||
}
|
||||
|
||||
std::pair<bool, std::map<std::string, ov::conformance::InputInfo>>
|
||||
@ -128,9 +164,35 @@ ModelComparator::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
try {
|
||||
if (match(model, model_ref)) {
|
||||
auto new_input_info = ov::util::align_input_info(model, model_ref, in_info, in_info_ref);
|
||||
auto new_input_info = ov::util::align_input_info(model, model_ref,
|
||||
in_info, in_info_ref,
|
||||
get_matched_ops_in_graphs(model, model_ref, true));
|
||||
return {true, new_input_info};
|
||||
}
|
||||
} catch (std::exception) {}
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>>
|
||||
ModelComparator::get_matched_op_patterns(const ov::NodeVector& ordered_ops) {
|
||||
std::vector<std::vector<size_t>> matched_nodes;
|
||||
for (size_t node_idx = 0; node_idx < ordered_ops.size(); ++node_idx) {
|
||||
bool is_matched = false;
|
||||
for (auto& matched_node_idx : matched_nodes) {
|
||||
if (match(ordered_ops[matched_node_idx.front()], ordered_ops[node_idx])) {
|
||||
matched_node_idx.push_back(node_idx);
|
||||
is_matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_matched && !ov::util::is_node_to_skip(ordered_ops[node_idx])) {
|
||||
matched_nodes.push_back({node_idx});
|
||||
}
|
||||
}
|
||||
std::sort(matched_nodes.begin(), matched_nodes.end(),
|
||||
[](const std::vector<size_t>& a, const std::vector<size_t>& b){ return a.size() > b.size(); });
|
||||
while (!matched_nodes.empty() && matched_nodes.rbegin()->size() == 1) {
|
||||
matched_nodes.pop_back();
|
||||
}
|
||||
return matched_nodes;
|
||||
}
|
@ -12,14 +12,18 @@
|
||||
#include "test_models/model_0.hpp"
|
||||
#include "test_models/model_1.hpp"
|
||||
#include "test_models/model_2.hpp"
|
||||
#include "test_models/model_3.hpp"
|
||||
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/serialize.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
|
||||
// ======================= ExtractorsManagerTest Unit tests =======================
|
||||
class RepeatPatternExtractorTest : public SubgraphsDumperBaseTest {
|
||||
// ======================= ExtractorsManagerTest Func tests =======================
|
||||
class RepeatPatternExtractorFuncTest : public SubgraphsDumperBaseTest {
|
||||
protected:
|
||||
RepeatPatternExtractor extractor;
|
||||
|
||||
@ -52,41 +56,57 @@ protected:
|
||||
std::sort(pattern_vec.begin(), pattern_vec.end());
|
||||
}
|
||||
|
||||
// not allowed to sort inputs/outputs according there are not copy constructor
|
||||
// void sort_borders(std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec) {
|
||||
// for (auto& pattern : pattern_vec) {
|
||||
// for (auto& node_vec : pattern) {
|
||||
// std::sort(node_vec.first.begin(), node_vec.first.end());
|
||||
// std::sort(node_vec.second.begin(), node_vec.second.end());
|
||||
// }
|
||||
// std::sort(pattern.begin(), pattern.end());
|
||||
// }
|
||||
// std::sort(pattern_vec.begin(), pattern_vec.end());
|
||||
// }
|
||||
void
|
||||
is_equal_borders(const std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec_orig,
|
||||
const std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec_ref) {
|
||||
ASSERT_EQ(pattern_vec_orig.size(), pattern_vec_ref.size());
|
||||
size_t orig_borders_cnt = 0, ref_borderd_cnt = 0, eq_borders = 0;
|
||||
for (const auto& pattern_orig : pattern_vec_orig) {
|
||||
orig_borders_cnt += pattern_orig.size();
|
||||
ref_borderd_cnt = 0;
|
||||
for (const auto& pattern_ref : pattern_vec_ref) {
|
||||
ref_borderd_cnt += pattern_ref.size();
|
||||
if (pattern_ref.size() != pattern_orig.size()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& node_vec_orig : pattern_orig) {
|
||||
// size_t eq_pattens = 0;
|
||||
for (const auto& node_vec_ref : pattern_ref) {
|
||||
if (node_vec_orig == node_vec_ref) {
|
||||
++eq_borders;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(orig_borders_cnt, ref_borderd_cnt);
|
||||
ASSERT_EQ(orig_borders_cnt, eq_borders);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, extract_0) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, extract_0) {
|
||||
auto test_model = Model_0();
|
||||
auto models = extractor.extract(test_model.get());
|
||||
auto ref = test_model.get_repeat_pattern_ref();
|
||||
ASSERT_TRUE(is_match(models, ref));
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, extract_1) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, extract_1) {
|
||||
auto test_model = Model_1();
|
||||
auto models = extractor.extract(test_model.get());
|
||||
auto ref = test_model.get_repeat_pattern_ref();
|
||||
ASSERT_TRUE(is_match(models, ref));
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, extract_2) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, extract_2) {
|
||||
auto test_model = Model_2();
|
||||
auto models = extractor.extract(test_model.get());
|
||||
auto ref = test_model.get_repeat_pattern_ref();
|
||||
ASSERT_TRUE(is_match(models, ref));
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_0) {
|
||||
auto test_model = Model_0();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
@ -95,7 +115,7 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_1) {
|
||||
auto test_model = Model_1();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
@ -104,7 +124,7 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_node_vectors_model_2) {
|
||||
auto test_model = Model_2();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
@ -113,32 +133,24 @@ TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_0) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_0) {
|
||||
auto test_model = Model_0();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
is_equal_borders(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_1) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_1) {
|
||||
auto test_model = Model_1();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
is_equal_borders(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_2) {
|
||||
TEST_F(RepeatPatternExtractorFuncTest, get_repeat_pattern_borders_model_2) {
|
||||
auto test_model = Model_2();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
is_equal_borders(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
@ -56,7 +56,7 @@ public:
|
||||
{
|
||||
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
|
||||
ref_pattern_1 = {test_abs_1->inputs(), test_relu_1->outputs()};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_1}};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_1, ref_pattern_0}};
|
||||
ref_borders = std::move(ref_res);
|
||||
}
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ public:
|
||||
|
||||
std::shared_ptr<ov::op::v1::Multiply> test_multiply_0_1 =
|
||||
std::make_shared<ov::op::v1::Multiply>(test_add_0, test_multiply_0_0);
|
||||
test_multiply_0_0->set_friendly_name("Op_" + std::to_string(op_idx++));
|
||||
test_multiply_0_1->set_friendly_name("Op_" + std::to_string(op_idx++));
|
||||
|
||||
std::shared_ptr<ov::op::v0::Relu> test_relu_0_1 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_multiply_0_1);
|
||||
@ -134,11 +134,15 @@ public:
|
||||
ref_pattern_0_1_0 = {test_abs_0_1->inputs(), test_clamp_0_1->outputs()},
|
||||
test_pattern_0_1_1 = {test_multiply_0_1->inputs(), test_relu_0_1->outputs()},
|
||||
test_pattern_1_1 = {test_multiply_1_1->inputs(), test_relu_1_1->outputs()};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_0_0},
|
||||
{ref_pattern_1, ref_pattern_0_1_0},
|
||||
{test_pattern_0_1_1, test_pattern_1_1}};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0_0, ref_pattern_0},
|
||||
{ref_pattern_0_1_0, ref_pattern_1},
|
||||
{test_pattern_1_1, test_pattern_0_1_1}};
|
||||
ref_borders = std::move(ref_res);
|
||||
}
|
||||
start_ops = {test_abs_0, test_abs_0_0, test_abs_0_1, test_abs_1};
|
||||
out_nodes = {test_abs_0, test_relu_0, test_add_0, test_multiply_0_1,
|
||||
test_relu_0_1, test_add};
|
||||
start_node = test_abs_0;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> get() {
|
||||
@ -197,8 +201,24 @@ public:
|
||||
std::vector<std::vector<PatternBorders>>
|
||||
get_ref_node_borders() { return ref_borders; }
|
||||
|
||||
ov::NodeVector
|
||||
get_start_ops() { return start_ops; }
|
||||
|
||||
std::unordered_set<std::shared_ptr<ov::Node>>
|
||||
get_out_nodes_after_abs_0() {
|
||||
return out_nodes;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node>
|
||||
get_test_abs_0() {
|
||||
return start_node;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> model;
|
||||
std::vector<std::vector<ov::NodeVector>> ref_nodes;
|
||||
std::vector<std::vector<PatternBorders>> ref_borders;
|
||||
ov::NodeVector start_ops;
|
||||
std::unordered_set<std::shared_ptr<ov::Node>> out_nodes;
|
||||
std::shared_ptr<ov::Node> start_node;
|
||||
};
|
||||
|
@ -0,0 +1,159 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/abs.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/clamp.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "matchers/subgraph/repeat_pattern.hpp"
|
||||
|
||||
class Model_3 {
|
||||
protected:
|
||||
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
|
||||
std::shared_ptr<ov::Model> model;
|
||||
ov::NodeVector start_ops;
|
||||
ov::NodeVector node_queue;
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>>> ordered_patterns;
|
||||
std::vector<ov::NodeVector> repeats;
|
||||
|
||||
public:
|
||||
Model_3() {
|
||||
// param_00
|
||||
// |
|
||||
// relu_0
|
||||
// |
|
||||
// split_1
|
||||
// |
|
||||
// +-------------+
|
||||
// | |
|
||||
// relu_2 clamp_3
|
||||
// | |
|
||||
// split_4 |
|
||||
// | |
|
||||
// +------------+ |
|
||||
// | | |
|
||||
// relu_5 clamp_6 |
|
||||
// | | |
|
||||
// +------------+ |
|
||||
// | |
|
||||
// add_7 |
|
||||
// | |
|
||||
// concat_8 |
|
||||
// | |
|
||||
// +-------------+
|
||||
// |
|
||||
// multiply_9 param_01
|
||||
// |------------------+
|
||||
// add_10 param_02
|
||||
// |------------------+
|
||||
// multiply_11
|
||||
// |
|
||||
// result_00
|
||||
|
||||
auto param_00 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 1, 24, 1});
|
||||
param_00->set_friendly_name("param_00");
|
||||
auto relu_0 = std::make_shared<ov::op::v0::Relu>(param_00);
|
||||
relu_0->set_friendly_name("relu_0");
|
||||
auto axis_split_1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, std::vector<int64_t>({2}));
|
||||
auto split_1 = std::make_shared<ov::op::v1::Split>(relu_0, axis_split_1, 2);
|
||||
split_1->set_friendly_name("split_1");
|
||||
auto relu_2 = std::make_shared<ov::op::v0::Relu>(split_1->output(0));
|
||||
relu_2->set_friendly_name("relu_2");
|
||||
auto clamp_3 = std::make_shared<ov::op::v0::Clamp>(split_1->output(1), 0 , 10);
|
||||
clamp_3->set_friendly_name("clamp_3");
|
||||
auto axis_split_4 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, std::vector<int64_t>({2}));
|
||||
auto split_4 = std::make_shared<ov::op::v1::Split>(relu_2, axis_split_4, 2);
|
||||
split_4->set_friendly_name("split_4");
|
||||
auto relu_5 = std::make_shared<ov::op::v0::Relu>(split_4->output(0));
|
||||
relu_5->set_friendly_name("relu_5");
|
||||
auto clamp_6 = std::make_shared<ov::op::v0::Clamp>(split_4->output(1), 0, 10);
|
||||
clamp_6->set_friendly_name("clamp_6");
|
||||
auto add_7 = std::make_shared<ov::op::v1::Add>(relu_5, clamp_6);
|
||||
add_7->set_friendly_name("add_7");
|
||||
auto param_03 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, add_7->get_shape());
|
||||
param_03->set_friendly_name("param_03");
|
||||
auto concat_8 = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{add_7, param_03}, 2);
|
||||
concat_8->set_friendly_name("concat_8");
|
||||
auto multiply_9 = std::make_shared<ov::op::v1::Multiply>(concat_8, clamp_3);
|
||||
multiply_9->set_friendly_name("multiply_9");
|
||||
auto param_01 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, multiply_9->get_shape());
|
||||
param_01->set_friendly_name("param_01");
|
||||
auto add_10 = std::make_shared<ov::op::v1::Add>(multiply_9, param_01);
|
||||
add_10->set_friendly_name("add_10");
|
||||
auto param_02 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, add_10->get_shape());
|
||||
param_02->set_friendly_name("param_02");
|
||||
auto multiply_11 = std::make_shared<ov::op::v1::Multiply>(add_10, param_02);
|
||||
multiply_11->set_friendly_name("multiply_11");
|
||||
auto result_00 = std::make_shared<ov::op::v0::Result>(multiply_11);
|
||||
result_00->set_friendly_name("result_00");
|
||||
|
||||
model = std::make_shared<ov::Model>(ov::ResultVector{result_00},
|
||||
ov::ParameterVector{param_00, param_01, param_02, param_03});
|
||||
|
||||
start_ops = {relu_0, relu_2, relu_5};
|
||||
ordered_patterns = {{
|
||||
{ relu_0, {}},
|
||||
{ split_1, {0}},
|
||||
{ relu_2, {1}},
|
||||
{ split_4, {2}},
|
||||
{ relu_5, {3}},
|
||||
{ clamp_6, {3}},
|
||||
{ add_7, {4, 5}},
|
||||
{ concat_8, {6}},
|
||||
{ clamp_3, {1}},
|
||||
{ multiply_9, {7, 8}},
|
||||
{ add_10, {9}},
|
||||
{ multiply_11, {10}},
|
||||
}, {
|
||||
{ relu_2, {}},
|
||||
{ split_4, {0}},
|
||||
{ relu_5, {1}},
|
||||
{ clamp_6, {1}},
|
||||
{ add_7, {2, 3}},
|
||||
{ concat_8, {4}},
|
||||
{ multiply_9, {5}},
|
||||
{ add_10, {6}},
|
||||
{ multiply_11, {7}},
|
||||
}, {
|
||||
{ relu_5, {}},
|
||||
{ add_7, {0}},
|
||||
{ concat_8, {1}},
|
||||
{ multiply_9, {2}},
|
||||
{ add_10, {3}},
|
||||
{ multiply_11, {4}},
|
||||
}};
|
||||
repeats = {
|
||||
{ relu_0, split_1, relu_2 },
|
||||
{ relu_2, split_4, relu_5 },
|
||||
};
|
||||
node_queue = {
|
||||
relu_0, split_1, relu_2, split_4, relu_5, clamp_6,
|
||||
add_7, concat_8, clamp_3, multiply_9, add_10, multiply_11,
|
||||
};
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> get() {
|
||||
return model;
|
||||
}
|
||||
|
||||
ov::NodeVector
|
||||
get_start_ops() { return start_ops; }
|
||||
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<ov::Node>, std::vector<size_t>>>>
|
||||
get_ordered_patterns() { return ordered_patterns; }
|
||||
|
||||
std::vector<ov::NodeVector>
|
||||
get_repeats() { return repeats; }
|
||||
|
||||
ov::NodeVector
|
||||
get_queue() { return node_queue; }
|
||||
};
|
@ -30,9 +30,8 @@ TEST_F(ModelUtilsTest, generate_0) {
|
||||
Model_0 test;
|
||||
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
}
|
||||
{
|
||||
@ -44,9 +43,8 @@ TEST_F(ModelUtilsTest, generate_1) {
|
||||
Model_1 test;
|
||||
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
}
|
||||
{
|
||||
@ -58,9 +56,8 @@ TEST_F(ModelUtilsTest, generate_2) {
|
||||
Model_2 test;
|
||||
std::shared_ptr<ov::Model> test_model = test.get(), recovered_model;
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops, checked_ops);
|
||||
auto model_with_in_info = ov::util::generate_model(func_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
auto in_info = std::get<1>(model_with_in_info);
|
||||
}
|
||||
@ -74,10 +71,11 @@ TEST_F(ModelUtilsTest, align_input_info) {
|
||||
auto in_info_0 = ov::util::get_input_info_by_model(test_model_0.get());
|
||||
auto in_info_1 = ov::util::get_input_info_by_model(test_model_1.get());
|
||||
ASSERT_NE(in_info_0, in_info_1);
|
||||
std::unordered_map<std::string, std::string> a;
|
||||
ASSERT_NO_THROW(ov::util::align_input_info(test_model_0.get(), test_model_1.get(),
|
||||
in_info_0, in_info_1));
|
||||
in_info_0, in_info_1, a));
|
||||
auto in_info_ref = ov::util::align_input_info(test_model_0.get(), test_model_1.get(),
|
||||
in_info_0, in_info_1);
|
||||
in_info_0, in_info_1, a);
|
||||
ASSERT_EQ(in_info_1, in_info_ref);
|
||||
}
|
||||
|
||||
@ -88,7 +86,7 @@ TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
|
||||
auto in_info_0 = ov::util::get_input_info_by_model(test_model_0);
|
||||
auto in_info_1 = ov::util::get_input_info_by_model(test_model_1);
|
||||
ASSERT_NE(in_info_0, in_info_1);
|
||||
std::map<std::string, std::string> matched_ops;
|
||||
auto matched_ops = ov::util::ModelComparator::get()->get_matched_ops_in_graphs(test_model_0, test_model_1);
|
||||
auto params_0 = test_model_0->get_parameters();
|
||||
auto params_1 = test_model_1->get_parameters();
|
||||
size_t params_cnt = params_0.size();
|
||||
@ -96,9 +94,9 @@ TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
|
||||
matched_ops.insert({params_0[param_id]->get_friendly_name(),
|
||||
params_1[param_id]->get_friendly_name()});
|
||||
}
|
||||
ASSERT_NO_THROW(ov::util::align_input_info(test_model_0, test_model_1,
|
||||
in_info_0, in_info_1,
|
||||
matched_ops));
|
||||
// ASSERT_NO_THROW(ov::util::align_input_info(test_model_0, test_model_1,
|
||||
// in_info_0, in_info_1,
|
||||
// matched_ops));
|
||||
auto ref = ov::util::align_input_info(test_model_0, test_model_1,
|
||||
in_info_0, in_info_1, matched_ops);
|
||||
ASSERT_EQ(in_info_1, ref);
|
||||
@ -118,4 +116,13 @@ TEST_F(ModelUtilsTest, get_input_info_by_model) {
|
||||
ASSERT_EQ(cur, ref);
|
||||
}
|
||||
|
||||
TEST_F(ModelUtilsTest, get_subgraph_set_node) {
|
||||
Model_1 model;
|
||||
std::unordered_set<std::shared_ptr<ov::Node>> out_ops;
|
||||
ov::util::get_subgraph_set_node(out_ops, model.get_test_abs_0());
|
||||
auto expected = model.get_out_nodes_after_abs_0();
|
||||
std::set<std::shared_ptr<ov::Node>> orig(out_ops.begin(), out_ops.end()),
|
||||
ref(expected.begin(), expected.end());
|
||||
ASSERT_EQ(orig, ref);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -11,7 +11,8 @@
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
|
||||
#include "openvino/openvino.hpp"
|
||||
#include "utils/model.hpp"
|
||||
namespace {
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
@ -115,10 +116,11 @@ TEST_F(ModelComparatorTest, match_with_low_coeff) {
|
||||
TEST_F(ModelComparatorTest, match_with_in_info) {
|
||||
ov::util::ModelComparator::Ptr model_comparator = ov::util::ModelComparator::get();
|
||||
std::map<std::string, ov::conformance::InputInfo>
|
||||
test_in_info({{"test_parameter_0", ov::conformance::InputInfo()}}),
|
||||
test_in_info_1({{"test_parameter_1", ov::conformance::InputInfo({}, 1, 2, true)}});
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
|
||||
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info)));
|
||||
test_in_info({{"test_parameter_0", ov::conformance::InputInfo(ov::Shape{1, 2})}}),
|
||||
test_in_info_({{"test_parameter_0", ov::conformance::InputInfo(ov::Shape{1, 2})}}),
|
||||
test_in_info_1({{"test_parameter_1", ov::conformance::InputInfo(ov::Shape{2, 5}, 1, 2, true)}});
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_));
|
||||
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_)));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
|
||||
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1)));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
|
||||
|
Loading…
Reference in New Issue
Block a user