[CONFORMANCE][TOOLS] Repeat pattern extractor API (#20293)

* Prepare API

* Refactor api

* Move model comparation to separate component

* Cover by tests

* Move align_in_info to utils

* Change arch diagram
This commit is contained in:
Irina Efode 2023-10-12 21:01:04 +04:00 committed by GitHub
parent 29475c738e
commit 74690d038b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1318 additions and 773 deletions

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:caa4f76ba61548d1b60d7de1f78fb48dccbf5337117240353a9581f23c88bfa9
size 216595
oid sha256:45578db1c9ac5362340ea35fc8fa024e992c8beeb30e984d969ee80217c9031b
size 342214

View File

@ -6,6 +6,7 @@
#include "cache/cache.hpp"
#include "cache/meta/input_info.hpp"
#include "utils/model_comparator.hpp"
#include "matchers/subgraph/manager.hpp"
#include "matchers/subgraph/subgraph.hpp"
#include "matchers/subgraph/fused_names.hpp"
@ -42,10 +43,12 @@ public:
protected:
std::map<std::shared_ptr<ov::Model>, MetaInfo> m_graph_cache;
ExtractorsManager m_manager = ExtractorsManager();
static std::shared_ptr<GraphCache> m_cache_instance;
// cache byte size
uint64_t m_graph_cache_bytesize = 0;
ExtractorsManager m_manager;
ModelComparator::Ptr m_model_comparator = ModelComparator::get();
std::shared_ptr<ov::Model> model_to_update = nullptr;
static std::shared_ptr<GraphCache> m_cache_instance;
GraphCache(const std::string& device = "") {
ExtractorsManager::ExtractorsMap matchers = {
@ -59,7 +62,7 @@ protected:
void update_cache(const std::shared_ptr<ov::Model>& model,
const std::string& model_path,
std::map<std::string, InputInfo>& input_info,
const std::map<std::string, InputInfo>& input_info,
const std::string& extractor_name,
size_t model_op_cnt);
};

View File

@ -59,6 +59,9 @@ struct InputInfo {
}
InputInfo operator=(const InputInfo& input_info) {
if (this->is_const != input_info.is_const) {
throw std::runtime_error("Cast Const to Parameter! Impossible to update Input Info!");
}
this->ranges = input_info.ranges;
if (ov::shape_size(this->max_shape.get_max_shape()) < ov::shape_size(input_info.max_shape.get_max_shape())) {
this->max_shape = input_info.max_shape;

View File

@ -33,7 +33,10 @@ public:
const std::string& extractor = "",
const std::vector<std::string>& ignored_inputs = {});
std::map<std::string, InputInfo> get_input_info() const;
void set_input_info(const std::map<std::string, InputInfo>& new_in_info) { input_info = new_in_info; };
void set_input_info(const std::map<std::string, InputInfo>& new_in_info) {
input_info.clear();
input_info = new_in_info;
};
std::map<std::string, ModelInfo> get_model_info() const;
std::string get_any_extractor() const { return *extractors.begin(); }

View File

@ -20,6 +20,8 @@ public:
const std::shared_ptr<ov::Node> &ref) const;
void set_matchers(const MatchersMap& matchers = {}) { m_matchers = matchers; }
void set_shape_strict_match(bool shape_strict_match);
const MatchersMap& get_matchers() { return m_matchers; }
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;

View File

@ -21,6 +21,7 @@ public:
const std::shared_ptr<ov::Node> &ref) const;
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
void set_strict_shape_match(bool strict_shape_match);
protected:
virtual void configure(const pugi::xml_document &cfg) {};
@ -35,6 +36,8 @@ protected:
const std::shared_ptr<ov::Node> &ref) const;
std::vector<iMatcherConfig::Ptr> default_configs;
// match only shape ranks by default;
bool is_strict_shape_match = false;
};
} // namespace subgraph_dumper

View File

@ -17,9 +17,7 @@ public:
FusedNamesExtractor(const std::string& device = "");
~FusedNamesExtractor();
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true,
bool is_copy_constants = true) override;
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &modele) override;
protected:
std::unordered_set<std::string> extract_compiled_model_names(const std::shared_ptr<ov::Model>& model);

View File

@ -12,38 +12,19 @@ namespace subgraph_dumper {
class ExtractorsManager {
public:
// { model, subgraph, model_in_info, subgraph_in_info }
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
using ExtractorsMap = std::map<std::string, SubgraphExtractor::Ptr>;
explicit ExtractorsManager(const ExtractorsMap& extractors = {}) : m_extractors(extractors) {}
bool match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref);
ExtractedSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info = {},
const std::map<std::string, InputInfo> &in_info_ref = {});
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true,
bool is_copy_constants = true);
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true,
bool is_copy_constants = true);
void set_extractors(const ExtractorsMap& extractors = {}) { m_extractors = extractors; }
ExtractorsMap get_extractors() { return m_extractors; }
std::map<std::string, InputInfo> align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref,
const std::map<std::string, std::string> &matched_op = {});
protected:
ExtractorsMap m_extractors = {};
bool match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref);
};
} // namespace subgraph_dumper

View File

@ -5,31 +5,45 @@
#pragma once
#include <utility>
#include "matchers/subgraph/subgraph.hpp"
#include "matchers/single_op/single_op.hpp"
#include "matchers/single_op/convolutions.hpp"
#include "matchers/single_op/manager.hpp"
namespace ov {
namespace tools {
namespace subgraph_dumper {
class RepeatPatternExtractor final : public SubgraphExtractor {
public:
RepeatPatternExtractor() {
MatchersManager::MatchersMap matchers = {
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
};
manager.set_matchers(matchers);
}
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true,
bool is_copy_constants = true) override;
private:
MatchersManager manager;
using InputVector = std::vector<ov::Input<ov::Node>>;
using OutputVector = std::vector<ov::Output<ov::Node>>;
public:
using PatternBorders = std::pair<InputVector, OutputVector>;
ModelComparator::Ptr model_comparator = ModelComparator::get();
std::vector<std::vector<PatternBorders>>
get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model);
std::vector<std::vector<ov::NodeVector>>
get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model);
void set_recursive_extraction(bool _is_recursive_extraction);
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
protected:
// {subgraph, node_vector, input_info}
using ExtractedRepeatPattern = std::tuple<std::shared_ptr<ov::Model>, ov::NodeVector, std::map<std::string, InputInfo>>;
bool is_recursive_extraction = true;
std::list<std::vector<ExtractedRepeatPattern>>
find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
bool is_save_borders_only = false);
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
std::list<std::vector<ExtractedRepeatPattern>>& secondary_extracted_patterns);
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
const std::shared_ptr<ov::Model>& pattern,
const ov::NodeVector& pattern_node_vector,
const std::map<std::string, InputInfo>& in_info);
};
} // namespace subgraph_dumper

View File

@ -7,12 +7,8 @@
#include <utility>
#include "openvino/op/util/op_types.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include "cache/meta/input_info.hpp"
#include "matchers/single_op/single_op.hpp"
#include "matchers/single_op/convolutions.hpp"
#include "matchers/single_op/manager.hpp"
#include "utils/model_comparator.hpp"
namespace ov {
namespace tools {
@ -20,44 +16,19 @@ namespace subgraph_dumper {
class SubgraphExtractor {
public:
// { is_subgraph, model, subgraph, matched_ops{ model_op_name, graph_op_name }}
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
using Ptr = std::shared_ptr<SubgraphExtractor>;
SubgraphExtractor() {
MatchersManager::MatchersMap matchers = {
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
};
m_manager.set_matchers(matchers);
virtual std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) {
return std::vector<ExtractedPattern>{};
}
bool match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const;
IsSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const;
virtual std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true,
bool is_copy_constants = true) {
return std::list<ExtractedPattern>{};
};
void set_extractor_name(const std::string& _extractor_name) { extractor_name = _extractor_name; }
void set_extract_body(bool _is_extract_body) { is_extract_body = _is_extract_body; }
void set_save_const(bool _is_save_const) { is_save_const = _is_save_const; }
protected:
std::string extractor_name = "";
FunctionsComparator comparator = FunctionsComparator::no_default()
.enable(FunctionsComparator::ATTRIBUTES)
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
MatchersManager m_manager = MatchersManager();
inline bool is_node_to_skip(const std::shared_ptr<ov::Node>& node) const {
return ov::op::util::is_parameter(node) ||
ov::op::util::is_constant(node) ||
ov::op::util::is_output(node);
}
bool is_extract_body = true, is_save_const = true;
};
} // namespace subgraph_dumper

View File

@ -33,7 +33,7 @@ static std::vector<std::regex> FROTEND_REGEXP = {
std::regex(R"(.*__model__)"),
#endif
#ifdef ENABLE_OV_TF_FRONTEND
std::regex(R"(.*\.pb)"),
std::regex(R"(.*\model.pb)"),
#endif
#ifdef ENABLE_OV_IR_FRONTEND
std::regex(R"(.*\.xml)"),
@ -74,32 +74,24 @@ std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
void save_model_status_to_file(const std::map<ModelCacheStatus, std::vector<std::string>>& caching_status,
const std::string& output_dir);
inline bool is_dynamic_model(const std::shared_ptr<ov::Model>& model) {
for (const auto& parameter : model->get_parameters()) {
if (is_dynamic_node(parameter)) {
return true;
}
}
for (const auto& result : model->get_results()) {
if (is_dynamic_node(result)) {
return true;
}
}
return false;
}
bool is_dynamic_model(const std::shared_ptr<ov::Model>& model);
std::string get_model_type(const std::shared_ptr<ov::Model>& model);
inline std::string get_model_type(const std::shared_ptr<ov::Model>& model) {
if (is_dynamic_model(model)) {
return "dynamic";
}
return "static";
}
std::map<std::string, InputInfo>
get_input_info_by_model(const std::shared_ptr<ov::Model>& model);
inline ExtractedPattern
generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
std::map<std::string, InputInfo>
align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref,
const std::map<std::string, std::string> &matched_op = {});
inline std::pair<std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>>
generate_model(ov::NodeVector& nodes,
std::unordered_set<std::string>& checked_ops,
const std::string& extractor_name,
bool is_copy_constants = true) {
bool is_copy_constants = true,
bool is_save_only_borders = false) {
// map to recover graph using cloned nodes and original connections
// { original_node_name, cloned_node }
std::unordered_map<std::string, std::shared_ptr<ov::Node>> cloned_node_map;
@ -214,27 +206,51 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
// prepare unique model name based on operations from model
std::string string_to_hash;
for (const auto& op : model->get_ordered_ops()) {
bool is_erase_node = !is_save_only_borders;
std::ostringstream result;
result << op->get_type_info();
for (const auto& in : op->inputs()) {
for (size_t i = 0; i < op->inputs().size(); ++i) {
const auto& in = op->input(i);
if (!is_node_to_skip(op->get_input_node_shared_ptr(i))) {
is_erase_node |= true;
}
result << in.get_element_type();
result << in.get_partial_shape().rank();
result << in.get_partial_shape().is_static();
}
for (const auto& out : op->outputs()) {
for (const auto& target_input : out.get_target_inputs()) {
if (!is_node_to_skip(target_input.get_node()->shared_from_this())) {
is_erase_node |= true;
break;
}
}
result << out.get_element_type();
result << out.get_partial_shape().rank();
result << out.get_partial_shape().is_static();
}
string_to_hash += result.str();
if (is_erase_node) {
cloned_node_map.erase(op->get_friendly_name());
}
}
for (const auto& in : model_input_info) {
string_to_hash += (in.second.is_const ? "1" : "0");
}
auto h1 = std::hash<std::string>{}(string_to_hash);
model->set_friendly_name(std::to_string(h1));
return { model, model_input_info, extractor_name };
{
auto it = nodes.begin();
while (it != nodes.end()) {
if (cloned_node_map.count((*it)->get_friendly_name())) {
nodes.erase(it);
} else {
++it;
}
}
}
return { model, model_input_info };
}
} // namespace subgraph_dumper

View File

@ -0,0 +1,68 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "matchers/single_op/single_op.hpp"
#include "matchers/single_op/convolutions.hpp"
#include "matchers/single_op/manager.hpp"
namespace ov {
namespace tools {
namespace subgraph_dumper {
class ModelComparator {
public:
using Ptr = std::shared_ptr<ModelComparator>;
// { is_match, subgraph, graph, matched_nodes -> {subgraph_op_name, graph_op_name}}
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
// { model, subgraph, graph, subgraph_in_info, model_in_info, }
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
static std::shared_ptr<ModelComparator> get(bool in_is_match_shapes = false) {
if (m_instance == nullptr) {
m_instance = std::shared_ptr<ModelComparator>(new ModelComparator);
}
return m_instance;
}
IsSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const;
bool match(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref_node) const;
bool match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const;
std::pair<bool, std::map<std::string, InputInfo>>
match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref);
ExtractedSubgraphTuple
is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref);
void set_match_coefficient(float _match_coefficient);
void set_shape_strict_match(bool is_shape_strict_match);
protected:
MatchersManager m_manager = MatchersManager();
float match_coefficient = 0.9f;
static std::shared_ptr<ModelComparator> m_instance;
ModelComparator() {
MatchersManager::MatchersMap matchers = {
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
};
m_manager.set_matchers(matchers);
}
};
} // namespace subgraph_dumper
} // namespace tools
} // namespace ov

View File

@ -26,6 +26,9 @@ inline InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Const
return InputInfo::Range(static_cast<double>(min), static_cast<double>(max));
}
InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& const_node,
ov::element::Type elem_type);
std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov::Node>& node);
// replace all input node by parameters and constants instead of non input mode types
@ -111,6 +114,12 @@ inline size_t get_node_priority_by_version(const std::shared_ptr<ov::Node>& node
return priority;
}
inline bool is_node_to_skip(const std::shared_ptr<ov::Node>& node) {
return ov::op::util::is_parameter(node) ||
ov::op::util::is_constant(node) ||
ov::op::util::is_output(node);
}
} // namespace subgraph_dumper
} // namespace tools

View File

@ -42,8 +42,8 @@ bool ICache::serialize_model(const std::pair<std::shared_ptr<ov::Model>, MetaInf
meta.serialize(meta_path);
return true;
} catch (std::exception &e) {
// std::cout << "[ ERROR ] Failed to serialize model: " << model_name
// << ". Exception: " << e.what() << std::endl;
std::cout << "[ ERROR ] Failed to serialize model: " << model_name
<< ". Exception: " << e.what() << std::endl;
ov::test::utils::removeFile(xml_path);
ov::test::utils::removeFile(bin_path);
ov::test::utils::removeFile(meta_path);

View File

@ -56,77 +56,92 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& model,
return;
}
while (!extracted_patterns.empty()) {
auto it = *extracted_patterns.begin();
auto it = *extracted_patterns.rbegin();
update_cache(std::get<0>(it), model_meta_data, std::get<1>(it), std::get<2>(it), model_total_op);
extracted_patterns.pop_front();
extracted_patterns.pop_back();
}
}
}
void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
const std::string& model_path,
std::map<std::string, InputInfo>& input_info,
const std::map<std::string, InputInfo>& input_info,
const std::string& extractor_name,
size_t model_op_cnt) {
auto graph_name = extracted_model->get_friendly_name();
auto this_op_cnt = extracted_model->get_ops().size() -
extracted_model->get_parameters().size() - extracted_model->get_results().size();
std::string serialized_model_path = "";
for (const auto& extractor : m_manager.get_extractors()) {
auto tmp_serialized_model_path = ov::util::path_join({ m_serialization_dir, m_cache_subdir, extractor.first, graph_name + ".xml" });
if (ov::util::file_exists(serialized_model_path)) {
serialized_model_path = tmp_serialized_model_path;
break;
std::map<std::string, InputInfo> updated_input_info;
if (!m_graph_cache.empty() && model_to_update != nullptr) {
auto comparator_res = m_model_comparator->match(extracted_model, model_to_update,
input_info, m_graph_cache.at(model_to_update).get_input_info());
if (comparator_res.first) {
updated_input_info = comparator_res.second;
} else {
model_to_update = nullptr;
}
}
std::shared_ptr<ov::Model> model_to_update = nullptr;
// if cached model was serialized
if (!serialized_model_path.empty()) {
// std::cout << "[ GRAPH CACHE ][ INFO ] Reading cached model: " << serialized_model_path << std::endl;
auto bin_path = ov::test::utils::replaceExt(serialized_model_path, ".bin");
auto meta_path = ov::test::utils::replaceExt(serialized_model_path, ".meta");
auto cached_model = ov::test::utils::PluginCache::get().core()->read_model(serialized_model_path);
auto cached_meta = MetaInfo::read_meta_from_file(meta_path);
ov::test::utils::removeFile(serialized_model_path);
ov::test::utils::removeFile(bin_path);
ov::test::utils::removeFile(meta_path);
m_graph_cache.insert({ cached_model, cached_meta });
m_graph_cache_bytesize += cached_model->get_graph_size();
if (m_manager.match(extracted_model, cached_model,
input_info, cached_meta.get_input_info())) {
model_to_update = cached_model;
}
} else {
for (const auto& cached_model : m_graph_cache) {
if (m_manager.match(extracted_model, cached_model.first,
input_info, cached_model.second.get_input_info())) {
model_to_update = cached_model.first;
if (model_to_update == nullptr) {
std::string serialized_model_path = "";
for (const auto& extractor : m_manager.get_extractors()) {
auto tmp_serialized_model_path = ov::util::path_join({ m_serialization_dir, m_cache_subdir, extractor.first, graph_name + ".xml" });
if (ov::util::file_exists(serialized_model_path)) {
serialized_model_path = tmp_serialized_model_path;
break;
} else {
auto is_subgraph = m_manager.is_subgraph(extracted_model, cached_model.first,
input_info, cached_model.second.get_input_info());
// in case if one model is subgraph of other to update model meta info and remove subgraph from cache
if (std::get<0>(is_subgraph)) {
std::shared_ptr<ov::Model> graph, subgraph;
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
std::tie(std::ignore, graph, subgraph, graph_in_info, subgraph_in_info) = is_subgraph;
if (subgraph == cached_model.first) {
auto meta = m_graph_cache[subgraph];
meta.set_input_info(graph_in_info);
m_graph_cache.erase(subgraph);
m_graph_cache.insert({graph, meta});
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
}
}
// if cached model was serialized
if (!serialized_model_path.empty()) {
// std::cout << "[ GRAPH CACHE ][ INFO ] Reading cached model: " << serialized_model_path << std::endl;
auto bin_path = ov::test::utils::replaceExt(serialized_model_path, ".bin");
auto meta_path = ov::test::utils::replaceExt(serialized_model_path, ".meta");
auto cached_model = ov::test::utils::PluginCache::get().core()->read_model(serialized_model_path);
auto cached_meta = MetaInfo::read_meta_from_file(meta_path);
ov::test::utils::removeFile(serialized_model_path);
ov::test::utils::removeFile(bin_path);
ov::test::utils::removeFile(meta_path);
m_graph_cache.insert({ cached_model, cached_meta });
m_graph_cache_bytesize += cached_model->get_graph_size();
auto comparator_res = m_model_comparator->match(extracted_model, cached_model,
input_info, cached_meta.get_input_info());
if (comparator_res.first) {
model_to_update = cached_model;
updated_input_info = comparator_res.second;
}
} else {
for (const auto& cached_model : m_graph_cache) {
auto comparator_res = m_model_comparator->match(extracted_model, cached_model.first,
input_info, cached_model.second.get_input_info());
if (comparator_res.first) {
model_to_update = cached_model.first;
updated_input_info = comparator_res.second;
break;
} else {
auto is_subgraph = m_model_comparator->is_subgraph(extracted_model, cached_model.first,
input_info, cached_model.second.get_input_info());
// in case if one model is subgraph of other to update model meta info and remove subgraph from cache
if (std::get<0>(is_subgraph)) {
std::shared_ptr<ov::Model> graph, subgraph;
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
std::tie(std::ignore, subgraph, graph, subgraph_in_info, graph_in_info) = is_subgraph;
if (subgraph == cached_model.first) {
auto meta = m_graph_cache[subgraph];
meta.set_input_info(graph_in_info);
m_graph_cache.erase(subgraph);
m_graph_cache.insert({graph, meta});
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
}
m_graph_cache[cached_model.first].update(model_path,
subgraph_in_info,
model_op_cnt,
this_op_cnt,
extractor_name);
return;
}
m_graph_cache[cached_model.first].update(model_path,
subgraph_in_info,
model_op_cnt,
this_op_cnt,
extractor_name);
return;
}
}
}
@ -134,18 +149,22 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
if (model_to_update == nullptr) {
MetaInfo meta = MetaInfo(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
m_graph_cache.insert({ extracted_model, meta });
model_to_update = extracted_model;
m_graph_cache.insert({ model_to_update, meta });
m_graph_cache_bytesize += extracted_model->get_graph_size();
return;
}
m_graph_cache[model_to_update].update(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
m_graph_cache[model_to_update].update(model_path, updated_input_info, model_op_cnt, this_op_cnt, extractor_name);
auto cached_model_size = model_to_update->get_graph_size();
auto pattern_model_size = extracted_model->get_graph_size();
if (pattern_model_size < cached_model_size) {
m_graph_cache_bytesize -= (cached_model_size - pattern_model_size);
auto meta = m_graph_cache[model_to_update];
auto new_in_info = align_input_info(model_to_update, extracted_model, m_graph_cache.at(model_to_update).get_input_info(), input_info);
meta.set_input_info(new_in_info);
m_graph_cache.erase(model_to_update);
m_graph_cache.insert({extracted_model, meta});
model_to_update = extracted_model;
m_graph_cache.insert({model_to_update, meta});
}
}

View File

@ -169,7 +169,6 @@ void MetaInfo::update(const std::string& _model_path,
size_t _this_op_cnt,
const std::string& extractor,
const std::vector<std::string>& ignored_inputs) {
bool is_update_in_info = true;
if (input_info.size() != _input_info.size()) {
throw std::runtime_error("Incompatible input info!");
}
@ -193,9 +192,6 @@ void MetaInfo::update(const std::string& _model_path,
if (!extractor.empty()) {
extractors.insert(extractor);
}
if (!is_update_in_info) {
return;
}
for (const auto& in : _input_info) {
if (std::find(ignored_inputs.begin(), ignored_inputs.end(), in.first) != ignored_inputs.begin()) {
continue;

View File

@ -50,8 +50,11 @@ bool ConvolutionsMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
bool has_groups = std::dynamic_pointer_cast<ov::op::v1::GroupConvolution>(node) ||
std::dynamic_pointer_cast<ov::op::v1::GroupConvolutionBackpropData>(node);
size_t kernel_size_offset = has_groups ? 3 : 2;
auto ref_weights_shape = ref->get_input_tensor(1).get_shape();
auto cur_weights_shape = node->get_input_tensor(1).get_shape();
auto ref_weights_shape = ref->get_input_partial_shape(1).get_shape();
auto cur_weights_shape = node->get_input_partial_shape(1).get_shape();
if (is_strict_shape_match && ref_weights_shape != cur_weights_shape) {
return false;
}
const auto ref_kernel_size = std::vector<size_t>(ref_weights_shape.begin() + kernel_size_offset,
ref_weights_shape.end());
const auto cur_kernel_size = std::vector<size_t>(cur_weights_shape.begin() + kernel_size_offset,

View File

@ -17,9 +17,15 @@ iMatcherConfig::Ptr MatchersManager::get_config(const std::shared_ptr<ov::Node>
return nullptr;
}
void MatchersManager::set_shape_strict_match(bool shape_strict_match) {
for (const auto& matcher : m_matchers) {
matcher.second->set_strict_shape_match(shape_strict_match);
}
}
bool MatchersManager::match(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref) const {
for (const auto &it : m_matchers) {
for (const auto& it : m_matchers) {
if (it.second->match(node, ref)) {
return true;
}

View File

@ -7,6 +7,7 @@
#include "openvino/op/util/op_types.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include "matchers/single_op/single_op.hpp"
#include "utils/node.hpp"
using namespace ov::tools::subgraph_dumper;
@ -24,6 +25,10 @@ iMatcherConfig::Ptr SingleOpMatcher::get_config(const std::shared_ptr<ov::Node>
return std::make_shared<MatcherConfig<>>();
}
void SingleOpMatcher::set_strict_shape_match(bool strict_shape_match) {
is_strict_shape_match = strict_shape_match;
}
bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref) const {
if (node->get_input_size() != ref->get_input_size()) {
@ -35,21 +40,17 @@ bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
if (std::find(ignored_ports.begin(), ignored_ports.end(), port_id) != ignored_ports.end()) {
continue;
}
if (!ov::op::util::is_parameter(node) && !ov::op::util::is_parameter(ref) &&
!ov::op::util::is_constant(node) && !ov::op::util::is_constant(ref)) {
const auto &cur_node_input_type = node->input_value(port_id).get_node_shared_ptr()->get_type_info();
const auto &ref_node_input_type = ref->input_value(port_id).get_node_shared_ptr()->get_type_info();
if (cur_node_input_type != ref_node_input_type) {
return false;
}
}
if (node->get_input_tensor(port_id).get_partial_shape().rank() != ref->get_input_tensor(port_id).get_partial_shape().rank()) {
if (node->get_input_element_type(port_id) != ref->get_input_element_type(port_id)) {
return false;
}
if (node->get_input_tensor(port_id).get_element_type() != ref->get_input_tensor(port_id).get_element_type()) {
const auto& partial_shape = node->get_input_partial_shape(port_id);
const auto& ref_partial_shape = ref->get_input_partial_shape(port_id);
if (is_strict_shape_match && partial_shape != ref_partial_shape) {
return false;
} else if (partial_shape.rank() != ref_partial_shape.rank()) {
return false;
}
if (node->get_input_partial_shape(port_id).is_dynamic() != ref->get_input_partial_shape(port_id).is_dynamic()) {
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
return false;
}
}
@ -63,20 +64,18 @@ SingleOpMatcher::match_outputs(const std::shared_ptr<ov::Node> &node,
return false;
}
for (size_t port_id = 0; port_id < node->get_output_size(); ++port_id) {
if (!ov::op::util::is_output(node) && !ov::op::util::is_output(ref)) {
const auto &cur_node_out_type = node->output(port_id).get_node_shared_ptr()->get_type_info();
const auto &ref_node_out_type = ref->output(port_id).get_node_shared_ptr()->get_type_info();
if (cur_node_out_type != ref_node_out_type) {
return false;
}
}
if (node->get_output_tensor(port_id).get_element_type() != ref->get_output_tensor(port_id).get_element_type()) {
if (node->get_output_element_type(port_id) != ref->get_output_element_type(port_id)) {
return false;
}
if (node->get_output_tensor(port_id).get_partial_shape().is_dynamic() != ref->get_output_tensor(port_id).get_partial_shape().is_dynamic()) {
const auto& partial_shape = node->get_output_partial_shape(port_id);
const auto& ref_partial_shape = ref->get_output_partial_shape(port_id);
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
return false;
}
if (node->get_output_tensor(port_id).get_partial_shape().rank()!= ref->get_output_tensor(port_id).get_partial_shape().rank()) {
if (is_strict_shape_match && partial_shape != ref_partial_shape) {
return false;
} else if (partial_shape.rank() != ref_partial_shape.rank()) {
return false;
}
}
@ -98,17 +97,16 @@ bool SingleOpMatcher::match(const std::shared_ptr<ov::Node> &node,
if (cfg->ignore_matching) {
return false;
}
if (!same_op_type(node, ref)) {
return false;
}
if (!match_inputs(node, ref)) {
return false;
}
if (!match_attrs(node, ref) && !ov::op::util::is_parameter(node) && !ov::op::util::is_parameter(ref)) {
if (!match_outputs(node, ref)) {
return false;
}
if (!match_outputs(node, ref)) {
if (!match_attrs(node, ref) && !is_node_to_skip(node)) {
return false;
}
return true;
@ -121,9 +119,6 @@ bool SingleOpMatcher::same_op_type(const std::shared_ptr<ov::Node> &node,
SingleOpMatcher::SingleOpMatcher() {
default_configs = {
// std::make_shared<MatcherConfig<>>(std::vector<std::string>{}, std::vector<size_t>{0}),
// std::make_shared<MatcherConfig<ov::opset8::FakeQuantize>>(std::vector<std::string>{},
// std::vector<size_t>{0, 1, 2, 3, 4}),
std::make_shared<MatcherConfig<
ov::op::v1::Convolution,
ov::op::v1::ConvolutionBackpropData,

View File

@ -5,6 +5,7 @@
#include "openvino/op/lstm_cell.hpp"
#include "openvino/op/tensor_iterator.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/loop.hpp"
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/ov_plugin_cache.hpp"
@ -56,14 +57,12 @@ FusedNamesExtractor::~FusedNamesExtractor() {
core.reset();
}
std::list<ExtractedPattern>
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body,
bool is_copy_constants) {
std::vector<ExtractedPattern>
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
auto compiled_op_name = extract_compiled_model_names(model);
std::list<ExtractedPattern> matched_patterns;
std::vector<ExtractedPattern> matched_patterns;
std::unordered_set<std::string> checked_ops;
std::set<std::shared_ptr<ov::Node>> nodes;
ov::NodeVector nodes;
for (const auto& op : model->get_ordered_ops()) {
auto op_name = op->get_friendly_name();
if (is_node_to_skip(op) || checked_ops.count(op_name)) {
@ -71,7 +70,8 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
}
if (compiled_op_name.count(op_name)) {
try {
matched_patterns.push_back(generate_model(nodes, checked_ops, extractor_name, is_copy_constants));
auto extracted_pattern = generate_model(nodes, checked_ops, is_save_const);
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
} catch(std::exception& e) {
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;
@ -79,7 +79,7 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
}
nodes.clear();
} else {
nodes.insert(op);
nodes.push_back(op);
}
if (is_extract_body) {
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
@ -104,7 +104,8 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
}
}
try {
matched_patterns.push_back(generate_model(nodes, checked_ops, extractor_name, is_copy_constants));
auto extracted_pattern = generate_model(nodes, checked_ops, is_save_const);
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
} catch(std::exception& e) {
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;

View File

@ -9,122 +9,18 @@
using namespace ov::tools::subgraph_dumper;
bool ExtractorsManager::match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref) {
// `match` is not virtual method in base `SubgraphExtractor` class
// we can use function from any `extractor` to avoid of cycle
if (!m_extractors.empty()) {
if (m_extractors.begin()->second->match(model, ref)) {
return true;
}
}
return false;
}
ExtractorsManager::ExtractedSubgraphTuple
ExtractorsManager::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref) {
if (!m_extractors.empty()) {
// `is_subgraph` is not virtual method in base `SubgraphExtractor` class
// we can use function from any `extractor` to avoid of cycle
auto extractor_res = m_extractors.begin()->second->is_subgraph(model, ref_model);
if (std::get<0>(extractor_res)) {
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
graph_in_info = in_info;
subgraph_in_info = in_info_ref;
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
graph_in_info = in_info_ref;
subgraph_in_info = in_info;
} else {
throw std::runtime_error("Generated models are incompatible with original ones!");
}
try {
subgraph_in_info = align_input_info(std::get<2>(extractor_res), std::get<1>(extractor_res), subgraph_in_info, graph_in_info);
} catch(std::exception) {
return { false, nullptr, nullptr, {}, {} };
}
return { true, std::get<1>(extractor_res), std::get<2>(extractor_res), graph_in_info, subgraph_in_info };
}
}
return { false, nullptr, nullptr, {}, {} };
}
bool ExtractorsManager::match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref,
std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref) {
if (match(model, ref)) {
try {
in_info = align_input_info(model, ref, in_info, in_info_ref);
return true;
} catch (std::exception) {
return false;
}
}
return false;
}
std::map<std::string, InputInfo>
ExtractorsManager::align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, InputInfo>& in_info,
const std::map<std::string, InputInfo>& in_info_ref,
const std::map<std::string, std::string> &matched_op) {
std::map<std::string, InputInfo> new_input_info = in_info;
bool is_update_required = false;
for (const auto& in_info_item : in_info_ref) {
if (!in_info.count(in_info_item.first)) {
is_update_required = true;
break;
} else if (in_info.at(in_info_item.first).is_const != in_info_item.second.is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
}
if (is_update_required) {
// align matched model names
auto ref_model_ops = model_ref->get_ordered_ops();
auto model_ops = model->get_ordered_ops();
size_t ref_ordered_ops_size = ref_model_ops.size();
size_t ordered_ops_size = model_ops.size();
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
throw std::runtime_error("Matched models can not be compared according different op numbers!");
}
for (size_t i = 0; i < ref_ordered_ops_size; ++i) {
auto model_op_name = i < ordered_ops_size ? model_ops[i]->get_friendly_name() : "";
auto model_ref_op_name = ref_model_ops[i]->get_friendly_name();
if (!in_info_ref.count(model_ref_op_name) && !in_info.count(model_op_name)) {
continue;
}
auto input_info = matched_op.empty() ? new_input_info[model_op_name] : in_info_ref.at(model_ref_op_name);
std::string input_name = matched_op.count(model_ref_op_name) ? matched_op.at(model_ref_op_name) : model_op_name;
if (new_input_info.count(input_name)) {
if (input_info.is_const != in_info_ref.at(model_ref_op_name).is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
if (!matched_op.empty()) {
input_info = new_input_info.at(input_name);
}
new_input_info.erase(input_name);
}
new_input_info.insert({ model_ref_op_name, input_info });
}
}
return new_input_info;
}
std::list<ExtractedPattern>
std::vector<ExtractedPattern>
ExtractorsManager::extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body,
bool is_copy_constants) {
std::list<ExtractedPattern> result;
std::vector<ExtractedPattern> result;
for (const auto &it : m_extractors) {
// extract patterns from original models
auto start = std::chrono::high_resolution_clock::now();
it.second->set_extractor_name(it.first);
auto extracted_patterns = it.second->extract(model, is_extract_body, is_copy_constants);
it.second->set_extract_body(is_extract_body);
it.second->set_save_const(is_copy_constants);
auto extracted_patterns = it.second->extract(model);
result.insert(result.end(), extracted_patterns.begin(), extracted_patterns.end());
auto end = std::chrono::high_resolution_clock::now();
auto delta = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

View File

@ -8,18 +8,118 @@
#include "openvino/op/lstm_cell.hpp"
#include "openvino/op/tensor_iterator.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/loop.hpp"
#include "matchers/subgraph/repeat_pattern.hpp"
#include "utils/model.hpp"
#include "utils/model_comparator.hpp"
using namespace ov::tools::subgraph_dumper;
std::list<ExtractedPattern>
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body,
bool is_copy_constants) {
void RepeatPatternExtractor::set_recursive_extraction(bool _is_recursive_extraction) {
is_recursive_extraction = _is_recursive_extraction;
}
std::vector<ExtractedPattern>
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
std::vector<ExtractedPattern> extracted_patterns;
for (const auto& pattern : find_repeat_patterns(model)) {
for (const auto& pattern_structure : pattern) {
extracted_patterns.push_back({std::get<0>(pattern_structure), std::get<2>(pattern_structure), extractor_name});
}
}
return extracted_patterns;
}
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>
RepeatPatternExtractor::get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model) {
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>> extracted_patterns;
for (auto& pattern : find_repeat_patterns(model, true)) {
std::vector<RepeatPatternExtractor::PatternBorders> same_pattern_borders;
for (const auto& pattern_structure : pattern) {
std::set<std::string> output_names;
for (const auto& result : std::get<0>(pattern_structure)->get_results()) {
output_names.insert(result->get_input_node_shared_ptr(0)->get_friendly_name());
}
RepeatPatternExtractor::InputVector in_vec;
RepeatPatternExtractor::OutputVector out_vec;
for (const auto& node : std::get<1>(pattern_structure)) {
if (output_names.count(node->get_friendly_name())) {
OutputVector node_outputs = node->outputs();
out_vec.insert(out_vec.end(), node_outputs.begin(), node_outputs.end());
} else {
for (const auto& input : node->inputs()) {
in_vec.push_back(input);
}
}
}
same_pattern_borders.push_back({in_vec, out_vec});
}
extracted_patterns.push_back(same_pattern_borders);
}
return extracted_patterns;
}
std::vector<std::vector<ov::NodeVector>>
RepeatPatternExtractor::get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model) {
std::vector<std::vector<ov::NodeVector>> extracted_patterns;
for (const auto& pattern : find_repeat_patterns(model)) {
std::vector<ov::NodeVector> same_pattern_nodes;
for (const auto& pattern_structure : pattern) {
same_pattern_nodes.push_back(std::get<1>(pattern_structure));
}
extracted_patterns.push_back(same_pattern_nodes);
}
return extracted_patterns;
}
void
RepeatPatternExtractor::update_extractor_cache(
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
const std::shared_ptr<ov::Model>& pattern,
const ov::NodeVector& pattern_node_vector,
const std::map<std::string, InputInfo>& pattern_in_info) {
for (auto& extracted_pattern : extracted_patterns) {
auto& pattern_structure = extracted_pattern.front();
const auto& cached_pattern = std::get<0>(pattern_structure);
if (model_comparator->match(pattern, cached_pattern)) {
try {
const auto& cached_in_info = std::get<2>(pattern_structure);
align_input_info(pattern, cached_pattern, pattern_in_info, cached_in_info);
extracted_pattern.push_back({ pattern, pattern_node_vector, pattern_in_info });
return;
} catch(std::exception) {}
}
}
extracted_patterns.push_back({{ pattern, pattern_node_vector, pattern_in_info }});
}
void
RepeatPatternExtractor::update_extractor_cache(
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& secondary_extracted_patterns) {
auto extern_it = secondary_extracted_patterns.begin();
while (!secondary_extracted_patterns.empty()) {
auto it = extern_it->rbegin();
while (!extern_it->empty()) {
auto& pattern_structure = *it;
const auto& pattern = std::get<0>(pattern_structure);
const auto& pattern_node_vector = std::get<1>(pattern_structure);
const auto& pattern_in_info = std::get<2>(pattern_structure);
update_extractor_cache(extracted_patterns, pattern, pattern_node_vector, pattern_in_info);
extern_it->pop_back();
it = extern_it->rbegin();
}
secondary_extracted_patterns.pop_front();
}
}
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>
RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
bool is_save_borders_only) {
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>> extracted_patterns;
std::unordered_set<std::string> checked_ops;
std::list<ExtractedPattern> to_cache;
auto ordered_ops = model->get_ordered_ops();
auto op_cnt = ordered_ops.size();
@ -31,9 +131,10 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
continue;
}
// find the same nodes
std::vector<size_t> start_node_idx{idx};
for (size_t i = idx + 1; i < op_cnt; ++i) {
if (manager.match(op, ordered_ops[i])) {
if (model_comparator->match(op, ordered_ops[i])) {
start_node_idx.push_back(i);
}
}
@ -57,9 +158,9 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
if (node_idx == start_node_idx[i] && ref_node_idx == start_node_idx[j]) {
nodes[i].insert(node);
nodes[j].insert(ref_node);
} else if (manager.match(node, ref_node)) {
} else if (model_comparator->match(node, ref_node)) {
// check if we met the same node
if (manager.match(node, op)) {
if (model_comparator->match(node, op)) {
break;
}
if (checked_ops.count(node->get_friendly_name()) ||
@ -94,16 +195,26 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
for (size_t i = 0; i < start_node_idx.size(); ++i) {
try {
std::unordered_set<std::string> tmp_checked_ops;
auto extracted_pattern = generate_model(nodes[i], tmp_checked_ops, extractor_name, is_copy_constants);
auto extracted_model = std::get<0>(extracted_pattern);
std::list<ExtractedPattern> secondary_patterns;
if (nodes[i].size() > 20) {
secondary_patterns = extract(std::get<0>(extracted_pattern), is_extract_body, is_copy_constants);
}
if (secondary_patterns.size() > 1) {
to_cache.insert(to_cache.end(), secondary_patterns.begin(), secondary_patterns.end());
// model, in_info, extractor_name
ov::NodeVector nodes_vector(nodes[i].begin(), nodes[i].end());
auto extracted_pattern = generate_model(nodes_vector, tmp_checked_ops, is_save_const, is_save_borders_only);
auto extracted_model = extracted_pattern.first;
if (is_recursive_extraction && nodes_vector.size() > 20) {
auto secondary_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
if (!secondary_patterns.empty()) {
tmp_checked_ops.clear();
update_extractor_cache(extracted_patterns, secondary_patterns);
} else {
update_extractor_cache(extracted_patterns,
extracted_model,
nodes_vector,
extracted_pattern.second);
}
} else {
to_cache.push_back(extracted_pattern);
update_extractor_cache(extracted_patterns,
extracted_model,
nodes_vector,
extracted_pattern.second);
}
nodes[i].clear();
checked_ops.insert(tmp_checked_ops.begin(), tmp_checked_ops.end());
@ -117,23 +228,39 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(op);
auto ti_body = ti->get_function();
auto tmp_res = extract(ti_body);
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(op)) {
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(op);
auto loop_body = loop->get_function();
auto tmp_res = extract(loop_body);
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(op)) {
auto if_op = ov::as_type_ptr<ov::op::v8::If>(op);
std::vector<std::shared_ptr<ov::Model>> bodies;
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
auto if_body = if_op->get_function(i);
auto tmp_res = extract(if_body);
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
update_extractor_cache(extracted_patterns, secondary_patterns);
}
}
}
}
return to_cache;
// clean up patterns
{
auto it = extracted_patterns.begin();
size_t elem_cnt = 0;
while (it != extracted_patterns.end()) {
if (it->size() > 1) {
++it;
++elem_cnt;
} else {
extracted_patterns.erase(it);
it = extracted_patterns.begin();
std::advance(it, elem_cnt);
}
}
}
return extracted_patterns;
}

View File

@ -1,76 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include "matchers/subgraph/subgraph.hpp"
using namespace ov::tools::subgraph_dumper;
bool
SubgraphExtractor::match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const {
bool res = comparator.compare(model, ref_model).valid;
if (res) {
return res;
}
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
ref_ordered_ops = ref_model->get_ordered_ops();
if (ordered_ops.size() != ref_ordered_ops.size()) {
return false;
}
size_t matched_op_cnt = 0, total_op_cnt = ordered_ops.size();
size_t matched_op_cnt_required = round(0.9 * total_op_cnt);
for (size_t i = 0; i < total_op_cnt; ++i) {
if (is_node_to_skip(ordered_ops[i]) &&
is_node_to_skip(ref_ordered_ops[i]) ||
m_manager.match(ordered_ops[i], ref_ordered_ops[i])) {
++matched_op_cnt;
}
if (matched_op_cnt >= matched_op_cnt_required) {
return true;
}
}
return false;
}
inline SubgraphExtractor::IsSubgraphTuple prepare_is_subgraph_result(bool is_subgraph,
const std::shared_ptr<ov::Model>& graph,
const std::shared_ptr<ov::Model>& subgraph,
const std::map<std::string, std::string>& matched_ops) {
return is_subgraph ?
std::make_tuple(is_subgraph, graph, subgraph, matched_ops) :
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
}
SubgraphExtractor::IsSubgraphTuple
SubgraphExtractor::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const {
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
ref_ordered_ops = ref_model->get_ordered_ops();
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
if (is_model) {
graph_to_check_ops = ordered_ops;
subgraph_to_check_ops = ref_ordered_ops;
graph = model;
subgraph = ref_model;
} else {
graph_to_check_ops = ref_ordered_ops;
subgraph_to_check_ops = ordered_ops;
graph = ref_model;
subgraph = model;
}
std::map<std::string, std::string> matched_op_names;
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
if (m_manager.match(*graph_it, *subgraph_it)) {
matched_op_names.insert({ (*graph_it)->get_friendly_name(), (*subgraph_it)->get_friendly_name()});
++subgraph_it;
}
++graph_it;
}
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), graph, subgraph, matched_op_names);
}

View File

@ -70,6 +70,27 @@ find_models(const std::vector<std::string> &dirs, const std::string& regexp) {
return { models, { ModelCacheStatus::NOT_READ, not_read_model } };
}
bool is_dynamic_model(const std::shared_ptr<ov::Model>& model) {
for (const auto& parameter : model->get_parameters()) {
if (is_dynamic_node(parameter)) {
return true;
}
}
for (const auto& result : model->get_results()) {
if (is_dynamic_node(result)) {
return true;
}
}
return false;
}
std::string get_model_type(const std::shared_ptr<ov::Model>& model) {
if (is_dynamic_model(model)) {
return "dynamic";
}
return "static";
}
std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
std::shared_ptr<ICache>& cache,
const std::vector<std::string>& models,
@ -115,6 +136,78 @@ std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
return cache_status;
}
std::map<std::string, InputInfo>
get_input_info_by_model(const std::shared_ptr<ov::Model>& model) {
std::map<std::string, InputInfo> in_info;
for (const auto& node : model->get_ordered_ops()) {
InputInfo::Range ranges(DEFAULT_MIN_VALUE, DEFAULT_MAX_VALUE);
bool is_const = false;
if (ov::op::util::is_constant(node)) {
std::shared_ptr<ov::op::v0::Constant> constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(node);
auto const_ranges = get_const_ranges(constant,
constant->get_default_output().get_element_type());
ranges = const_ranges;
} else if (!ov::op::util::is_parameter(node)) {
continue;
}
auto partial_shape = node->get_default_output().get_partial_shape();
in_info.insert({node->get_friendly_name(),
InputInfo(partial_shape, ranges.min, ranges.max, is_const)});
}
return in_info;
}
std::map<std::string, InputInfo>
align_input_info(const std::shared_ptr<ov::Model>& model,
const std::shared_ptr<ov::Model>& model_ref,
const std::map<std::string, InputInfo>& in_info,
const std::map<std::string, InputInfo>& in_info_ref,
const std::map<std::string, std::string> &matched_op) {
bool is_update_required = !matched_op.empty();
if (!is_update_required) {
for (const auto& ref_item : in_info_ref) {
if (!in_info.count(ref_item.first)) {
is_update_required = true;
break;
} else if (in_info.at(ref_item.first).is_const != ref_item.second.is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
}
}
std::map<std::string, InputInfo> updated_input_info = in_info_ref;
if (is_update_required) {
// align matched model names
const auto& ref_model_ops = model_ref->get_ordered_ops();
const auto& model_ops = model->get_ordered_ops();
size_t ref_ordered_ops_size = ref_model_ops.size();
size_t ordered_ops_size = model_ops.size();
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
throw std::runtime_error("Matched models can not be compared according different op numbers!");
}
for (size_t i = 0; i < ordered_ops_size; ++i) {
auto model_op_name = model_ops[i]->get_friendly_name();
if (!in_info.count(model_op_name)) {
continue;
}
if (!matched_op.empty()) {
if (!matched_op.count(model_op_name)) {
continue;
}
}
auto model_ref_op_name = matched_op.empty() ? ref_model_ops[i]->get_friendly_name() : matched_op.at(model_op_name);
const auto& in_info_item = in_info.at(model_op_name);
const auto& ref_in_info_item = in_info_ref.at(model_ref_op_name);
if (in_info_item.is_const != ref_in_info_item.is_const) {
throw std::runtime_error("Impossible to update input info!!!");
}
updated_input_info[model_ref_op_name] = in_info_item;
}
}
return updated_input_info;
}
} // namespace subgraph_dumper
} // namespace tools
} // namespace ov

View File

@ -0,0 +1,136 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "utils/model_comparator.hpp"
#include "utils/model.hpp"
using namespace ov::tools::subgraph_dumper;
std::shared_ptr<ModelComparator> ModelComparator::m_instance = nullptr;
void ModelComparator::set_match_coefficient(float _match_coefficient) {
if (_match_coefficient < 0 || _match_coefficient > 1) {
throw std::runtime_error("[ ERROR ] Match coefficient should be from 0 to 1!");
}
match_coefficient = _match_coefficient;
}
void ModelComparator::set_shape_strict_match(bool in_is_shape_strict_match) {
m_manager.set_shape_strict_match(in_is_shape_strict_match);
}
inline ModelComparator::IsSubgraphTuple
prepare_is_subgraph_result(bool is_subgraph,
const std::shared_ptr<ov::Model>& subgraph,
const std::shared_ptr<ov::Model>& graph,
const std::map<std::string, std::string>& matched_ops) {
return is_subgraph ?
std::make_tuple(is_subgraph, subgraph, graph, matched_ops) :
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
}
ModelComparator::IsSubgraphTuple
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const {
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
ref_ordered_ops = ref_model->get_ordered_ops();
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
if (is_model) {
graph_to_check_ops = ordered_ops;
subgraph_to_check_ops = ref_ordered_ops;
graph = model;
subgraph = ref_model;
} else {
graph_to_check_ops = ref_ordered_ops;
subgraph_to_check_ops = ordered_ops;
graph = ref_model;
subgraph = model;
}
std::map<std::string, std::string> matched_op_names;
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
if (m_manager.match(*graph_it, *subgraph_it)) {
matched_op_names.insert({ (*subgraph_it)->get_friendly_name(), (*graph_it)->get_friendly_name()});
++subgraph_it;
}
++graph_it;
}
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), subgraph, graph, matched_op_names);
}
bool
ModelComparator::match(const std::shared_ptr<ov::Node> &node,
const std::shared_ptr<ov::Node> &ref_node) const {
return m_manager.match(node, ref_node);
}
bool
ModelComparator::match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model) const {
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
ref_ordered_ops = ref_model->get_ordered_ops();
if (ordered_ops.size() != ref_ordered_ops.size()) {
return false;
}
size_t matched_op_cnt = 0, total_op_cnt = ordered_ops.size();
size_t matched_op_cnt_required = round(match_coefficient * total_op_cnt);
for (size_t i = 0; i < total_op_cnt; ++i) {
if (m_manager.match(ordered_ops[i], ref_ordered_ops[i])) {
++matched_op_cnt;
}
if (matched_op_cnt >= matched_op_cnt_required) {
return true;
}
}
return false;
}
ModelComparator::ExtractedSubgraphTuple
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &ref_model,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref) {
auto extractor_res = is_subgraph(model, ref_model);
if (std::get<0>(extractor_res)) {
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
std::shared_ptr<ov::Model> subgraph = nullptr, graph = nullptr;
// if (model == subgraph && ref_model == graph)
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
subgraph = model;
subgraph_in_info = in_info;
graph = ref_model;
graph_in_info = in_info_ref;
// else if (subgraph == ref_model && graph = model)
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
subgraph = ref_model;
subgraph_in_info = in_info_ref;
graph = model;
graph_in_info = in_info;
} else {
throw std::runtime_error("Generated models are incompatible with original ones!");
}
try {
subgraph_in_info = align_input_info(subgraph, graph, subgraph_in_info, graph_in_info);
return { true, subgraph, graph, subgraph_in_info, graph_in_info };
} catch(std::exception) {}
}
return { false, nullptr, nullptr, {}, {} };
}
std::pair<bool, std::map<std::string, InputInfo>>
ModelComparator::match(const std::shared_ptr<ov::Model> &model,
const std::shared_ptr<ov::Model> &model_ref,
const std::map<std::string, InputInfo> &in_info,
const std::map<std::string, InputInfo> &in_info_ref) {
try {
if (match(model, model_ref)) {
auto new_input_info = align_input_info(model, model_ref, in_info, in_info_ref);
return {true, new_input_info};
}
} catch (std::exception) {}
return {false, {}};
}

View File

@ -7,6 +7,73 @@ namespace ov {
namespace tools {
namespace subgraph_dumper {
InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& const_node,
ov::element::Type elem_type) {
InputInfo::Range ranges(DEFAULT_MIN_VALUE, DEFAULT_MAX_VALUE);
switch (elem_type) {
case ov::element::Type_t::boolean: {
ranges = get_const_ranges<bool>(const_node);
break;
}
case ov::element::Type_t::bf16: {
ranges = get_const_ranges<ov::bfloat16>(const_node);
break;
}
case ov::element::Type_t::f16: {
ranges = get_const_ranges<ov::float16>(const_node);
break;
}
case ov::element::Type_t::f32: {
ranges = get_const_ranges<float>(const_node);
break;
}
case ov::element::Type_t::f64: {
ranges = get_const_ranges<double>(const_node);
break;
}
case ov::element::Type_t::i8: {
ranges = get_const_ranges<int8_t>(const_node);
break;
}
case ov::element::Type_t::i16: {
ranges = get_const_ranges<int16_t>(const_node);
break;
}
case ov::element::Type_t::i32: {
ranges = get_const_ranges<int32_t>(const_node);
break;
}
case ov::element::Type_t::i64: {
ranges = get_const_ranges<int64_t>(const_node);
break;
}
// TODO cast_vector doesn't support u1 now
// case ov::element::Type_t::u1:
// return get_const_ranges<char>(const_node);
case ov::element::Type_t::u8: {
ranges = get_const_ranges<uint8_t>(const_node);
break;
}
case ov::element::Type_t::u16: {
ranges = get_const_ranges<uint16_t>(const_node);
break;
}
case ov::element::Type_t::u32: {
ranges = get_const_ranges<uint32_t>(const_node);
break;
}
case ov::element::Type_t::u64: {
ranges = get_const_ranges<uint64_t>(const_node);
break;
}
default: {
std::cout << "Can't get ranges.. Unsupported data type" << std::endl;
break;
}
}
return ranges;
}
std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov::Node>& node) {
std::map<std::string, InputInfo> input_info;
for (size_t port_id = 0; port_id < node->get_input_size(); ++port_id) {
@ -19,71 +86,12 @@ std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov
if (std::dynamic_pointer_cast<ov::op::v0::Constant>(input_node)) {
if (ov::shape_size(input_node->get_output_shape(0)) == 0)
continue;
auto const_node =
std::dynamic_pointer_cast<ov::op::v0::Constant>(input_node);
auto const_node = ov::as_type_ptr<ov::op::v0::Constant>(input_node);
in_info.is_const = true;
switch (node->get_output_element_type(0)) {
case ov::element::Type_t::boolean: {
in_info.ranges = get_const_ranges<bool>(const_node);
break;
}
case ov::element::Type_t::bf16: {
in_info.ranges = get_const_ranges<ov::bfloat16>(const_node);
break;
}
case ov::element::Type_t::f16: {
in_info.ranges = get_const_ranges<ov::float16>(const_node);
break;
}
case ov::element::Type_t::f32: {
in_info.ranges = get_const_ranges<float>(const_node);
break;
}
case ov::element::Type_t::f64: {
in_info.ranges = get_const_ranges<double>(const_node);
break;
}
case ov::element::Type_t::i8: {
in_info.ranges = get_const_ranges<int8_t>(const_node);
break;
}
case ov::element::Type_t::i16: {
in_info.ranges = get_const_ranges<int16_t>(const_node);
break;
}
case ov::element::Type_t::i32: {
in_info.ranges = get_const_ranges<int32_t>(const_node);
break;
}
case ov::element::Type_t::i64: {
in_info.ranges = get_const_ranges<int64_t>(const_node);
break;
}
// TODO cast_vector doesn't support u1 now
// case ov::element::Type_t::u1:
// return get_const_ranges<char>(const_node);
case ov::element::Type_t::u8: {
in_info.ranges = get_const_ranges<uint8_t>(const_node);
break;
}
case ov::element::Type_t::u16: {
in_info.ranges = get_const_ranges<uint16_t>(const_node);
break;
}
case ov::element::Type_t::u32: {
in_info.ranges = get_const_ranges<uint32_t>(const_node);
break;
}
case ov::element::Type_t::u64: {
in_info.ranges = get_const_ranges<uint64_t>(const_node);
break;
}
default: {
std::cout << "Can't get ranges.. Unsupported data type" << std::endl;
break;
}}
in_info.ranges = get_const_ranges(const_node,
const_node->get_default_output().get_element_type());
}
input_info.insert({ input_name, in_info });
input_info.insert({input_name, in_info});
}
return input_info;
}
@ -128,9 +136,10 @@ std::shared_ptr<ov::Node> clone_node(std::shared_ptr<ov::Node> node,
std::shared_ptr<ov::Node> cloned_node = nullptr;
if (!has_parameters && !is_copy_const_node && !inputs.empty()) {
cloned_node = clone_node(node, true, true, node_name);
// std::cout << "The operation: " + node->get_friendly_name() + " does not have parameters! Replace first input to parameter!" << std::endl;
auto param =
std::make_shared<ov::op::v0::Parameter>(cloned_node->get_input_element_type(0), cloned_node->get_input_partial_shape(0));
// std::cout << "The operation: " + node->get_friendly_name() + " does not have parameters! Replace first input
// to parameter!" << std::endl;
auto param = std::make_shared<ov::op::v0::Parameter>(cloned_node->get_input_element_type(0),
cloned_node->get_input_partial_shape(0));
std::string param_name = node_name + "_0";
param->set_friendly_name(param_name);
auto node_to_replace = cloned_node->get_input_node_shared_ptr(0);
@ -142,10 +151,11 @@ std::shared_ptr<ov::Node> clone_node(std::shared_ptr<ov::Node> node,
return cloned_node;
}
std::shared_ptr<ov::op::v0::Parameter> convert_const_to_param(const std::shared_ptr<ov::op::v0::Constant>& op_to_replace) {
std::shared_ptr<ov::op::v0::Parameter> convert_const_to_param(
const std::shared_ptr<ov::op::v0::Constant>& op_to_replace) {
if (op_to_replace->get_byte_size() > 1024) {
auto param = std::make_shared<ov::op::v0::Parameter>(
op_to_replace->get_output_element_type(0), op_to_replace->get_output_partial_shape(0));
auto param = std::make_shared<ov::op::v0::Parameter>(op_to_replace->get_output_element_type(0),
op_to_replace->get_output_partial_shape(0));
param->set_friendly_name(op_to_replace->get_friendly_name());
if (param != nullptr) {
ov::replace_node(op_to_replace, param);

View File

@ -93,4 +93,22 @@ TEST_F(ICacheUnitTest, serialize_model) {
}
}
TEST_F(ICacheUnitTest, is_model_large_to_read) {
this->mem_size = 0;
ASSERT_NO_THROW(this->is_model_large_to_read(test_model, test_model_path));
ASSERT_TRUE(this->is_model_large_to_read(test_model, test_model_path));
this->mem_size = 1 << 30;
ASSERT_NO_THROW(this->is_model_large_to_read(test_model, test_model_path));
ASSERT_FALSE(this->is_model_large_to_read(test_model, test_model_path));
}
TEST_F(ICacheUnitTest, is_model_large_to_store_const) {
this->mem_size = 0;
ASSERT_NO_THROW(this->is_model_large_to_store_const(test_model));
ASSERT_TRUE(this->is_model_large_to_store_const(test_model));
this->mem_size = 1 << 30;
ASSERT_NO_THROW(this->is_model_large_to_store_const(test_model));
ASSERT_FALSE(this->is_model_large_to_store_const(test_model));
}
} // namespace

View File

@ -8,6 +8,7 @@
#include "matchers/subgraph/fused_names.hpp"
#include "utils/model.hpp"
#include "utils/model_comparator.hpp"
#include "test_models/model_0.hpp"
#include "test_models/model_1.hpp"
@ -32,7 +33,7 @@ protected:
auto it_model_2 = models_2.begin();
while (it_model_1 != models_1.end() || it_model_2 != models_2.end()) {
SubgraphExtractor extractor;
ASSERT_TRUE(extractor.match(std::get<0>(*it_model_1), std::get<0>(*it_model_2)));
ASSERT_TRUE(ModelComparator::get()->match(std::get<0>(*it_model_1), std::get<0>(*it_model_2)));
auto in_info_1 = std::get<1>(*it_model_1);
auto in_info_2 = std::get<1>(*it_model_2);
for (const auto& in_info : in_info_1) {

View File

@ -35,31 +35,10 @@ protected:
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
test_parameter->set_friendly_name("test_parameter_1");
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
std::shared_ptr<ov::op::v0::Relu> test_abs =
std::make_shared<ov::op::v0::Relu>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
}
ExtractorsManager::ExtractorsMap test_map;
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
std::shared_ptr<ov::Model> test_model_0_0;
};
TEST_F(ExtractorsManagerTest, constructor) {
@ -78,57 +57,9 @@ TEST_F(ExtractorsManagerTest, get_extractors) {
ASSERT_EQ(this->m_extractors, this->get_extractors());
}
TEST_F(ExtractorsManagerTest, match) {
this->set_extractors(test_map);
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
}
TEST_F(ExtractorsManagerTest, is_subgraph) {
this->set_extractors(test_map);
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_1));
auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_1);
ASSERT_TRUE(std::get<0>(is_subgraph));
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_0_0, test_model_1)));
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_0_1, test_model_1)));
}
TEST_F(ExtractorsManagerTest, match_with_in_info) {
this->set_extractors(test_map);
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_1({{"test_parameter_1", InputInfo({}, 1, 2, true)}});
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
ASSERT_FALSE(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
ASSERT_FALSE(this->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
}
TEST_F(ExtractorsManagerTest, extract) {
this->set_extractors(test_map);
ASSERT_NO_THROW(this->extract(test_model_0_0));
}
TEST_F(ExtractorsManagerTest, align_input_info) {
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_ref({{"test_parameter_1", InputInfo()}});
ASSERT_NE(test_in_info, test_in_info_ref);
ASSERT_NO_THROW(this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref));
auto c = this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref);
ASSERT_EQ(c, test_in_info_ref);
}
TEST_F(ExtractorsManagerTest, align_input_info_for_subgraphs) {
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_ref({{"test_parameter_1", InputInfo()}});
ASSERT_NE(test_in_info, test_in_info_ref);
ASSERT_NO_THROW(this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref, {{"test_parameter_0", "test_parameter_1"}}));
auto c = this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref);
ASSERT_EQ(c, test_in_info_ref);
}
} // namespace

View File

@ -6,6 +6,7 @@
#include "matchers/subgraph/repeat_pattern.hpp"
#include "utils/model.hpp"
#include "utils/model_comparator.hpp"
#include "base_test.hpp"
#include "test_models/model_0.hpp"
@ -22,13 +23,13 @@ class RepeatPatternExtractorTest : public SubgraphsDumperBaseTest {
protected:
RepeatPatternExtractor extractor;
bool is_match(const std::list<ExtractedPattern>& models,
bool is_match(const std::vector<ExtractedPattern>& models,
const std::vector<std::shared_ptr<ov::Model>>& ref_models) {
size_t match_numbers = 0;
for (const auto& model : models) {
bool is_match = false;
for (const auto& ref_model : ref_models) {
if (extractor.match(std::get<0>(model), ref_model)) {
if (ModelComparator::get()->match(std::get<0>(model), ref_model)) {
is_match = true;
++match_numbers;
break;
@ -40,6 +41,28 @@ protected:
}
return match_numbers == models.size();
}
void sort_node_vec(std::vector<std::vector<ov::NodeVector>>& pattern_vec) {
for (auto& pattern : pattern_vec) {
for (auto& node_vec : pattern) {
std::sort(node_vec.begin(), node_vec.end());
}
std::sort(pattern.begin(), pattern.end());
}
std::sort(pattern_vec.begin(), pattern_vec.end());
}
// not allowed to sort inputs/outputs according there are not copy constructor
// void sort_borders(std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec) {
// for (auto& pattern : pattern_vec) {
// for (auto& node_vec : pattern) {
// std::sort(node_vec.first.begin(), node_vec.first.end());
// std::sort(node_vec.second.begin(), node_vec.second.end());
// }
// std::sort(pattern.begin(), pattern.end());
// }
// std::sort(pattern_vec.begin(), pattern_vec.end());
// }
};
TEST_F(RepeatPatternExtractorTest, extract_0) {
@ -63,4 +86,59 @@ TEST_F(RepeatPatternExtractorTest, extract_2) {
ASSERT_TRUE(is_match(models, ref));
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
auto test_model = Model_0();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
sort_node_vec(node_vector);
sort_node_vec(ref);
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
auto test_model = Model_1();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
sort_node_vec(node_vector);
sort_node_vec(ref);
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
auto test_model = Model_2();
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
auto ref = test_model.get_ref_node_vector();
sort_node_vec(node_vector);
sort_node_vec(ref);
ASSERT_EQ(node_vector, ref);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_0) {
auto test_model = Model_0();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_1) {
auto test_model = Model_1();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
}
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_2) {
auto test_model = Model_2();
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
auto ref_borders = test_model.get_ref_node_borders();
// sort_borders(extracted_borders);
// sort_borders(ref_borders);
ASSERT_EQ(extracted_borders, ref_borders);
}
} // namespace

View File

@ -1,209 +1,209 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// // Copyright (C) 2018-2023 Intel Corporation
// // SPDX-License-Identifier: Apache-2.0
// //
#include "matchers/subgraph/subgraph.hpp"
#include "base_test.hpp"
// #include "matchers/subgraph/subgraph.hpp"
// #include "base_test.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
// #include "openvino/op/abs.hpp"
// #include "openvino/op/relu.hpp"
// #include "openvino/op/parameter.hpp"
// #include "openvino/op/result.hpp"
namespace {
// namespace {
using namespace ov::tools::subgraph_dumper;
// using namespace ov::tools::subgraph_dumper;
// ======================= ExtractorsManagerTest Unit tests =======================
class SubgraphExtractorTest : public SubgraphExtractor,
public SubgraphsDumperBaseTest {
protected:
void SetUp() override {
SubgraphsDumperBaseTest::SetUp();
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
std::shared_ptr<ov::op::v0::Relu> test_abs =
std::make_shared<ov::op::v0::Relu>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
}
// // ======================= ExtractorsManagerTest Unit tests =======================
// class SubgraphExtractorTest : public SubgraphExtractor,
// public SubgraphsDumperBaseTest {
// protected:
// void SetUp() override {
// SubgraphsDumperBaseTest::SetUp();
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
// std::shared_ptr<ov::op::v0::Abs> test_abs =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs);
// test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
// std::shared_ptr<ov::op::v0::Abs> test_abs =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs);
// test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
// std::shared_ptr<ov::op::v0::Relu> test_abs =
// std::make_shared<ov::op::v0::Relu>(test_parameter);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs);
// test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// }
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
};
// std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
// };
TEST_F(SubgraphExtractorTest, match) {
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
}
// TEST_F(SubgraphExtractorTest, match) {
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
// ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
// ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
// ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
// ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
// }
TEST_F(SubgraphExtractorTest, match_90_percent) {
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
std::make_shared<ov::op::v0::Abs>(test_abs_0);
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
std::make_shared<ov::op::v0::Abs>(test_abs_1);
std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
std::make_shared<ov::op::v0::Abs>(test_abs_2);
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
std::make_shared<ov::op::v0::Abs>(test_abs_3);
std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
std::make_shared<ov::op::v0::Abs>(test_abs_4);
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
std::make_shared<ov::op::v0::Abs>(test_abs_5);
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
std::make_shared<ov::op::v0::Abs>(test_abs_6);
std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
std::make_shared<ov::op::v0::Abs>(test_abs_7);
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
std::make_shared<ov::op::v0::Abs>(test_abs_8);
std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
std::make_shared<ov::op::v0::Abs>(test_abs_9);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs_10);
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
std::make_shared<ov::op::v0::Abs>(test_abs_0);
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
std::make_shared<ov::op::v0::Abs>(test_abs_1);
std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
std::make_shared<ov::op::v0::Abs>(test_abs_2);
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
std::make_shared<ov::op::v0::Abs>(test_abs_3);
std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
std::make_shared<ov::op::v0::Abs>(test_abs_4);
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
std::make_shared<ov::op::v0::Abs>(test_abs_5);
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
std::make_shared<ov::op::v0::Abs>(test_abs_6);
std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
std::make_shared<ov::op::v0::Abs>(test_abs_7);
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
std::make_shared<ov::op::v0::Abs>(test_abs_8);
std::shared_ptr<ov::op::v0::Relu> test_abs_10 =
std::make_shared<ov::op::v0::Relu>(test_abs_9);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs_10);
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Relu> test_abs_1 =
std::make_shared<ov::op::v0::Relu>(test_abs_0);
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
std::make_shared<ov::op::v0::Abs>(test_abs_1);
std::shared_ptr<ov::op::v0::Relu> test_abs_3 =
std::make_shared<ov::op::v0::Relu>(test_abs_2);
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
std::make_shared<ov::op::v0::Abs>(test_abs_3);
std::shared_ptr<ov::op::v0::Relu> test_abs_5 =
std::make_shared<ov::op::v0::Relu>(test_abs_4);
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
std::make_shared<ov::op::v0::Abs>(test_abs_5);
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
std::make_shared<ov::op::v0::Abs>(test_abs_6);
std::shared_ptr<ov::op::v0::Relu> test_abs_8 =
std::make_shared<ov::op::v0::Relu>(test_abs_7);
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
std::make_shared<ov::op::v0::Abs>(test_abs_8);
std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
std::make_shared<ov::op::v0::Abs>(test_abs_9);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs_10);
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
}
// TEST_F(SubgraphExtractorTest, match_90_percent) {
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
// std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
// std::make_shared<ov::op::v0::Abs>(test_abs_2);
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
// std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
// std::make_shared<ov::op::v0::Abs>(test_abs_4);
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
// std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
// std::make_shared<ov::op::v0::Abs>(test_abs_7);
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
// std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
// std::make_shared<ov::op::v0::Abs>(test_abs_9);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs_10);
// test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
// std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
// std::make_shared<ov::op::v0::Abs>(test_abs_2);
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
// std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
// std::make_shared<ov::op::v0::Abs>(test_abs_4);
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
// std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
// std::make_shared<ov::op::v0::Abs>(test_abs_7);
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
// std::shared_ptr<ov::op::v0::Relu> test_abs_10 =
// std::make_shared<ov::op::v0::Relu>(test_abs_9);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs_10);
// test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Relu> test_abs_1 =
// std::make_shared<ov::op::v0::Relu>(test_abs_0);
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
// std::shared_ptr<ov::op::v0::Relu> test_abs_3 =
// std::make_shared<ov::op::v0::Relu>(test_abs_2);
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
// std::shared_ptr<ov::op::v0::Relu> test_abs_5 =
// std::make_shared<ov::op::v0::Relu>(test_abs_4);
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
// std::shared_ptr<ov::op::v0::Relu> test_abs_8 =
// std::make_shared<ov::op::v0::Relu>(test_abs_7);
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
// std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
// std::make_shared<ov::op::v0::Abs>(test_abs_9);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs_10);
// test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// }
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
// ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
// ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
// ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
// ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
// }
TEST_F(SubgraphExtractorTest, extract) {
ASSERT_NO_THROW(this->extract(test_model_0_0));
ASSERT_NO_THROW(this->extract(test_model_0_1));
ASSERT_NO_THROW(this->extract(test_model_1));
}
// TEST_F(SubgraphExtractorTest, extract) {
// ASSERT_NO_THROW(this->extract(test_model_0_0));
// ASSERT_NO_THROW(this->extract(test_model_0_1));
// ASSERT_NO_THROW(this->extract(test_model_1));
// }
TEST_F(SubgraphExtractorTest, is_subgraph) {
auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_0);
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_0));
ASSERT_TRUE(std::get<0>(is_subgraph));
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
is_subgraph = this->is_subgraph(test_model_0_0, test_model_1);
ASSERT_FALSE(std::get<0>(is_subgraph));
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
is_subgraph = this->is_subgraph(test_model_0_1, test_model_1);
ASSERT_FALSE(std::get<0>(is_subgraph));
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
std::make_shared<ov::op::v0::Abs>(test_abs_0);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs_1);
auto big_model_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
is_subgraph = this->is_subgraph(test_model_0_0, big_model_0);
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, big_model_0));
ASSERT_TRUE(std::get<0>(is_subgraph));
ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_0);
// TEST_F(SubgraphExtractorTest, is_subgraph) {
// auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_0);
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_0));
// ASSERT_TRUE(std::get<0>(is_subgraph));
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
// is_subgraph = this->is_subgraph(test_model_0_0, test_model_1);
// ASSERT_FALSE(std::get<0>(is_subgraph));
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
// is_subgraph = this->is_subgraph(test_model_0_1, test_model_1);
// ASSERT_FALSE(std::get<0>(is_subgraph));
// {
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
// std::make_shared<ov::op::v0::Abs>(test_parameter);
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
// std::shared_ptr<ov::op::v0::Result> test_res =
// std::make_shared<ov::op::v0::Result>(test_abs_1);
// auto big_model_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
// ov::ParameterVector{test_parameter});
// is_subgraph = this->is_subgraph(test_model_0_0, big_model_0);
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, big_model_0));
// ASSERT_TRUE(std::get<0>(is_subgraph));
// ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
// ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_0);
is_subgraph = this->is_subgraph(test_model_0_1, big_model_0);
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, big_model_0));
ASSERT_TRUE(std::get<0>(is_subgraph));
ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_1);
ASSERT_NO_THROW(this->is_subgraph(test_model_1, big_model_0));
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_1, big_model_0)));
}
}
// is_subgraph = this->is_subgraph(test_model_0_1, big_model_0);
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, big_model_0));
// ASSERT_TRUE(std::get<0>(is_subgraph));
// ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
// ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_1);
// ASSERT_NO_THROW(this->is_subgraph(test_model_1, big_model_0));
// ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_1, big_model_0)));
// }
// }
} // namespace
// } // namespace

View File

@ -11,8 +11,12 @@
#include "openvino/op/relu.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "matchers/subgraph/repeat_pattern.hpp"
class Model_0 {
private:
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
public:
Model_0() {
// param param
@ -48,6 +52,13 @@ public:
std::make_shared<ov::op::v0::Result>(test_add_0);
model = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter_0, test_parameter_1});
ref_nodes = {{{test_abs_0, test_relu_0}, {test_abs_1, test_relu_1}}};
{
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
ref_pattern_1 = {test_abs_1->inputs(), test_relu_1->outputs()};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_1}};
ref_borders = std::move(ref_res);
}
}
std::shared_ptr<ov::Model> get() {
@ -72,6 +83,14 @@ public:
return ref;
}
std::vector<std::vector<ov::NodeVector>>
get_ref_node_vector() { return ref_nodes; }
std::vector<std::vector<PatternBorders>>
get_ref_node_borders() { return ref_borders; }
protected:
std::shared_ptr<ov::Model> model;
std::vector<std::vector<ov::NodeVector>> ref_nodes;
std::vector<std::vector<PatternBorders>> ref_borders;
};

View File

@ -12,8 +12,12 @@
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/subtract.hpp"
#include "matchers/subgraph/repeat_pattern.hpp"
class Model_1 {
private:
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
public:
Model_1() {
// param param param param
@ -119,6 +123,22 @@ public:
ov::ParameterVector{test_parameter_0, test_parameter_1,
test_parameter_0_0, test_parameter_0_1,
test_parameter_1_0, test_parameter_1_1});
ref_nodes = {{{test_abs_0, test_relu_0}, {test_abs_0_0, test_relu_0_0}},
{{test_abs_1, test_clamp_1}, {test_abs_0_1, test_clamp_0_1}},
{{test_multiply_0_1, test_relu_0_1}, {test_multiply_1_1, test_relu_1_1}}};
{
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
ref_pattern_0_0 = {test_abs_0_0->inputs(), test_relu_0_0->outputs()},
ref_pattern_1 = {test_abs_1->inputs(), test_clamp_1->outputs()},
ref_pattern_0_1_0 = {test_abs_0_1->inputs(), test_clamp_0_1->outputs()},
test_pattern_0_1_1 = {test_multiply_0_1->inputs(), test_relu_0_1->outputs()},
test_pattern_1_1 = {test_multiply_1_1->inputs(), test_relu_1_1->outputs()};
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_0_0},
{ref_pattern_1, ref_pattern_0_1_0},
{test_pattern_0_1_1, test_pattern_1_1}};
ref_borders = std::move(ref_res);
}
}
std::shared_ptr<ov::Model> get() {
@ -166,10 +186,19 @@ public:
std::make_shared<ov::op::v0::Result>(test_relu_1);
auto ref_model = std::make_shared<ov::Model>(ov::ResultVector{res},
ov::ParameterVector{test_parameter_1_0, test_parameter_1_1});
ref.push_back(ref_model);
}
return ref;
}
std::vector<std::vector<ov::NodeVector>>
get_ref_node_vector() { return ref_nodes; }
std::vector<std::vector<PatternBorders>>
get_ref_node_borders() { return ref_borders; }
protected:
std::shared_ptr<ov::Model> model;
std::vector<std::vector<ov::NodeVector>> ref_nodes;
std::vector<std::vector<PatternBorders>> ref_borders;
};

View File

@ -13,6 +13,9 @@
#include "openvino/op/result.hpp"
class Model_2 {
private:
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
public:
Model_2() {
// param
@ -55,9 +58,17 @@ public:
}
std::vector<std::shared_ptr<ov::Model>> get_repeat_pattern_ref() {
return {};
return std::vector<std::shared_ptr<ov::Model>>();
}
std::vector<std::vector<ov::NodeVector>>
get_ref_node_vector() { return ref_nodes; }
std::vector<std::vector<PatternBorders>>
get_ref_node_borders() { return ref_borders; }
protected:
std::shared_ptr<ov::Model> model;
std::vector<std::vector<ov::NodeVector>> ref_nodes;
std::vector<std::vector<PatternBorders>> ref_borders;
};

View File

@ -4,6 +4,7 @@
#include "openvino/op/util/op_types.hpp"
#include "utils/model.hpp"
#include "utils/model_comparator.hpp"
#include "matchers/subgraph/subgraph.hpp"
#include "test_models/model_0.hpp"
#include "test_models/model_1.hpp"
@ -16,11 +17,11 @@ using namespace ov::tools::subgraph_dumper;
using ModelUtilsTest = SubgraphsDumperBaseTest;
std::set<std::shared_ptr<ov::Node>>
ov::NodeVector
get_functional_ops(const std::shared_ptr<ov::Model>& model) {
std::set<std::shared_ptr<ov::Node>> nodes;
std::vector<std::shared_ptr<ov::Node>> nodes;
for (const auto& op : model->get_ordered_ops()) {
nodes.insert(op);
nodes.push_back(op);
}
return nodes;
}
@ -31,12 +32,11 @@ TEST_F(ModelUtilsTest, generate_0) {
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = generate_model(func_ops, checked_ops, "test_extractor");
auto model_with_in_info = generate_model(func_ops, checked_ops);
recovered_model = std::get<0>(model_with_in_info);
}
{
SubgraphExtractor extractor;
ASSERT_TRUE(extractor.match(test_model, recovered_model));
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
}
}
@ -46,12 +46,11 @@ TEST_F(ModelUtilsTest, generate_1) {
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = generate_model(func_ops, checked_ops, "test_extractor");
auto model_with_in_info = generate_model(func_ops, checked_ops);
recovered_model = std::get<0>(model_with_in_info);
}
{
SubgraphExtractor extractor;
ASSERT_TRUE(extractor.match(test_model, recovered_model));
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
}
}
@ -61,14 +60,59 @@ TEST_F(ModelUtilsTest, generate_2) {
{
std::unordered_set<std::string> checked_ops;
auto func_ops = get_functional_ops(test_model);
auto model_with_in_info = generate_model(func_ops, checked_ops, "extract_model");
auto model_with_in_info = generate_model(func_ops, checked_ops);
recovered_model = std::get<0>(model_with_in_info);
auto in_info = std::get<1>(model_with_in_info);
}
{
SubgraphExtractor extractor;
ASSERT_TRUE(extractor.match(test_model, recovered_model));
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
}
}
TEST_F(ModelUtilsTest, align_input_info) {
Model_0 test_model_0, test_model_1;
auto in_info_0 = get_input_info_by_model(test_model_0.get());
auto in_info_1 = get_input_info_by_model(test_model_1.get());
ASSERT_NE(in_info_0, in_info_1);
ASSERT_NO_THROW(align_input_info(test_model_0.get(), test_model_1.get(), in_info_0, in_info_1));
auto in_info_ref = align_input_info(test_model_0.get(), test_model_1.get(), in_info_0, in_info_1);
ASSERT_EQ(in_info_1, in_info_ref);
}
TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
Model_0 model_0, model_1;
auto test_model_0 = model_0.get();
auto test_model_1 = model_1.get();
auto in_info_0 = get_input_info_by_model(test_model_0);
auto in_info_1 = get_input_info_by_model(test_model_1);
ASSERT_NE(in_info_0, in_info_1);
std::map<std::string, std::string> matched_ops;
auto params_0 = test_model_0->get_parameters();
auto params_1 = test_model_1->get_parameters();
size_t params_cnt = params_0.size();
for (size_t param_id = 0; param_id < params_cnt; ++param_id) {
matched_ops.insert({params_0[param_id]->get_friendly_name(),
params_1[param_id]->get_friendly_name()});
}
ASSERT_NO_THROW(align_input_info(test_model_0, test_model_1,
in_info_0, in_info_1,
matched_ops));
auto ref = align_input_info(test_model_0, test_model_1, in_info_0, in_info_1, matched_ops);
ASSERT_EQ(in_info_1, ref);
}
TEST_F(ModelUtilsTest, get_input_info_by_model) {
Model_1 model;
auto test_model = model.get();
size_t param_idx = 0;
std::map<std::string, InputInfo> ref;
for (auto& param : test_model->get_parameters()) {
std::string param_name = "parameter_" + std::to_string(param_idx++);
param->set_friendly_name(param_name);
ref.insert({param_name, InputInfo(param->get_default_output().get_partial_shape())});
}
auto cur = get_input_info_by_model(test_model);
ASSERT_EQ(cur, ref);
}
} // namespace

View File

@ -0,0 +1,137 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "matchers/subgraph/subgraph.hpp"
#include "utils/model_comparator.hpp"
#include "base_test.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
namespace {
using namespace ov::tools::subgraph_dumper;
// ======================= ExtractorsManagerTest Unit tests =======================
class ModelComparatorTest : public SubgraphsDumperBaseTest {
protected:
void SetUp() override {
SubgraphsDumperBaseTest::SetUp();
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
test_parameter->set_friendly_name("test_parameter_0");
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
test_parameter->set_friendly_name("test_parameter_1");
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
std::shared_ptr<ov::op::v0::Relu> test_abs =
std::make_shared<ov::op::v0::Relu>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
}
void TearDown() override {
ModelComparator::Ptr model_comparator = ModelComparator::get();
model_comparator->set_shape_strict_match(false);
model_comparator->set_match_coefficient(0.9f);
}
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
};
TEST_F(ModelComparatorTest, get) {
ModelComparator::Ptr model_comparator = nullptr;
ASSERT_NO_THROW(model_comparator = ModelComparator::get());
ASSERT_EQ(model_comparator, ModelComparator::get());
}
TEST_F(ModelComparatorTest, match) {
ModelComparator::Ptr model_comparator = ModelComparator::get();
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_1));
ASSERT_FALSE(model_comparator->match(test_model_0_0, test_model_1));
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1));
ASSERT_FALSE(model_comparator->match(test_model_0_1, test_model_1));
}
TEST_F(ModelComparatorTest, match_strict_shape) {
ModelComparator::Ptr model_comparator = ModelComparator::get();
ASSERT_NO_THROW(model_comparator->set_shape_strict_match(true));
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
ASSERT_FALSE(model_comparator->match(test_model_0_0, test_model_0_1));
{
{
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
test_parameter->set_friendly_name("test_parameter_1");
std::shared_ptr<ov::op::v0::Abs> test_abs =
std::make_shared<ov::op::v0::Abs>(test_parameter);
std::shared_ptr<ov::op::v0::Result> test_res =
std::make_shared<ov::op::v0::Result>(test_abs);
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
ov::ParameterVector{test_parameter});
}
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
}
}
TEST_F(ModelComparatorTest, match_with_low_coeff) {
ModelComparator::Ptr model_comparator = ModelComparator::get();
model_comparator->set_match_coefficient(0.5f);
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_1));
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_1));
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1));
ASSERT_TRUE(model_comparator->match(test_model_0_1, test_model_1));
}
TEST_F(ModelComparatorTest, match_with_in_info) {
ModelComparator::Ptr model_comparator = ModelComparator::get();
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}),
test_in_info_1({{"test_parameter_1", InputInfo({}, 1, 2, true)}});
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info)));
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1)));
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info)));
}
TEST_F(ModelComparatorTest, is_subgraph) {
ModelComparator::Ptr model_comparator = ModelComparator::get();
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_0, test_model_0_1));
auto is_subgraph = model_comparator->is_subgraph(test_model_0_0, test_model_0_1);
ASSERT_TRUE(std::get<0>(is_subgraph));
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_0, test_model_1));
ASSERT_FALSE(std::get<0>(model_comparator->is_subgraph(test_model_0_0, test_model_1)));
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_1, test_model_1));
ASSERT_FALSE(std::get<0>(model_comparator->is_subgraph(test_model_0_1, test_model_1)));
}
} // namespace