[CONFORMANCE][TOOLS] Repeat pattern extractor API (#20293)
* Prepare API * Refactor api * Move model comparation to separate component * Cover by tests * Move align_in_info to utils * Change arch diagram
This commit is contained in:
parent
29475c738e
commit
74690d038b
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:caa4f76ba61548d1b60d7de1f78fb48dccbf5337117240353a9581f23c88bfa9
|
||||
size 216595
|
||||
oid sha256:45578db1c9ac5362340ea35fc8fa024e992c8beeb30e984d969ee80217c9031b
|
||||
size 342214
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "cache/cache.hpp"
|
||||
#include "cache/meta/input_info.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
#include "matchers/subgraph/manager.hpp"
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
#include "matchers/subgraph/fused_names.hpp"
|
||||
@ -42,10 +43,12 @@ public:
|
||||
|
||||
protected:
|
||||
std::map<std::shared_ptr<ov::Model>, MetaInfo> m_graph_cache;
|
||||
ExtractorsManager m_manager = ExtractorsManager();
|
||||
static std::shared_ptr<GraphCache> m_cache_instance;
|
||||
// cache byte size
|
||||
uint64_t m_graph_cache_bytesize = 0;
|
||||
ExtractorsManager m_manager;
|
||||
ModelComparator::Ptr m_model_comparator = ModelComparator::get();
|
||||
std::shared_ptr<ov::Model> model_to_update = nullptr;
|
||||
static std::shared_ptr<GraphCache> m_cache_instance;
|
||||
|
||||
GraphCache(const std::string& device = "") {
|
||||
ExtractorsManager::ExtractorsMap matchers = {
|
||||
@ -59,7 +62,7 @@ protected:
|
||||
|
||||
void update_cache(const std::shared_ptr<ov::Model>& model,
|
||||
const std::string& model_path,
|
||||
std::map<std::string, InputInfo>& input_info,
|
||||
const std::map<std::string, InputInfo>& input_info,
|
||||
const std::string& extractor_name,
|
||||
size_t model_op_cnt);
|
||||
};
|
||||
|
@ -59,6 +59,9 @@ struct InputInfo {
|
||||
}
|
||||
|
||||
InputInfo operator=(const InputInfo& input_info) {
|
||||
if (this->is_const != input_info.is_const) {
|
||||
throw std::runtime_error("Cast Const to Parameter! Impossible to update Input Info!");
|
||||
}
|
||||
this->ranges = input_info.ranges;
|
||||
if (ov::shape_size(this->max_shape.get_max_shape()) < ov::shape_size(input_info.max_shape.get_max_shape())) {
|
||||
this->max_shape = input_info.max_shape;
|
||||
|
@ -33,7 +33,10 @@ public:
|
||||
const std::string& extractor = "",
|
||||
const std::vector<std::string>& ignored_inputs = {});
|
||||
std::map<std::string, InputInfo> get_input_info() const;
|
||||
void set_input_info(const std::map<std::string, InputInfo>& new_in_info) { input_info = new_in_info; };
|
||||
void set_input_info(const std::map<std::string, InputInfo>& new_in_info) {
|
||||
input_info.clear();
|
||||
input_info = new_in_info;
|
||||
};
|
||||
std::map<std::string, ModelInfo> get_model_info() const;
|
||||
std::string get_any_extractor() const { return *extractors.begin(); }
|
||||
|
||||
|
@ -20,6 +20,8 @@ public:
|
||||
const std::shared_ptr<ov::Node> &ref) const;
|
||||
|
||||
void set_matchers(const MatchersMap& matchers = {}) { m_matchers = matchers; }
|
||||
void set_shape_strict_match(bool shape_strict_match);
|
||||
|
||||
const MatchersMap& get_matchers() { return m_matchers; }
|
||||
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
|
||||
|
||||
|
@ -21,6 +21,7 @@ public:
|
||||
const std::shared_ptr<ov::Node> &ref) const;
|
||||
|
||||
iMatcherConfig::Ptr get_config(const std::shared_ptr<ov::Node> &node) const;
|
||||
void set_strict_shape_match(bool strict_shape_match);
|
||||
|
||||
protected:
|
||||
virtual void configure(const pugi::xml_document &cfg) {};
|
||||
@ -35,6 +36,8 @@ protected:
|
||||
const std::shared_ptr<ov::Node> &ref) const;
|
||||
|
||||
std::vector<iMatcherConfig::Ptr> default_configs;
|
||||
// match only shape ranks by default;
|
||||
bool is_strict_shape_match = false;
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -17,9 +17,7 @@ public:
|
||||
FusedNamesExtractor(const std::string& device = "");
|
||||
~FusedNamesExtractor();
|
||||
|
||||
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body = true,
|
||||
bool is_copy_constants = true) override;
|
||||
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &modele) override;
|
||||
|
||||
protected:
|
||||
std::unordered_set<std::string> extract_compiled_model_names(const std::shared_ptr<ov::Model>& model);
|
||||
|
@ -12,38 +12,19 @@ namespace subgraph_dumper {
|
||||
|
||||
class ExtractorsManager {
|
||||
public:
|
||||
// { model, subgraph, model_in_info, subgraph_in_info }
|
||||
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
|
||||
using ExtractorsMap = std::map<std::string, SubgraphExtractor::Ptr>;
|
||||
|
||||
explicit ExtractorsManager(const ExtractorsMap& extractors = {}) : m_extractors(extractors) {}
|
||||
|
||||
bool match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref);
|
||||
ExtractedSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info = {},
|
||||
const std::map<std::string, InputInfo> &in_info_ref = {});
|
||||
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body = true,
|
||||
bool is_copy_constants = true);
|
||||
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body = true,
|
||||
bool is_copy_constants = true);
|
||||
|
||||
void set_extractors(const ExtractorsMap& extractors = {}) { m_extractors = extractors; }
|
||||
ExtractorsMap get_extractors() { return m_extractors; }
|
||||
|
||||
std::map<std::string, InputInfo> align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op = {});
|
||||
|
||||
protected:
|
||||
ExtractorsMap m_extractors = {};
|
||||
|
||||
bool match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref);
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -5,31 +5,45 @@
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
#include "matchers/single_op/single_op.hpp"
|
||||
#include "matchers/single_op/convolutions.hpp"
|
||||
#include "matchers/single_op/manager.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace tools {
|
||||
namespace subgraph_dumper {
|
||||
|
||||
class RepeatPatternExtractor final : public SubgraphExtractor {
|
||||
public:
|
||||
RepeatPatternExtractor() {
|
||||
MatchersManager::MatchersMap matchers = {
|
||||
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
|
||||
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
|
||||
};
|
||||
manager.set_matchers(matchers);
|
||||
}
|
||||
|
||||
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body = true,
|
||||
bool is_copy_constants = true) override;
|
||||
|
||||
private:
|
||||
MatchersManager manager;
|
||||
using InputVector = std::vector<ov::Input<ov::Node>>;
|
||||
using OutputVector = std::vector<ov::Output<ov::Node>>;
|
||||
|
||||
public:
|
||||
using PatternBorders = std::pair<InputVector, OutputVector>;
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
|
||||
std::vector<std::vector<PatternBorders>>
|
||||
get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model);
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model);
|
||||
|
||||
void set_recursive_extraction(bool _is_recursive_extraction);
|
||||
std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) override;
|
||||
|
||||
protected:
|
||||
// {subgraph, node_vector, input_info}
|
||||
using ExtractedRepeatPattern = std::tuple<std::shared_ptr<ov::Model>, ov::NodeVector, std::map<std::string, InputInfo>>;
|
||||
bool is_recursive_extraction = true;
|
||||
|
||||
std::list<std::vector<ExtractedRepeatPattern>>
|
||||
find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_save_borders_only = false);
|
||||
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
|
||||
std::list<std::vector<ExtractedRepeatPattern>>& secondary_extracted_patterns);
|
||||
void update_extractor_cache(std::list<std::vector<ExtractedRepeatPattern>>& extracted_patterns,
|
||||
const std::shared_ptr<ov::Model>& pattern,
|
||||
const ov::NodeVector& pattern_node_vector,
|
||||
const std::map<std::string, InputInfo>& in_info);
|
||||
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -7,12 +7,8 @@
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "common_test_utils/graph_comparator.hpp"
|
||||
|
||||
#include "cache/meta/input_info.hpp"
|
||||
#include "matchers/single_op/single_op.hpp"
|
||||
#include "matchers/single_op/convolutions.hpp"
|
||||
#include "matchers/single_op/manager.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace tools {
|
||||
@ -20,44 +16,19 @@ namespace subgraph_dumper {
|
||||
|
||||
class SubgraphExtractor {
|
||||
public:
|
||||
// { is_subgraph, model, subgraph, matched_ops{ model_op_name, graph_op_name }}
|
||||
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
|
||||
using Ptr = std::shared_ptr<SubgraphExtractor>;
|
||||
|
||||
SubgraphExtractor() {
|
||||
MatchersManager::MatchersMap matchers = {
|
||||
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
|
||||
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
|
||||
};
|
||||
m_manager.set_matchers(matchers);
|
||||
virtual std::vector<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model) {
|
||||
return std::vector<ExtractedPattern>{};
|
||||
}
|
||||
|
||||
bool match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const;
|
||||
IsSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const;
|
||||
|
||||
virtual std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body = true,
|
||||
bool is_copy_constants = true) {
|
||||
return std::list<ExtractedPattern>{};
|
||||
};
|
||||
|
||||
void set_extractor_name(const std::string& _extractor_name) { extractor_name = _extractor_name; }
|
||||
void set_extract_body(bool _is_extract_body) { is_extract_body = _is_extract_body; }
|
||||
void set_save_const(bool _is_save_const) { is_save_const = _is_save_const; }
|
||||
|
||||
protected:
|
||||
std::string extractor_name = "";
|
||||
FunctionsComparator comparator = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
MatchersManager m_manager = MatchersManager();
|
||||
|
||||
inline bool is_node_to_skip(const std::shared_ptr<ov::Node>& node) const {
|
||||
return ov::op::util::is_parameter(node) ||
|
||||
ov::op::util::is_constant(node) ||
|
||||
ov::op::util::is_output(node);
|
||||
}
|
||||
bool is_extract_body = true, is_save_const = true;
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -33,7 +33,7 @@ static std::vector<std::regex> FROTEND_REGEXP = {
|
||||
std::regex(R"(.*__model__)"),
|
||||
#endif
|
||||
#ifdef ENABLE_OV_TF_FRONTEND
|
||||
std::regex(R"(.*\.pb)"),
|
||||
std::regex(R"(.*\model.pb)"),
|
||||
#endif
|
||||
#ifdef ENABLE_OV_IR_FRONTEND
|
||||
std::regex(R"(.*\.xml)"),
|
||||
@ -74,32 +74,24 @@ std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
|
||||
void save_model_status_to_file(const std::map<ModelCacheStatus, std::vector<std::string>>& caching_status,
|
||||
const std::string& output_dir);
|
||||
|
||||
inline bool is_dynamic_model(const std::shared_ptr<ov::Model>& model) {
|
||||
for (const auto& parameter : model->get_parameters()) {
|
||||
if (is_dynamic_node(parameter)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto& result : model->get_results()) {
|
||||
if (is_dynamic_node(result)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool is_dynamic_model(const std::shared_ptr<ov::Model>& model);
|
||||
std::string get_model_type(const std::shared_ptr<ov::Model>& model);
|
||||
|
||||
inline std::string get_model_type(const std::shared_ptr<ov::Model>& model) {
|
||||
if (is_dynamic_model(model)) {
|
||||
return "dynamic";
|
||||
}
|
||||
return "static";
|
||||
}
|
||||
std::map<std::string, InputInfo>
|
||||
get_input_info_by_model(const std::shared_ptr<ov::Model>& model);
|
||||
|
||||
inline ExtractedPattern
|
||||
generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
|
||||
std::map<std::string, InputInfo>
|
||||
align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op = {});
|
||||
|
||||
inline std::pair<std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>>
|
||||
generate_model(ov::NodeVector& nodes,
|
||||
std::unordered_set<std::string>& checked_ops,
|
||||
const std::string& extractor_name,
|
||||
bool is_copy_constants = true) {
|
||||
bool is_copy_constants = true,
|
||||
bool is_save_only_borders = false) {
|
||||
// map to recover graph using cloned nodes and original connections
|
||||
// { original_node_name, cloned_node }
|
||||
std::unordered_map<std::string, std::shared_ptr<ov::Node>> cloned_node_map;
|
||||
@ -214,27 +206,51 @@ generate_model(const std::set<std::shared_ptr<ov::Node>>& nodes,
|
||||
// prepare unique model name based on operations from model
|
||||
std::string string_to_hash;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
bool is_erase_node = !is_save_only_borders;
|
||||
std::ostringstream result;
|
||||
result << op->get_type_info();
|
||||
for (const auto& in : op->inputs()) {
|
||||
for (size_t i = 0; i < op->inputs().size(); ++i) {
|
||||
const auto& in = op->input(i);
|
||||
if (!is_node_to_skip(op->get_input_node_shared_ptr(i))) {
|
||||
is_erase_node |= true;
|
||||
}
|
||||
result << in.get_element_type();
|
||||
result << in.get_partial_shape().rank();
|
||||
result << in.get_partial_shape().is_static();
|
||||
}
|
||||
for (const auto& out : op->outputs()) {
|
||||
for (const auto& target_input : out.get_target_inputs()) {
|
||||
if (!is_node_to_skip(target_input.get_node()->shared_from_this())) {
|
||||
is_erase_node |= true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
result << out.get_element_type();
|
||||
result << out.get_partial_shape().rank();
|
||||
result << out.get_partial_shape().is_static();
|
||||
}
|
||||
string_to_hash += result.str();
|
||||
if (is_erase_node) {
|
||||
cloned_node_map.erase(op->get_friendly_name());
|
||||
}
|
||||
}
|
||||
for (const auto& in : model_input_info) {
|
||||
string_to_hash += (in.second.is_const ? "1" : "0");
|
||||
}
|
||||
auto h1 = std::hash<std::string>{}(string_to_hash);
|
||||
model->set_friendly_name(std::to_string(h1));
|
||||
|
||||
return { model, model_input_info, extractor_name };
|
||||
{
|
||||
auto it = nodes.begin();
|
||||
while (it != nodes.end()) {
|
||||
if (cloned_node_map.count((*it)->get_friendly_name())) {
|
||||
nodes.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { model, model_input_info };
|
||||
}
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
|
@ -0,0 +1,68 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "matchers/single_op/single_op.hpp"
|
||||
#include "matchers/single_op/convolutions.hpp"
|
||||
#include "matchers/single_op/manager.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace tools {
|
||||
namespace subgraph_dumper {
|
||||
|
||||
class ModelComparator {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<ModelComparator>;
|
||||
// { is_match, subgraph, graph, matched_nodes -> {subgraph_op_name, graph_op_name}}
|
||||
using IsSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, std::string>>;
|
||||
// { model, subgraph, graph, subgraph_in_info, model_in_info, }
|
||||
using ExtractedSubgraphTuple = std::tuple<bool, std::shared_ptr<ov::Model>, std::shared_ptr<ov::Model>, std::map<std::string, InputInfo>, std::map<std::string, InputInfo>>;
|
||||
|
||||
static std::shared_ptr<ModelComparator> get(bool in_is_match_shapes = false) {
|
||||
if (m_instance == nullptr) {
|
||||
m_instance = std::shared_ptr<ModelComparator>(new ModelComparator);
|
||||
}
|
||||
return m_instance;
|
||||
}
|
||||
|
||||
IsSubgraphTuple is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const;
|
||||
|
||||
bool match(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref_node) const;
|
||||
bool match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const;
|
||||
|
||||
std::pair<bool, std::map<std::string, InputInfo>>
|
||||
match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref);
|
||||
ExtractedSubgraphTuple
|
||||
is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref);
|
||||
|
||||
void set_match_coefficient(float _match_coefficient);
|
||||
void set_shape_strict_match(bool is_shape_strict_match);
|
||||
|
||||
protected:
|
||||
MatchersManager m_manager = MatchersManager();
|
||||
float match_coefficient = 0.9f;
|
||||
static std::shared_ptr<ModelComparator> m_instance;
|
||||
|
||||
ModelComparator() {
|
||||
MatchersManager::MatchersMap matchers = {
|
||||
{ "generic_single_op", SingleOpMatcher::Ptr(new SingleOpMatcher) },
|
||||
{ "convolutions", ConvolutionsMatcher::Ptr(new ConvolutionsMatcher) },
|
||||
};
|
||||
m_manager.set_matchers(matchers);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
} // namespace tools
|
||||
} // namespace ov
|
@ -26,6 +26,9 @@ inline InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Const
|
||||
return InputInfo::Range(static_cast<double>(min), static_cast<double>(max));
|
||||
}
|
||||
|
||||
InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& const_node,
|
||||
ov::element::Type elem_type);
|
||||
|
||||
std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov::Node>& node);
|
||||
|
||||
// replace all input node by parameters and constants instead of non input mode types
|
||||
@ -111,6 +114,12 @@ inline size_t get_node_priority_by_version(const std::shared_ptr<ov::Node>& node
|
||||
|
||||
return priority;
|
||||
}
|
||||
|
||||
inline bool is_node_to_skip(const std::shared_ptr<ov::Node>& node) {
|
||||
return ov::op::util::is_parameter(node) ||
|
||||
ov::op::util::is_constant(node) ||
|
||||
ov::op::util::is_output(node);
|
||||
}
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
} // namespace tools
|
||||
|
@ -42,8 +42,8 @@ bool ICache::serialize_model(const std::pair<std::shared_ptr<ov::Model>, MetaInf
|
||||
meta.serialize(meta_path);
|
||||
return true;
|
||||
} catch (std::exception &e) {
|
||||
// std::cout << "[ ERROR ] Failed to serialize model: " << model_name
|
||||
// << ". Exception: " << e.what() << std::endl;
|
||||
std::cout << "[ ERROR ] Failed to serialize model: " << model_name
|
||||
<< ". Exception: " << e.what() << std::endl;
|
||||
ov::test::utils::removeFile(xml_path);
|
||||
ov::test::utils::removeFile(bin_path);
|
||||
ov::test::utils::removeFile(meta_path);
|
||||
|
@ -56,77 +56,92 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& model,
|
||||
return;
|
||||
}
|
||||
while (!extracted_patterns.empty()) {
|
||||
auto it = *extracted_patterns.begin();
|
||||
auto it = *extracted_patterns.rbegin();
|
||||
update_cache(std::get<0>(it), model_meta_data, std::get<1>(it), std::get<2>(it), model_total_op);
|
||||
extracted_patterns.pop_front();
|
||||
extracted_patterns.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
|
||||
const std::string& model_path,
|
||||
std::map<std::string, InputInfo>& input_info,
|
||||
const std::map<std::string, InputInfo>& input_info,
|
||||
const std::string& extractor_name,
|
||||
size_t model_op_cnt) {
|
||||
auto graph_name = extracted_model->get_friendly_name();
|
||||
auto this_op_cnt = extracted_model->get_ops().size() -
|
||||
extracted_model->get_parameters().size() - extracted_model->get_results().size();
|
||||
std::string serialized_model_path = "";
|
||||
for (const auto& extractor : m_manager.get_extractors()) {
|
||||
auto tmp_serialized_model_path = ov::util::path_join({ m_serialization_dir, m_cache_subdir, extractor.first, graph_name + ".xml" });
|
||||
if (ov::util::file_exists(serialized_model_path)) {
|
||||
serialized_model_path = tmp_serialized_model_path;
|
||||
break;
|
||||
std::map<std::string, InputInfo> updated_input_info;
|
||||
if (!m_graph_cache.empty() && model_to_update != nullptr) {
|
||||
auto comparator_res = m_model_comparator->match(extracted_model, model_to_update,
|
||||
input_info, m_graph_cache.at(model_to_update).get_input_info());
|
||||
if (comparator_res.first) {
|
||||
updated_input_info = comparator_res.second;
|
||||
} else {
|
||||
model_to_update = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> model_to_update = nullptr;
|
||||
// if cached model was serialized
|
||||
if (!serialized_model_path.empty()) {
|
||||
// std::cout << "[ GRAPH CACHE ][ INFO ] Reading cached model: " << serialized_model_path << std::endl;
|
||||
auto bin_path = ov::test::utils::replaceExt(serialized_model_path, ".bin");
|
||||
auto meta_path = ov::test::utils::replaceExt(serialized_model_path, ".meta");
|
||||
auto cached_model = ov::test::utils::PluginCache::get().core()->read_model(serialized_model_path);
|
||||
auto cached_meta = MetaInfo::read_meta_from_file(meta_path);
|
||||
|
||||
ov::test::utils::removeFile(serialized_model_path);
|
||||
ov::test::utils::removeFile(bin_path);
|
||||
ov::test::utils::removeFile(meta_path);
|
||||
|
||||
m_graph_cache.insert({ cached_model, cached_meta });
|
||||
m_graph_cache_bytesize += cached_model->get_graph_size();
|
||||
|
||||
if (m_manager.match(extracted_model, cached_model,
|
||||
input_info, cached_meta.get_input_info())) {
|
||||
model_to_update = cached_model;
|
||||
}
|
||||
} else {
|
||||
for (const auto& cached_model : m_graph_cache) {
|
||||
if (m_manager.match(extracted_model, cached_model.first,
|
||||
input_info, cached_model.second.get_input_info())) {
|
||||
model_to_update = cached_model.first;
|
||||
if (model_to_update == nullptr) {
|
||||
std::string serialized_model_path = "";
|
||||
for (const auto& extractor : m_manager.get_extractors()) {
|
||||
auto tmp_serialized_model_path = ov::util::path_join({ m_serialization_dir, m_cache_subdir, extractor.first, graph_name + ".xml" });
|
||||
if (ov::util::file_exists(serialized_model_path)) {
|
||||
serialized_model_path = tmp_serialized_model_path;
|
||||
break;
|
||||
} else {
|
||||
auto is_subgraph = m_manager.is_subgraph(extracted_model, cached_model.first,
|
||||
input_info, cached_model.second.get_input_info());
|
||||
// in case if one model is subgraph of other to update model meta info and remove subgraph from cache
|
||||
if (std::get<0>(is_subgraph)) {
|
||||
std::shared_ptr<ov::Model> graph, subgraph;
|
||||
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
|
||||
std::tie(std::ignore, graph, subgraph, graph_in_info, subgraph_in_info) = is_subgraph;
|
||||
if (subgraph == cached_model.first) {
|
||||
auto meta = m_graph_cache[subgraph];
|
||||
meta.set_input_info(graph_in_info);
|
||||
m_graph_cache.erase(subgraph);
|
||||
m_graph_cache.insert({graph, meta});
|
||||
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
|
||||
}
|
||||
}
|
||||
// if cached model was serialized
|
||||
if (!serialized_model_path.empty()) {
|
||||
// std::cout << "[ GRAPH CACHE ][ INFO ] Reading cached model: " << serialized_model_path << std::endl;
|
||||
auto bin_path = ov::test::utils::replaceExt(serialized_model_path, ".bin");
|
||||
auto meta_path = ov::test::utils::replaceExt(serialized_model_path, ".meta");
|
||||
auto cached_model = ov::test::utils::PluginCache::get().core()->read_model(serialized_model_path);
|
||||
auto cached_meta = MetaInfo::read_meta_from_file(meta_path);
|
||||
|
||||
ov::test::utils::removeFile(serialized_model_path);
|
||||
ov::test::utils::removeFile(bin_path);
|
||||
ov::test::utils::removeFile(meta_path);
|
||||
|
||||
m_graph_cache.insert({ cached_model, cached_meta });
|
||||
m_graph_cache_bytesize += cached_model->get_graph_size();
|
||||
|
||||
auto comparator_res = m_model_comparator->match(extracted_model, cached_model,
|
||||
input_info, cached_meta.get_input_info());
|
||||
if (comparator_res.first) {
|
||||
model_to_update = cached_model;
|
||||
updated_input_info = comparator_res.second;
|
||||
}
|
||||
} else {
|
||||
for (const auto& cached_model : m_graph_cache) {
|
||||
auto comparator_res = m_model_comparator->match(extracted_model, cached_model.first,
|
||||
input_info, cached_model.second.get_input_info());
|
||||
if (comparator_res.first) {
|
||||
model_to_update = cached_model.first;
|
||||
updated_input_info = comparator_res.second;
|
||||
break;
|
||||
} else {
|
||||
auto is_subgraph = m_model_comparator->is_subgraph(extracted_model, cached_model.first,
|
||||
input_info, cached_model.second.get_input_info());
|
||||
// in case if one model is subgraph of other to update model meta info and remove subgraph from cache
|
||||
if (std::get<0>(is_subgraph)) {
|
||||
std::shared_ptr<ov::Model> graph, subgraph;
|
||||
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
|
||||
std::tie(std::ignore, subgraph, graph, subgraph_in_info, graph_in_info) = is_subgraph;
|
||||
if (subgraph == cached_model.first) {
|
||||
auto meta = m_graph_cache[subgraph];
|
||||
meta.set_input_info(graph_in_info);
|
||||
m_graph_cache.erase(subgraph);
|
||||
m_graph_cache.insert({graph, meta});
|
||||
m_graph_cache_bytesize += (graph->get_graph_size() - subgraph->get_graph_size());
|
||||
}
|
||||
m_graph_cache[cached_model.first].update(model_path,
|
||||
subgraph_in_info,
|
||||
model_op_cnt,
|
||||
this_op_cnt,
|
||||
extractor_name);
|
||||
return;
|
||||
}
|
||||
m_graph_cache[cached_model.first].update(model_path,
|
||||
subgraph_in_info,
|
||||
model_op_cnt,
|
||||
this_op_cnt,
|
||||
extractor_name);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -134,18 +149,22 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
|
||||
|
||||
if (model_to_update == nullptr) {
|
||||
MetaInfo meta = MetaInfo(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
|
||||
m_graph_cache.insert({ extracted_model, meta });
|
||||
model_to_update = extracted_model;
|
||||
m_graph_cache.insert({ model_to_update, meta });
|
||||
m_graph_cache_bytesize += extracted_model->get_graph_size();
|
||||
return;
|
||||
}
|
||||
m_graph_cache[model_to_update].update(model_path, input_info, model_op_cnt, this_op_cnt, extractor_name);
|
||||
m_graph_cache[model_to_update].update(model_path, updated_input_info, model_op_cnt, this_op_cnt, extractor_name);
|
||||
auto cached_model_size = model_to_update->get_graph_size();
|
||||
auto pattern_model_size = extracted_model->get_graph_size();
|
||||
if (pattern_model_size < cached_model_size) {
|
||||
m_graph_cache_bytesize -= (cached_model_size - pattern_model_size);
|
||||
auto meta = m_graph_cache[model_to_update];
|
||||
auto new_in_info = align_input_info(model_to_update, extracted_model, m_graph_cache.at(model_to_update).get_input_info(), input_info);
|
||||
meta.set_input_info(new_in_info);
|
||||
m_graph_cache.erase(model_to_update);
|
||||
m_graph_cache.insert({extracted_model, meta});
|
||||
model_to_update = extracted_model;
|
||||
m_graph_cache.insert({model_to_update, meta});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +169,6 @@ void MetaInfo::update(const std::string& _model_path,
|
||||
size_t _this_op_cnt,
|
||||
const std::string& extractor,
|
||||
const std::vector<std::string>& ignored_inputs) {
|
||||
bool is_update_in_info = true;
|
||||
if (input_info.size() != _input_info.size()) {
|
||||
throw std::runtime_error("Incompatible input info!");
|
||||
}
|
||||
@ -193,9 +192,6 @@ void MetaInfo::update(const std::string& _model_path,
|
||||
if (!extractor.empty()) {
|
||||
extractors.insert(extractor);
|
||||
}
|
||||
if (!is_update_in_info) {
|
||||
return;
|
||||
}
|
||||
for (const auto& in : _input_info) {
|
||||
if (std::find(ignored_inputs.begin(), ignored_inputs.end(), in.first) != ignored_inputs.begin()) {
|
||||
continue;
|
||||
|
@ -50,8 +50,11 @@ bool ConvolutionsMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
|
||||
bool has_groups = std::dynamic_pointer_cast<ov::op::v1::GroupConvolution>(node) ||
|
||||
std::dynamic_pointer_cast<ov::op::v1::GroupConvolutionBackpropData>(node);
|
||||
size_t kernel_size_offset = has_groups ? 3 : 2;
|
||||
auto ref_weights_shape = ref->get_input_tensor(1).get_shape();
|
||||
auto cur_weights_shape = node->get_input_tensor(1).get_shape();
|
||||
auto ref_weights_shape = ref->get_input_partial_shape(1).get_shape();
|
||||
auto cur_weights_shape = node->get_input_partial_shape(1).get_shape();
|
||||
if (is_strict_shape_match && ref_weights_shape != cur_weights_shape) {
|
||||
return false;
|
||||
}
|
||||
const auto ref_kernel_size = std::vector<size_t>(ref_weights_shape.begin() + kernel_size_offset,
|
||||
ref_weights_shape.end());
|
||||
const auto cur_kernel_size = std::vector<size_t>(cur_weights_shape.begin() + kernel_size_offset,
|
||||
|
@ -17,9 +17,15 @@ iMatcherConfig::Ptr MatchersManager::get_config(const std::shared_ptr<ov::Node>
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void MatchersManager::set_shape_strict_match(bool shape_strict_match) {
|
||||
for (const auto& matcher : m_matchers) {
|
||||
matcher.second->set_strict_shape_match(shape_strict_match);
|
||||
}
|
||||
}
|
||||
|
||||
bool MatchersManager::match(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref) const {
|
||||
for (const auto &it : m_matchers) {
|
||||
for (const auto& it : m_matchers) {
|
||||
if (it.second->match(node, ref)) {
|
||||
return true;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "common_test_utils/graph_comparator.hpp"
|
||||
#include "matchers/single_op/single_op.hpp"
|
||||
#include "utils/node.hpp"
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
@ -24,6 +25,10 @@ iMatcherConfig::Ptr SingleOpMatcher::get_config(const std::shared_ptr<ov::Node>
|
||||
return std::make_shared<MatcherConfig<>>();
|
||||
}
|
||||
|
||||
void SingleOpMatcher::set_strict_shape_match(bool strict_shape_match) {
|
||||
is_strict_shape_match = strict_shape_match;
|
||||
}
|
||||
|
||||
bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref) const {
|
||||
if (node->get_input_size() != ref->get_input_size()) {
|
||||
@ -35,21 +40,17 @@ bool SingleOpMatcher::match_inputs(const std::shared_ptr<ov::Node> &node,
|
||||
if (std::find(ignored_ports.begin(), ignored_ports.end(), port_id) != ignored_ports.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!ov::op::util::is_parameter(node) && !ov::op::util::is_parameter(ref) &&
|
||||
!ov::op::util::is_constant(node) && !ov::op::util::is_constant(ref)) {
|
||||
const auto &cur_node_input_type = node->input_value(port_id).get_node_shared_ptr()->get_type_info();
|
||||
const auto &ref_node_input_type = ref->input_value(port_id).get_node_shared_ptr()->get_type_info();
|
||||
if (cur_node_input_type != ref_node_input_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (node->get_input_tensor(port_id).get_partial_shape().rank() != ref->get_input_tensor(port_id).get_partial_shape().rank()) {
|
||||
if (node->get_input_element_type(port_id) != ref->get_input_element_type(port_id)) {
|
||||
return false;
|
||||
}
|
||||
if (node->get_input_tensor(port_id).get_element_type() != ref->get_input_tensor(port_id).get_element_type()) {
|
||||
const auto& partial_shape = node->get_input_partial_shape(port_id);
|
||||
const auto& ref_partial_shape = ref->get_input_partial_shape(port_id);
|
||||
if (is_strict_shape_match && partial_shape != ref_partial_shape) {
|
||||
return false;
|
||||
} else if (partial_shape.rank() != ref_partial_shape.rank()) {
|
||||
return false;
|
||||
}
|
||||
if (node->get_input_partial_shape(port_id).is_dynamic() != ref->get_input_partial_shape(port_id).is_dynamic()) {
|
||||
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -63,20 +64,18 @@ SingleOpMatcher::match_outputs(const std::shared_ptr<ov::Node> &node,
|
||||
return false;
|
||||
}
|
||||
for (size_t port_id = 0; port_id < node->get_output_size(); ++port_id) {
|
||||
if (!ov::op::util::is_output(node) && !ov::op::util::is_output(ref)) {
|
||||
const auto &cur_node_out_type = node->output(port_id).get_node_shared_ptr()->get_type_info();
|
||||
const auto &ref_node_out_type = ref->output(port_id).get_node_shared_ptr()->get_type_info();
|
||||
if (cur_node_out_type != ref_node_out_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (node->get_output_tensor(port_id).get_element_type() != ref->get_output_tensor(port_id).get_element_type()) {
|
||||
if (node->get_output_element_type(port_id) != ref->get_output_element_type(port_id)) {
|
||||
return false;
|
||||
}
|
||||
if (node->get_output_tensor(port_id).get_partial_shape().is_dynamic() != ref->get_output_tensor(port_id).get_partial_shape().is_dynamic()) {
|
||||
|
||||
const auto& partial_shape = node->get_output_partial_shape(port_id);
|
||||
const auto& ref_partial_shape = ref->get_output_partial_shape(port_id);
|
||||
if (partial_shape.is_dynamic() != ref_partial_shape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (node->get_output_tensor(port_id).get_partial_shape().rank()!= ref->get_output_tensor(port_id).get_partial_shape().rank()) {
|
||||
if (is_strict_shape_match && partial_shape != ref_partial_shape) {
|
||||
return false;
|
||||
} else if (partial_shape.rank() != ref_partial_shape.rank()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -98,17 +97,16 @@ bool SingleOpMatcher::match(const std::shared_ptr<ov::Node> &node,
|
||||
if (cfg->ignore_matching) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!same_op_type(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (!match_inputs(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (!match_attrs(node, ref) && !ov::op::util::is_parameter(node) && !ov::op::util::is_parameter(ref)) {
|
||||
if (!match_outputs(node, ref)) {
|
||||
return false;
|
||||
}
|
||||
if (!match_outputs(node, ref)) {
|
||||
if (!match_attrs(node, ref) && !is_node_to_skip(node)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -121,9 +119,6 @@ bool SingleOpMatcher::same_op_type(const std::shared_ptr<ov::Node> &node,
|
||||
|
||||
SingleOpMatcher::SingleOpMatcher() {
|
||||
default_configs = {
|
||||
// std::make_shared<MatcherConfig<>>(std::vector<std::string>{}, std::vector<size_t>{0}),
|
||||
// std::make_shared<MatcherConfig<ov::opset8::FakeQuantize>>(std::vector<std::string>{},
|
||||
// std::vector<size_t>{0, 1, 2, 3, 4}),
|
||||
std::make_shared<MatcherConfig<
|
||||
ov::op::v1::Convolution,
|
||||
ov::op::v1::ConvolutionBackpropData,
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "openvino/op/lstm_cell.hpp"
|
||||
#include "openvino/op/tensor_iterator.hpp"
|
||||
#include "openvino/op/if.hpp"
|
||||
#include "openvino/op/loop.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/ov_plugin_cache.hpp"
|
||||
@ -56,14 +57,12 @@ FusedNamesExtractor::~FusedNamesExtractor() {
|
||||
core.reset();
|
||||
}
|
||||
|
||||
std::list<ExtractedPattern>
|
||||
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body,
|
||||
bool is_copy_constants) {
|
||||
std::vector<ExtractedPattern>
|
||||
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
auto compiled_op_name = extract_compiled_model_names(model);
|
||||
std::list<ExtractedPattern> matched_patterns;
|
||||
std::vector<ExtractedPattern> matched_patterns;
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
std::set<std::shared_ptr<ov::Node>> nodes;
|
||||
ov::NodeVector nodes;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
auto op_name = op->get_friendly_name();
|
||||
if (is_node_to_skip(op) || checked_ops.count(op_name)) {
|
||||
@ -71,7 +70,8 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
}
|
||||
if (compiled_op_name.count(op_name)) {
|
||||
try {
|
||||
matched_patterns.push_back(generate_model(nodes, checked_ops, extractor_name, is_copy_constants));
|
||||
auto extracted_pattern = generate_model(nodes, checked_ops, is_save_const);
|
||||
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
|
||||
} catch(std::exception& e) {
|
||||
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
|
||||
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;
|
||||
@ -79,7 +79,7 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
}
|
||||
nodes.clear();
|
||||
} else {
|
||||
nodes.insert(op);
|
||||
nodes.push_back(op);
|
||||
}
|
||||
if (is_extract_body) {
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
|
||||
@ -104,7 +104,8 @@ FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
}
|
||||
}
|
||||
try {
|
||||
matched_patterns.push_back(generate_model(nodes, checked_ops, extractor_name, is_copy_constants));
|
||||
auto extracted_pattern = generate_model(nodes, checked_ops, is_save_const);
|
||||
matched_patterns.push_back({ extracted_pattern.first, extracted_pattern.second, extractor_name });
|
||||
} catch(std::exception& e) {
|
||||
if (std::string(e.what()).find("Incorrect node number to create model") == std::string::npos) {
|
||||
// std::cout << "[ WARNING ] Impossible to generate network and add to GraphCache: " <<e.what() << std::endl;
|
||||
|
@ -9,122 +9,18 @@
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
bool ExtractorsManager::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref) {
|
||||
// `match` is not virtual method in base `SubgraphExtractor` class
|
||||
// we can use function from any `extractor` to avoid of cycle
|
||||
if (!m_extractors.empty()) {
|
||||
if (m_extractors.begin()->second->match(model, ref)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ExtractorsManager::ExtractedSubgraphTuple
|
||||
ExtractorsManager::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
if (!m_extractors.empty()) {
|
||||
// `is_subgraph` is not virtual method in base `SubgraphExtractor` class
|
||||
// we can use function from any `extractor` to avoid of cycle
|
||||
auto extractor_res = m_extractors.begin()->second->is_subgraph(model, ref_model);
|
||||
if (std::get<0>(extractor_res)) {
|
||||
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
|
||||
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
|
||||
graph_in_info = in_info;
|
||||
subgraph_in_info = in_info_ref;
|
||||
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
|
||||
graph_in_info = in_info_ref;
|
||||
subgraph_in_info = in_info;
|
||||
} else {
|
||||
throw std::runtime_error("Generated models are incompatible with original ones!");
|
||||
}
|
||||
try {
|
||||
subgraph_in_info = align_input_info(std::get<2>(extractor_res), std::get<1>(extractor_res), subgraph_in_info, graph_in_info);
|
||||
} catch(std::exception) {
|
||||
return { false, nullptr, nullptr, {}, {} };
|
||||
}
|
||||
return { true, std::get<1>(extractor_res), std::get<2>(extractor_res), graph_in_info, subgraph_in_info };
|
||||
}
|
||||
}
|
||||
return { false, nullptr, nullptr, {}, {} };
|
||||
}
|
||||
|
||||
bool ExtractorsManager::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref,
|
||||
std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
if (match(model, ref)) {
|
||||
try {
|
||||
in_info = align_input_info(model, ref, in_info, in_info_ref);
|
||||
return true;
|
||||
} catch (std::exception) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<std::string, InputInfo>
|
||||
ExtractorsManager::align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, InputInfo>& in_info,
|
||||
const std::map<std::string, InputInfo>& in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op) {
|
||||
std::map<std::string, InputInfo> new_input_info = in_info;
|
||||
bool is_update_required = false;
|
||||
for (const auto& in_info_item : in_info_ref) {
|
||||
if (!in_info.count(in_info_item.first)) {
|
||||
is_update_required = true;
|
||||
break;
|
||||
} else if (in_info.at(in_info_item.first).is_const != in_info_item.second.is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
}
|
||||
if (is_update_required) {
|
||||
// align matched model names
|
||||
auto ref_model_ops = model_ref->get_ordered_ops();
|
||||
auto model_ops = model->get_ordered_ops();
|
||||
size_t ref_ordered_ops_size = ref_model_ops.size();
|
||||
size_t ordered_ops_size = model_ops.size();
|
||||
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
|
||||
throw std::runtime_error("Matched models can not be compared according different op numbers!");
|
||||
}
|
||||
for (size_t i = 0; i < ref_ordered_ops_size; ++i) {
|
||||
auto model_op_name = i < ordered_ops_size ? model_ops[i]->get_friendly_name() : "";
|
||||
auto model_ref_op_name = ref_model_ops[i]->get_friendly_name();
|
||||
if (!in_info_ref.count(model_ref_op_name) && !in_info.count(model_op_name)) {
|
||||
continue;
|
||||
}
|
||||
auto input_info = matched_op.empty() ? new_input_info[model_op_name] : in_info_ref.at(model_ref_op_name);
|
||||
std::string input_name = matched_op.count(model_ref_op_name) ? matched_op.at(model_ref_op_name) : model_op_name;
|
||||
if (new_input_info.count(input_name)) {
|
||||
if (input_info.is_const != in_info_ref.at(model_ref_op_name).is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
if (!matched_op.empty()) {
|
||||
input_info = new_input_info.at(input_name);
|
||||
}
|
||||
new_input_info.erase(input_name);
|
||||
}
|
||||
new_input_info.insert({ model_ref_op_name, input_info });
|
||||
}
|
||||
}
|
||||
return new_input_info;
|
||||
}
|
||||
|
||||
std::list<ExtractedPattern>
|
||||
std::vector<ExtractedPattern>
|
||||
ExtractorsManager::extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body,
|
||||
bool is_copy_constants) {
|
||||
std::list<ExtractedPattern> result;
|
||||
std::vector<ExtractedPattern> result;
|
||||
for (const auto &it : m_extractors) {
|
||||
// extract patterns from original models
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
it.second->set_extractor_name(it.first);
|
||||
auto extracted_patterns = it.second->extract(model, is_extract_body, is_copy_constants);
|
||||
it.second->set_extract_body(is_extract_body);
|
||||
it.second->set_save_const(is_copy_constants);
|
||||
auto extracted_patterns = it.second->extract(model);
|
||||
result.insert(result.end(), extracted_patterns.begin(), extracted_patterns.end());
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto delta = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
|
||||
|
@ -8,18 +8,118 @@
|
||||
#include "openvino/op/lstm_cell.hpp"
|
||||
#include "openvino/op/tensor_iterator.hpp"
|
||||
#include "openvino/op/if.hpp"
|
||||
#include "openvino/op/loop.hpp"
|
||||
|
||||
#include "matchers/subgraph/repeat_pattern.hpp"
|
||||
#include "utils/model.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
std::list<ExtractedPattern>
|
||||
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_extract_body,
|
||||
bool is_copy_constants) {
|
||||
void RepeatPatternExtractor::set_recursive_extraction(bool _is_recursive_extraction) {
|
||||
is_recursive_extraction = _is_recursive_extraction;
|
||||
}
|
||||
|
||||
std::vector<ExtractedPattern>
|
||||
RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<ExtractedPattern> extracted_patterns;
|
||||
for (const auto& pattern : find_repeat_patterns(model)) {
|
||||
for (const auto& pattern_structure : pattern) {
|
||||
extracted_patterns.push_back({std::get<0>(pattern_structure), std::get<2>(pattern_structure), extractor_name});
|
||||
}
|
||||
}
|
||||
return extracted_patterns;
|
||||
}
|
||||
|
||||
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>
|
||||
RepeatPatternExtractor::get_repeat_pattern_borders(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<std::vector<RepeatPatternExtractor::PatternBorders>> extracted_patterns;
|
||||
for (auto& pattern : find_repeat_patterns(model, true)) {
|
||||
std::vector<RepeatPatternExtractor::PatternBorders> same_pattern_borders;
|
||||
for (const auto& pattern_structure : pattern) {
|
||||
std::set<std::string> output_names;
|
||||
for (const auto& result : std::get<0>(pattern_structure)->get_results()) {
|
||||
output_names.insert(result->get_input_node_shared_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
||||
RepeatPatternExtractor::InputVector in_vec;
|
||||
RepeatPatternExtractor::OutputVector out_vec;
|
||||
for (const auto& node : std::get<1>(pattern_structure)) {
|
||||
if (output_names.count(node->get_friendly_name())) {
|
||||
OutputVector node_outputs = node->outputs();
|
||||
out_vec.insert(out_vec.end(), node_outputs.begin(), node_outputs.end());
|
||||
} else {
|
||||
for (const auto& input : node->inputs()) {
|
||||
in_vec.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
same_pattern_borders.push_back({in_vec, out_vec});
|
||||
}
|
||||
extracted_patterns.push_back(same_pattern_borders);
|
||||
}
|
||||
return extracted_patterns;
|
||||
}
|
||||
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
RepeatPatternExtractor::get_repeat_node_vectors(const std::shared_ptr<ov::Model> &model) {
|
||||
std::vector<std::vector<ov::NodeVector>> extracted_patterns;
|
||||
for (const auto& pattern : find_repeat_patterns(model)) {
|
||||
std::vector<ov::NodeVector> same_pattern_nodes;
|
||||
for (const auto& pattern_structure : pattern) {
|
||||
same_pattern_nodes.push_back(std::get<1>(pattern_structure));
|
||||
}
|
||||
extracted_patterns.push_back(same_pattern_nodes);
|
||||
}
|
||||
return extracted_patterns;
|
||||
}
|
||||
|
||||
void
|
||||
RepeatPatternExtractor::update_extractor_cache(
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
|
||||
const std::shared_ptr<ov::Model>& pattern,
|
||||
const ov::NodeVector& pattern_node_vector,
|
||||
const std::map<std::string, InputInfo>& pattern_in_info) {
|
||||
for (auto& extracted_pattern : extracted_patterns) {
|
||||
auto& pattern_structure = extracted_pattern.front();
|
||||
const auto& cached_pattern = std::get<0>(pattern_structure);
|
||||
if (model_comparator->match(pattern, cached_pattern)) {
|
||||
try {
|
||||
const auto& cached_in_info = std::get<2>(pattern_structure);
|
||||
align_input_info(pattern, cached_pattern, pattern_in_info, cached_in_info);
|
||||
extracted_pattern.push_back({ pattern, pattern_node_vector, pattern_in_info });
|
||||
return;
|
||||
} catch(std::exception) {}
|
||||
}
|
||||
}
|
||||
extracted_patterns.push_back({{ pattern, pattern_node_vector, pattern_in_info }});
|
||||
}
|
||||
|
||||
void
|
||||
RepeatPatternExtractor::update_extractor_cache(
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& extracted_patterns,
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>& secondary_extracted_patterns) {
|
||||
auto extern_it = secondary_extracted_patterns.begin();
|
||||
while (!secondary_extracted_patterns.empty()) {
|
||||
auto it = extern_it->rbegin();
|
||||
while (!extern_it->empty()) {
|
||||
auto& pattern_structure = *it;
|
||||
const auto& pattern = std::get<0>(pattern_structure);
|
||||
const auto& pattern_node_vector = std::get<1>(pattern_structure);
|
||||
const auto& pattern_in_info = std::get<2>(pattern_structure);
|
||||
update_extractor_cache(extracted_patterns, pattern, pattern_node_vector, pattern_in_info);
|
||||
extern_it->pop_back();
|
||||
it = extern_it->rbegin();
|
||||
}
|
||||
secondary_extracted_patterns.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>>
|
||||
RepeatPatternExtractor::find_repeat_patterns(const std::shared_ptr<ov::Model> &model,
|
||||
bool is_save_borders_only) {
|
||||
std::list<std::vector<RepeatPatternExtractor::ExtractedRepeatPattern>> extracted_patterns;
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
std::list<ExtractedPattern> to_cache;
|
||||
|
||||
auto ordered_ops = model->get_ordered_ops();
|
||||
auto op_cnt = ordered_ops.size();
|
||||
@ -31,9 +131,10 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
continue;
|
||||
}
|
||||
|
||||
// find the same nodes
|
||||
std::vector<size_t> start_node_idx{idx};
|
||||
for (size_t i = idx + 1; i < op_cnt; ++i) {
|
||||
if (manager.match(op, ordered_ops[i])) {
|
||||
if (model_comparator->match(op, ordered_ops[i])) {
|
||||
start_node_idx.push_back(i);
|
||||
}
|
||||
}
|
||||
@ -57,9 +158,9 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
if (node_idx == start_node_idx[i] && ref_node_idx == start_node_idx[j]) {
|
||||
nodes[i].insert(node);
|
||||
nodes[j].insert(ref_node);
|
||||
} else if (manager.match(node, ref_node)) {
|
||||
} else if (model_comparator->match(node, ref_node)) {
|
||||
// check if we met the same node
|
||||
if (manager.match(node, op)) {
|
||||
if (model_comparator->match(node, op)) {
|
||||
break;
|
||||
}
|
||||
if (checked_ops.count(node->get_friendly_name()) ||
|
||||
@ -94,16 +195,26 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
for (size_t i = 0; i < start_node_idx.size(); ++i) {
|
||||
try {
|
||||
std::unordered_set<std::string> tmp_checked_ops;
|
||||
auto extracted_pattern = generate_model(nodes[i], tmp_checked_ops, extractor_name, is_copy_constants);
|
||||
auto extracted_model = std::get<0>(extracted_pattern);
|
||||
std::list<ExtractedPattern> secondary_patterns;
|
||||
if (nodes[i].size() > 20) {
|
||||
secondary_patterns = extract(std::get<0>(extracted_pattern), is_extract_body, is_copy_constants);
|
||||
}
|
||||
if (secondary_patterns.size() > 1) {
|
||||
to_cache.insert(to_cache.end(), secondary_patterns.begin(), secondary_patterns.end());
|
||||
// model, in_info, extractor_name
|
||||
ov::NodeVector nodes_vector(nodes[i].begin(), nodes[i].end());
|
||||
auto extracted_pattern = generate_model(nodes_vector, tmp_checked_ops, is_save_const, is_save_borders_only);
|
||||
auto extracted_model = extracted_pattern.first;
|
||||
if (is_recursive_extraction && nodes_vector.size() > 20) {
|
||||
auto secondary_patterns = find_repeat_patterns(extracted_model, is_save_borders_only);
|
||||
if (!secondary_patterns.empty()) {
|
||||
tmp_checked_ops.clear();
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else {
|
||||
update_extractor_cache(extracted_patterns,
|
||||
extracted_model,
|
||||
nodes_vector,
|
||||
extracted_pattern.second);
|
||||
}
|
||||
} else {
|
||||
to_cache.push_back(extracted_pattern);
|
||||
update_extractor_cache(extracted_patterns,
|
||||
extracted_model,
|
||||
nodes_vector,
|
||||
extracted_pattern.second);
|
||||
}
|
||||
nodes[i].clear();
|
||||
checked_ops.insert(tmp_checked_ops.begin(), tmp_checked_ops.end());
|
||||
@ -117,23 +228,39 @@ RepeatPatternExtractor::extract(const std::shared_ptr<ov::Model> &model,
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
|
||||
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(op);
|
||||
auto ti_body = ti->get_function();
|
||||
auto tmp_res = extract(ti_body);
|
||||
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
|
||||
auto secondary_patterns = find_repeat_patterns(ti_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(op)) {
|
||||
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(op);
|
||||
auto loop_body = loop->get_function();
|
||||
auto tmp_res = extract(loop_body);
|
||||
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
|
||||
auto secondary_patterns = find_repeat_patterns(loop_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(op)) {
|
||||
auto if_op = ov::as_type_ptr<ov::op::v8::If>(op);
|
||||
std::vector<std::shared_ptr<ov::Model>> bodies;
|
||||
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
|
||||
auto if_body = if_op->get_function(i);
|
||||
auto tmp_res = extract(if_body);
|
||||
to_cache.insert(to_cache.end(), tmp_res.begin(), tmp_res.end());
|
||||
auto secondary_patterns = find_repeat_patterns(if_body, is_save_borders_only);
|
||||
update_extractor_cache(extracted_patterns, secondary_patterns);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return to_cache;
|
||||
|
||||
// clean up patterns
|
||||
{
|
||||
auto it = extracted_patterns.begin();
|
||||
size_t elem_cnt = 0;
|
||||
while (it != extracted_patterns.end()) {
|
||||
if (it->size() > 1) {
|
||||
++it;
|
||||
++elem_cnt;
|
||||
} else {
|
||||
extracted_patterns.erase(it);
|
||||
it = extracted_patterns.begin();
|
||||
std::advance(it, elem_cnt);
|
||||
}
|
||||
}
|
||||
}
|
||||
return extracted_patterns;
|
||||
}
|
||||
|
@ -1,76 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
bool
|
||||
SubgraphExtractor::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const {
|
||||
bool res = comparator.compare(model, ref_model).valid;
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
|
||||
ref_ordered_ops = ref_model->get_ordered_ops();
|
||||
if (ordered_ops.size() != ref_ordered_ops.size()) {
|
||||
return false;
|
||||
}
|
||||
size_t matched_op_cnt = 0, total_op_cnt = ordered_ops.size();
|
||||
size_t matched_op_cnt_required = round(0.9 * total_op_cnt);
|
||||
for (size_t i = 0; i < total_op_cnt; ++i) {
|
||||
if (is_node_to_skip(ordered_ops[i]) &&
|
||||
is_node_to_skip(ref_ordered_ops[i]) ||
|
||||
m_manager.match(ordered_ops[i], ref_ordered_ops[i])) {
|
||||
++matched_op_cnt;
|
||||
}
|
||||
if (matched_op_cnt >= matched_op_cnt_required) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline SubgraphExtractor::IsSubgraphTuple prepare_is_subgraph_result(bool is_subgraph,
|
||||
const std::shared_ptr<ov::Model>& graph,
|
||||
const std::shared_ptr<ov::Model>& subgraph,
|
||||
const std::map<std::string, std::string>& matched_ops) {
|
||||
return is_subgraph ?
|
||||
std::make_tuple(is_subgraph, graph, subgraph, matched_ops) :
|
||||
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
|
||||
}
|
||||
|
||||
SubgraphExtractor::IsSubgraphTuple
|
||||
SubgraphExtractor::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const {
|
||||
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
|
||||
ref_ordered_ops = ref_model->get_ordered_ops();
|
||||
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
|
||||
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
|
||||
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
|
||||
if (is_model) {
|
||||
graph_to_check_ops = ordered_ops;
|
||||
subgraph_to_check_ops = ref_ordered_ops;
|
||||
graph = model;
|
||||
subgraph = ref_model;
|
||||
} else {
|
||||
graph_to_check_ops = ref_ordered_ops;
|
||||
subgraph_to_check_ops = ordered_ops;
|
||||
graph = ref_model;
|
||||
subgraph = model;
|
||||
}
|
||||
std::map<std::string, std::string> matched_op_names;
|
||||
|
||||
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
|
||||
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
|
||||
if (m_manager.match(*graph_it, *subgraph_it)) {
|
||||
matched_op_names.insert({ (*graph_it)->get_friendly_name(), (*subgraph_it)->get_friendly_name()});
|
||||
++subgraph_it;
|
||||
}
|
||||
++graph_it;
|
||||
}
|
||||
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), graph, subgraph, matched_op_names);
|
||||
}
|
@ -70,6 +70,27 @@ find_models(const std::vector<std::string> &dirs, const std::string& regexp) {
|
||||
return { models, { ModelCacheStatus::NOT_READ, not_read_model } };
|
||||
}
|
||||
|
||||
bool is_dynamic_model(const std::shared_ptr<ov::Model>& model) {
|
||||
for (const auto& parameter : model->get_parameters()) {
|
||||
if (is_dynamic_node(parameter)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto& result : model->get_results()) {
|
||||
if (is_dynamic_node(result)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string get_model_type(const std::shared_ptr<ov::Model>& model) {
|
||||
if (is_dynamic_model(model)) {
|
||||
return "dynamic";
|
||||
}
|
||||
return "static";
|
||||
}
|
||||
|
||||
std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
|
||||
std::shared_ptr<ICache>& cache,
|
||||
const std::vector<std::string>& models,
|
||||
@ -115,6 +136,78 @@ std::map<ModelCacheStatus, std::vector<std::string>> cache_models(
|
||||
return cache_status;
|
||||
}
|
||||
|
||||
std::map<std::string, InputInfo>
|
||||
get_input_info_by_model(const std::shared_ptr<ov::Model>& model) {
|
||||
std::map<std::string, InputInfo> in_info;
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
InputInfo::Range ranges(DEFAULT_MIN_VALUE, DEFAULT_MAX_VALUE);
|
||||
bool is_const = false;
|
||||
if (ov::op::util::is_constant(node)) {
|
||||
std::shared_ptr<ov::op::v0::Constant> constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(node);
|
||||
auto const_ranges = get_const_ranges(constant,
|
||||
constant->get_default_output().get_element_type());
|
||||
ranges = const_ranges;
|
||||
} else if (!ov::op::util::is_parameter(node)) {
|
||||
continue;
|
||||
}
|
||||
auto partial_shape = node->get_default_output().get_partial_shape();
|
||||
in_info.insert({node->get_friendly_name(),
|
||||
InputInfo(partial_shape, ranges.min, ranges.max, is_const)});
|
||||
}
|
||||
return in_info;
|
||||
}
|
||||
|
||||
std::map<std::string, InputInfo>
|
||||
align_input_info(const std::shared_ptr<ov::Model>& model,
|
||||
const std::shared_ptr<ov::Model>& model_ref,
|
||||
const std::map<std::string, InputInfo>& in_info,
|
||||
const std::map<std::string, InputInfo>& in_info_ref,
|
||||
const std::map<std::string, std::string> &matched_op) {
|
||||
bool is_update_required = !matched_op.empty();
|
||||
if (!is_update_required) {
|
||||
for (const auto& ref_item : in_info_ref) {
|
||||
if (!in_info.count(ref_item.first)) {
|
||||
is_update_required = true;
|
||||
break;
|
||||
} else if (in_info.at(ref_item.first).is_const != ref_item.second.is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, InputInfo> updated_input_info = in_info_ref;
|
||||
if (is_update_required) {
|
||||
// align matched model names
|
||||
const auto& ref_model_ops = model_ref->get_ordered_ops();
|
||||
const auto& model_ops = model->get_ordered_ops();
|
||||
size_t ref_ordered_ops_size = ref_model_ops.size();
|
||||
size_t ordered_ops_size = model_ops.size();
|
||||
if (ref_ordered_ops_size != ordered_ops_size && matched_op.empty()) {
|
||||
throw std::runtime_error("Matched models can not be compared according different op numbers!");
|
||||
}
|
||||
for (size_t i = 0; i < ordered_ops_size; ++i) {
|
||||
auto model_op_name = model_ops[i]->get_friendly_name();
|
||||
if (!in_info.count(model_op_name)) {
|
||||
continue;
|
||||
}
|
||||
if (!matched_op.empty()) {
|
||||
if (!matched_op.count(model_op_name)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto model_ref_op_name = matched_op.empty() ? ref_model_ops[i]->get_friendly_name() : matched_op.at(model_op_name);
|
||||
|
||||
const auto& in_info_item = in_info.at(model_op_name);
|
||||
const auto& ref_in_info_item = in_info_ref.at(model_ref_op_name);
|
||||
if (in_info_item.is_const != ref_in_info_item.is_const) {
|
||||
throw std::runtime_error("Impossible to update input info!!!");
|
||||
}
|
||||
updated_input_info[model_ref_op_name] = in_info_item;
|
||||
}
|
||||
}
|
||||
return updated_input_info;
|
||||
}
|
||||
|
||||
} // namespace subgraph_dumper
|
||||
} // namespace tools
|
||||
} // namespace ov
|
@ -0,0 +1,136 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "utils/model_comparator.hpp"
|
||||
#include "utils/model.hpp"
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
std::shared_ptr<ModelComparator> ModelComparator::m_instance = nullptr;
|
||||
|
||||
void ModelComparator::set_match_coefficient(float _match_coefficient) {
|
||||
if (_match_coefficient < 0 || _match_coefficient > 1) {
|
||||
throw std::runtime_error("[ ERROR ] Match coefficient should be from 0 to 1!");
|
||||
}
|
||||
match_coefficient = _match_coefficient;
|
||||
}
|
||||
|
||||
void ModelComparator::set_shape_strict_match(bool in_is_shape_strict_match) {
|
||||
m_manager.set_shape_strict_match(in_is_shape_strict_match);
|
||||
}
|
||||
|
||||
inline ModelComparator::IsSubgraphTuple
|
||||
prepare_is_subgraph_result(bool is_subgraph,
|
||||
const std::shared_ptr<ov::Model>& subgraph,
|
||||
const std::shared_ptr<ov::Model>& graph,
|
||||
const std::map<std::string, std::string>& matched_ops) {
|
||||
return is_subgraph ?
|
||||
std::make_tuple(is_subgraph, subgraph, graph, matched_ops) :
|
||||
std::make_tuple(is_subgraph, nullptr, nullptr, std::map<std::string, std::string>());
|
||||
}
|
||||
|
||||
ModelComparator::IsSubgraphTuple
|
||||
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const {
|
||||
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
|
||||
ref_ordered_ops = ref_model->get_ordered_ops();
|
||||
bool is_model = ordered_ops.size() > ref_ordered_ops.size();
|
||||
ov::NodeVector graph_to_check_ops, subgraph_to_check_ops;
|
||||
std::shared_ptr<ov::Model> graph = nullptr, subgraph = nullptr;
|
||||
if (is_model) {
|
||||
graph_to_check_ops = ordered_ops;
|
||||
subgraph_to_check_ops = ref_ordered_ops;
|
||||
graph = model;
|
||||
subgraph = ref_model;
|
||||
} else {
|
||||
graph_to_check_ops = ref_ordered_ops;
|
||||
subgraph_to_check_ops = ordered_ops;
|
||||
graph = ref_model;
|
||||
subgraph = model;
|
||||
}
|
||||
std::map<std::string, std::string> matched_op_names;
|
||||
|
||||
auto graph_it = graph_to_check_ops.begin(), subgraph_it = subgraph_to_check_ops.begin();
|
||||
while (graph_it != graph_to_check_ops.end() && subgraph_it != subgraph_to_check_ops.end()) {
|
||||
if (m_manager.match(*graph_it, *subgraph_it)) {
|
||||
matched_op_names.insert({ (*subgraph_it)->get_friendly_name(), (*graph_it)->get_friendly_name()});
|
||||
++subgraph_it;
|
||||
}
|
||||
++graph_it;
|
||||
}
|
||||
return prepare_is_subgraph_result(subgraph_it == subgraph_to_check_ops.end(), subgraph, graph, matched_op_names);
|
||||
}
|
||||
|
||||
bool
|
||||
ModelComparator::match(const std::shared_ptr<ov::Node> &node,
|
||||
const std::shared_ptr<ov::Node> &ref_node) const {
|
||||
return m_manager.match(node, ref_node);
|
||||
}
|
||||
|
||||
bool
|
||||
ModelComparator::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model) const {
|
||||
std::vector<std::shared_ptr<ov::Node>> ordered_ops = model->get_ordered_ops(),
|
||||
ref_ordered_ops = ref_model->get_ordered_ops();
|
||||
if (ordered_ops.size() != ref_ordered_ops.size()) {
|
||||
return false;
|
||||
}
|
||||
size_t matched_op_cnt = 0, total_op_cnt = ordered_ops.size();
|
||||
size_t matched_op_cnt_required = round(match_coefficient * total_op_cnt);
|
||||
for (size_t i = 0; i < total_op_cnt; ++i) {
|
||||
if (m_manager.match(ordered_ops[i], ref_ordered_ops[i])) {
|
||||
++matched_op_cnt;
|
||||
}
|
||||
if (matched_op_cnt >= matched_op_cnt_required) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ModelComparator::ExtractedSubgraphTuple
|
||||
ModelComparator::is_subgraph(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &ref_model,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
auto extractor_res = is_subgraph(model, ref_model);
|
||||
if (std::get<0>(extractor_res)) {
|
||||
std::map<std::string, InputInfo> graph_in_info, subgraph_in_info;
|
||||
std::shared_ptr<ov::Model> subgraph = nullptr, graph = nullptr;
|
||||
// if (model == subgraph && ref_model == graph)
|
||||
if (std::get<1>(extractor_res) == model && std::get<2>(extractor_res) == ref_model) {
|
||||
subgraph = model;
|
||||
subgraph_in_info = in_info;
|
||||
graph = ref_model;
|
||||
graph_in_info = in_info_ref;
|
||||
// else if (subgraph == ref_model && graph = model)
|
||||
} else if (std::get<1>(extractor_res) == ref_model && std::get<2>(extractor_res) == model) {
|
||||
subgraph = ref_model;
|
||||
subgraph_in_info = in_info_ref;
|
||||
graph = model;
|
||||
graph_in_info = in_info;
|
||||
} else {
|
||||
throw std::runtime_error("Generated models are incompatible with original ones!");
|
||||
}
|
||||
try {
|
||||
subgraph_in_info = align_input_info(subgraph, graph, subgraph_in_info, graph_in_info);
|
||||
return { true, subgraph, graph, subgraph_in_info, graph_in_info };
|
||||
} catch(std::exception) {}
|
||||
}
|
||||
return { false, nullptr, nullptr, {}, {} };
|
||||
}
|
||||
|
||||
std::pair<bool, std::map<std::string, InputInfo>>
|
||||
ModelComparator::match(const std::shared_ptr<ov::Model> &model,
|
||||
const std::shared_ptr<ov::Model> &model_ref,
|
||||
const std::map<std::string, InputInfo> &in_info,
|
||||
const std::map<std::string, InputInfo> &in_info_ref) {
|
||||
try {
|
||||
if (match(model, model_ref)) {
|
||||
auto new_input_info = align_input_info(model, model_ref, in_info, in_info_ref);
|
||||
return {true, new_input_info};
|
||||
}
|
||||
} catch (std::exception) {}
|
||||
return {false, {}};
|
||||
}
|
@ -7,6 +7,73 @@ namespace ov {
|
||||
namespace tools {
|
||||
namespace subgraph_dumper {
|
||||
|
||||
InputInfo::Range get_const_ranges(const std::shared_ptr<ov::op::v0::Constant>& const_node,
|
||||
ov::element::Type elem_type) {
|
||||
InputInfo::Range ranges(DEFAULT_MIN_VALUE, DEFAULT_MAX_VALUE);
|
||||
switch (elem_type) {
|
||||
case ov::element::Type_t::boolean: {
|
||||
ranges = get_const_ranges<bool>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::bf16: {
|
||||
ranges = get_const_ranges<ov::bfloat16>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f16: {
|
||||
ranges = get_const_ranges<ov::float16>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f32: {
|
||||
ranges = get_const_ranges<float>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f64: {
|
||||
ranges = get_const_ranges<double>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i8: {
|
||||
ranges = get_const_ranges<int8_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i16: {
|
||||
ranges = get_const_ranges<int16_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i32: {
|
||||
ranges = get_const_ranges<int32_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i64: {
|
||||
ranges = get_const_ranges<int64_t>(const_node);
|
||||
break;
|
||||
}
|
||||
// TODO cast_vector doesn't support u1 now
|
||||
// case ov::element::Type_t::u1:
|
||||
// return get_const_ranges<char>(const_node);
|
||||
case ov::element::Type_t::u8: {
|
||||
ranges = get_const_ranges<uint8_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u16: {
|
||||
ranges = get_const_ranges<uint16_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u32: {
|
||||
ranges = get_const_ranges<uint32_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u64: {
|
||||
ranges = get_const_ranges<uint64_t>(const_node);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::cout << "Can't get ranges.. Unsupported data type" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ranges;
|
||||
}
|
||||
|
||||
std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov::Node>& node) {
|
||||
std::map<std::string, InputInfo> input_info;
|
||||
for (size_t port_id = 0; port_id < node->get_input_size(); ++port_id) {
|
||||
@ -19,71 +86,12 @@ std::map<std::string, InputInfo> get_input_info_by_node(const std::shared_ptr<ov
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::Constant>(input_node)) {
|
||||
if (ov::shape_size(input_node->get_output_shape(0)) == 0)
|
||||
continue;
|
||||
auto const_node =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(input_node);
|
||||
auto const_node = ov::as_type_ptr<ov::op::v0::Constant>(input_node);
|
||||
in_info.is_const = true;
|
||||
switch (node->get_output_element_type(0)) {
|
||||
case ov::element::Type_t::boolean: {
|
||||
in_info.ranges = get_const_ranges<bool>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::bf16: {
|
||||
in_info.ranges = get_const_ranges<ov::bfloat16>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f16: {
|
||||
in_info.ranges = get_const_ranges<ov::float16>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f32: {
|
||||
in_info.ranges = get_const_ranges<float>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::f64: {
|
||||
in_info.ranges = get_const_ranges<double>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i8: {
|
||||
in_info.ranges = get_const_ranges<int8_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i16: {
|
||||
in_info.ranges = get_const_ranges<int16_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i32: {
|
||||
in_info.ranges = get_const_ranges<int32_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i64: {
|
||||
in_info.ranges = get_const_ranges<int64_t>(const_node);
|
||||
break;
|
||||
}
|
||||
// TODO cast_vector doesn't support u1 now
|
||||
// case ov::element::Type_t::u1:
|
||||
// return get_const_ranges<char>(const_node);
|
||||
case ov::element::Type_t::u8: {
|
||||
in_info.ranges = get_const_ranges<uint8_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u16: {
|
||||
in_info.ranges = get_const_ranges<uint16_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u32: {
|
||||
in_info.ranges = get_const_ranges<uint32_t>(const_node);
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::u64: {
|
||||
in_info.ranges = get_const_ranges<uint64_t>(const_node);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::cout << "Can't get ranges.. Unsupported data type" << std::endl;
|
||||
break;
|
||||
}}
|
||||
in_info.ranges = get_const_ranges(const_node,
|
||||
const_node->get_default_output().get_element_type());
|
||||
}
|
||||
input_info.insert({ input_name, in_info });
|
||||
input_info.insert({input_name, in_info});
|
||||
}
|
||||
return input_info;
|
||||
}
|
||||
@ -128,9 +136,10 @@ std::shared_ptr<ov::Node> clone_node(std::shared_ptr<ov::Node> node,
|
||||
std::shared_ptr<ov::Node> cloned_node = nullptr;
|
||||
if (!has_parameters && !is_copy_const_node && !inputs.empty()) {
|
||||
cloned_node = clone_node(node, true, true, node_name);
|
||||
// std::cout << "The operation: " + node->get_friendly_name() + " does not have parameters! Replace first input to parameter!" << std::endl;
|
||||
auto param =
|
||||
std::make_shared<ov::op::v0::Parameter>(cloned_node->get_input_element_type(0), cloned_node->get_input_partial_shape(0));
|
||||
// std::cout << "The operation: " + node->get_friendly_name() + " does not have parameters! Replace first input
|
||||
// to parameter!" << std::endl;
|
||||
auto param = std::make_shared<ov::op::v0::Parameter>(cloned_node->get_input_element_type(0),
|
||||
cloned_node->get_input_partial_shape(0));
|
||||
std::string param_name = node_name + "_0";
|
||||
param->set_friendly_name(param_name);
|
||||
auto node_to_replace = cloned_node->get_input_node_shared_ptr(0);
|
||||
@ -142,10 +151,11 @@ std::shared_ptr<ov::Node> clone_node(std::shared_ptr<ov::Node> node,
|
||||
return cloned_node;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::op::v0::Parameter> convert_const_to_param(const std::shared_ptr<ov::op::v0::Constant>& op_to_replace) {
|
||||
std::shared_ptr<ov::op::v0::Parameter> convert_const_to_param(
|
||||
const std::shared_ptr<ov::op::v0::Constant>& op_to_replace) {
|
||||
if (op_to_replace->get_byte_size() > 1024) {
|
||||
auto param = std::make_shared<ov::op::v0::Parameter>(
|
||||
op_to_replace->get_output_element_type(0), op_to_replace->get_output_partial_shape(0));
|
||||
auto param = std::make_shared<ov::op::v0::Parameter>(op_to_replace->get_output_element_type(0),
|
||||
op_to_replace->get_output_partial_shape(0));
|
||||
param->set_friendly_name(op_to_replace->get_friendly_name());
|
||||
if (param != nullptr) {
|
||||
ov::replace_node(op_to_replace, param);
|
||||
|
@ -93,4 +93,22 @@ TEST_F(ICacheUnitTest, serialize_model) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ICacheUnitTest, is_model_large_to_read) {
|
||||
this->mem_size = 0;
|
||||
ASSERT_NO_THROW(this->is_model_large_to_read(test_model, test_model_path));
|
||||
ASSERT_TRUE(this->is_model_large_to_read(test_model, test_model_path));
|
||||
this->mem_size = 1 << 30;
|
||||
ASSERT_NO_THROW(this->is_model_large_to_read(test_model, test_model_path));
|
||||
ASSERT_FALSE(this->is_model_large_to_read(test_model, test_model_path));
|
||||
}
|
||||
|
||||
TEST_F(ICacheUnitTest, is_model_large_to_store_const) {
|
||||
this->mem_size = 0;
|
||||
ASSERT_NO_THROW(this->is_model_large_to_store_const(test_model));
|
||||
ASSERT_TRUE(this->is_model_large_to_store_const(test_model));
|
||||
this->mem_size = 1 << 30;
|
||||
ASSERT_NO_THROW(this->is_model_large_to_store_const(test_model));
|
||||
ASSERT_FALSE(this->is_model_large_to_store_const(test_model));
|
||||
}
|
||||
|
||||
} // namespace
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "matchers/subgraph/fused_names.hpp"
|
||||
#include "utils/model.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
|
||||
#include "test_models/model_0.hpp"
|
||||
#include "test_models/model_1.hpp"
|
||||
@ -32,7 +33,7 @@ protected:
|
||||
auto it_model_2 = models_2.begin();
|
||||
while (it_model_1 != models_1.end() || it_model_2 != models_2.end()) {
|
||||
SubgraphExtractor extractor;
|
||||
ASSERT_TRUE(extractor.match(std::get<0>(*it_model_1), std::get<0>(*it_model_2)));
|
||||
ASSERT_TRUE(ModelComparator::get()->match(std::get<0>(*it_model_1), std::get<0>(*it_model_2)));
|
||||
auto in_info_1 = std::get<1>(*it_model_1);
|
||||
auto in_info_2 = std::get<1>(*it_model_2);
|
||||
for (const auto& in_info : in_info_1) {
|
||||
|
@ -35,31 +35,10 @@ protected:
|
||||
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
test_parameter->set_friendly_name("test_parameter_1");
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs =
|
||||
std::make_shared<ov::op::v0::Relu>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
}
|
||||
|
||||
ExtractorsManager::ExtractorsMap test_map;
|
||||
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
|
||||
std::shared_ptr<ov::Model> test_model_0_0;
|
||||
};
|
||||
|
||||
TEST_F(ExtractorsManagerTest, constructor) {
|
||||
@ -78,57 +57,9 @@ TEST_F(ExtractorsManagerTest, get_extractors) {
|
||||
ASSERT_EQ(this->m_extractors, this->get_extractors());
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, match) {
|
||||
this->set_extractors(test_map);
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, is_subgraph) {
|
||||
this->set_extractors(test_map);
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_1));
|
||||
auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_1);
|
||||
ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_0_0, test_model_1)));
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_0_1, test_model_1)));
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, match_with_in_info) {
|
||||
this->set_extractors(test_map);
|
||||
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_1({{"test_parameter_1", InputInfo({}, 1, 2, true)}});
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
|
||||
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
|
||||
ASSERT_FALSE(this->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, extract) {
|
||||
this->set_extractors(test_map);
|
||||
ASSERT_NO_THROW(this->extract(test_model_0_0));
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, align_input_info) {
|
||||
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_ref({{"test_parameter_1", InputInfo()}});
|
||||
ASSERT_NE(test_in_info, test_in_info_ref);
|
||||
ASSERT_NO_THROW(this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref));
|
||||
auto c = this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref);
|
||||
ASSERT_EQ(c, test_in_info_ref);
|
||||
}
|
||||
|
||||
TEST_F(ExtractorsManagerTest, align_input_info_for_subgraphs) {
|
||||
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}), test_in_info_ref({{"test_parameter_1", InputInfo()}});
|
||||
ASSERT_NE(test_in_info, test_in_info_ref);
|
||||
ASSERT_NO_THROW(this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref, {{"test_parameter_0", "test_parameter_1"}}));
|
||||
auto c = this->align_input_info(test_model_0_0, test_model_0_1, test_in_info, test_in_info_ref);
|
||||
ASSERT_EQ(c, test_in_info_ref);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "matchers/subgraph/repeat_pattern.hpp"
|
||||
#include "utils/model.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
|
||||
#include "base_test.hpp"
|
||||
#include "test_models/model_0.hpp"
|
||||
@ -22,13 +23,13 @@ class RepeatPatternExtractorTest : public SubgraphsDumperBaseTest {
|
||||
protected:
|
||||
RepeatPatternExtractor extractor;
|
||||
|
||||
bool is_match(const std::list<ExtractedPattern>& models,
|
||||
bool is_match(const std::vector<ExtractedPattern>& models,
|
||||
const std::vector<std::shared_ptr<ov::Model>>& ref_models) {
|
||||
size_t match_numbers = 0;
|
||||
for (const auto& model : models) {
|
||||
bool is_match = false;
|
||||
for (const auto& ref_model : ref_models) {
|
||||
if (extractor.match(std::get<0>(model), ref_model)) {
|
||||
if (ModelComparator::get()->match(std::get<0>(model), ref_model)) {
|
||||
is_match = true;
|
||||
++match_numbers;
|
||||
break;
|
||||
@ -40,6 +41,28 @@ protected:
|
||||
}
|
||||
return match_numbers == models.size();
|
||||
}
|
||||
|
||||
void sort_node_vec(std::vector<std::vector<ov::NodeVector>>& pattern_vec) {
|
||||
for (auto& pattern : pattern_vec) {
|
||||
for (auto& node_vec : pattern) {
|
||||
std::sort(node_vec.begin(), node_vec.end());
|
||||
}
|
||||
std::sort(pattern.begin(), pattern.end());
|
||||
}
|
||||
std::sort(pattern_vec.begin(), pattern_vec.end());
|
||||
}
|
||||
|
||||
// not allowed to sort inputs/outputs according there are not copy constructor
|
||||
// void sort_borders(std::vector<std::vector<RepeatPatternExtractor::PatternBorders>>& pattern_vec) {
|
||||
// for (auto& pattern : pattern_vec) {
|
||||
// for (auto& node_vec : pattern) {
|
||||
// std::sort(node_vec.first.begin(), node_vec.first.end());
|
||||
// std::sort(node_vec.second.begin(), node_vec.second.end());
|
||||
// }
|
||||
// std::sort(pattern.begin(), pattern.end());
|
||||
// }
|
||||
// std::sort(pattern_vec.begin(), pattern_vec.end());
|
||||
// }
|
||||
};
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, extract_0) {
|
||||
@ -63,4 +86,59 @@ TEST_F(RepeatPatternExtractorTest, extract_2) {
|
||||
ASSERT_TRUE(is_match(models, ref));
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_0) {
|
||||
auto test_model = Model_0();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
sort_node_vec(node_vector);
|
||||
sort_node_vec(ref);
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_1) {
|
||||
auto test_model = Model_1();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
sort_node_vec(node_vector);
|
||||
sort_node_vec(ref);
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_node_vectors_model_2) {
|
||||
auto test_model = Model_2();
|
||||
auto node_vector = extractor.get_repeat_node_vectors(test_model.get());
|
||||
auto ref = test_model.get_ref_node_vector();
|
||||
sort_node_vec(node_vector);
|
||||
sort_node_vec(ref);
|
||||
ASSERT_EQ(node_vector, ref);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_0) {
|
||||
auto test_model = Model_0();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_1) {
|
||||
auto test_model = Model_1();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
TEST_F(RepeatPatternExtractorTest, get_repeat_pattern_borders_model_2) {
|
||||
auto test_model = Model_2();
|
||||
auto extracted_borders = extractor.get_repeat_pattern_borders(test_model.get());
|
||||
auto ref_borders = test_model.get_ref_node_borders();
|
||||
// sort_borders(extracted_borders);
|
||||
// sort_borders(ref_borders);
|
||||
ASSERT_EQ(extracted_borders, ref_borders);
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
@ -1,209 +1,209 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// // Copyright (C) 2018-2023 Intel Corporation
|
||||
// // SPDX-License-Identifier: Apache-2.0
|
||||
// //
|
||||
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
#include "base_test.hpp"
|
||||
// #include "matchers/subgraph/subgraph.hpp"
|
||||
// #include "base_test.hpp"
|
||||
|
||||
#include "openvino/op/abs.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
// #include "openvino/op/abs.hpp"
|
||||
// #include "openvino/op/relu.hpp"
|
||||
// #include "openvino/op/parameter.hpp"
|
||||
// #include "openvino/op/result.hpp"
|
||||
|
||||
namespace {
|
||||
// namespace {
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
// using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
// ======================= ExtractorsManagerTest Unit tests =======================
|
||||
class SubgraphExtractorTest : public SubgraphExtractor,
|
||||
public SubgraphsDumperBaseTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
SubgraphsDumperBaseTest::SetUp();
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs =
|
||||
std::make_shared<ov::op::v0::Relu>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
}
|
||||
// // ======================= ExtractorsManagerTest Unit tests =======================
|
||||
// class SubgraphExtractorTest : public SubgraphExtractor,
|
||||
// public SubgraphsDumperBaseTest {
|
||||
// protected:
|
||||
// void SetUp() override {
|
||||
// SubgraphsDumperBaseTest::SetUp();
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
// test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
// test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
// test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// }
|
||||
|
||||
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
|
||||
};
|
||||
// std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
|
||||
// };
|
||||
|
||||
TEST_F(SubgraphExtractorTest, match) {
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
|
||||
}
|
||||
// TEST_F(SubgraphExtractorTest, match) {
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
|
||||
// ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
|
||||
// ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
|
||||
// ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
|
||||
// }
|
||||
|
||||
TEST_F(SubgraphExtractorTest, match_90_percent) {
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_2);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_4);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_7);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_9);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_2);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_4);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_7);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs_10 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_abs_9);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs_1 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_abs_0);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs_3 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_abs_2);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs_5 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_abs_4);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs_8 =
|
||||
std::make_shared<ov::op::v0::Relu>(test_abs_7);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_9);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
|
||||
ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
|
||||
}
|
||||
// TEST_F(SubgraphExtractorTest, match_90_percent) {
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_2);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_4);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_7);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_9);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
// test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_3 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_2);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_5 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_4);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_8 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_7);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs_10 =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_abs_9);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
// test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs_1 =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_abs_0);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_2 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_1);
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs_3 =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_abs_2);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_4 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_3);
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs_5 =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_abs_4);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_6 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_5);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_7 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_6);
|
||||
// std::shared_ptr<ov::op::v0::Relu> test_abs_8 =
|
||||
// std::make_shared<ov::op::v0::Relu>(test_abs_7);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_9 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_8);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_10 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_9);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs_10);
|
||||
// test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// }
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_0_1));
|
||||
// ASSERT_TRUE(this->match(test_model_0_0, test_model_0_1));
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_0, test_model_1));
|
||||
// ASSERT_FALSE(this->match(test_model_0_0, test_model_1));
|
||||
// ASSERT_NO_THROW(this->match(test_model_0_1, test_model_1));
|
||||
// ASSERT_FALSE(this->match(test_model_0_1, test_model_1));
|
||||
// }
|
||||
|
||||
TEST_F(SubgraphExtractorTest, extract) {
|
||||
ASSERT_NO_THROW(this->extract(test_model_0_0));
|
||||
ASSERT_NO_THROW(this->extract(test_model_0_1));
|
||||
ASSERT_NO_THROW(this->extract(test_model_1));
|
||||
}
|
||||
// TEST_F(SubgraphExtractorTest, extract) {
|
||||
// ASSERT_NO_THROW(this->extract(test_model_0_0));
|
||||
// ASSERT_NO_THROW(this->extract(test_model_0_1));
|
||||
// ASSERT_NO_THROW(this->extract(test_model_1));
|
||||
// }
|
||||
|
||||
TEST_F(SubgraphExtractorTest, is_subgraph) {
|
||||
auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_0);
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_0));
|
||||
ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
|
||||
is_subgraph = this->is_subgraph(test_model_0_0, test_model_1);
|
||||
ASSERT_FALSE(std::get<0>(is_subgraph));
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
|
||||
is_subgraph = this->is_subgraph(test_model_0_1, test_model_1);
|
||||
ASSERT_FALSE(std::get<0>(is_subgraph));
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs_1);
|
||||
auto big_model_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
is_subgraph = this->is_subgraph(test_model_0_0, big_model_0);
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, big_model_0));
|
||||
ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
|
||||
ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_0);
|
||||
// TEST_F(SubgraphExtractorTest, is_subgraph) {
|
||||
// auto is_subgraph = this->is_subgraph(test_model_0_0, test_model_0_0);
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_0_0));
|
||||
// ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, test_model_1));
|
||||
// is_subgraph = this->is_subgraph(test_model_0_0, test_model_1);
|
||||
// ASSERT_FALSE(std::get<0>(is_subgraph));
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, test_model_1));
|
||||
// is_subgraph = this->is_subgraph(test_model_0_1, test_model_1);
|
||||
// ASSERT_FALSE(std::get<0>(is_subgraph));
|
||||
// {
|
||||
// std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
// std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_0 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
// std::shared_ptr<ov::op::v0::Abs> test_abs_1 =
|
||||
// std::make_shared<ov::op::v0::Abs>(test_abs_0);
|
||||
// std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
// std::make_shared<ov::op::v0::Result>(test_abs_1);
|
||||
// auto big_model_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
// ov::ParameterVector{test_parameter});
|
||||
// is_subgraph = this->is_subgraph(test_model_0_0, big_model_0);
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_0, big_model_0));
|
||||
// ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
// ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
|
||||
// ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_0);
|
||||
|
||||
is_subgraph = this->is_subgraph(test_model_0_1, big_model_0);
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, big_model_0));
|
||||
ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
|
||||
ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_1);
|
||||
ASSERT_NO_THROW(this->is_subgraph(test_model_1, big_model_0));
|
||||
ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_1, big_model_0)));
|
||||
}
|
||||
}
|
||||
// is_subgraph = this->is_subgraph(test_model_0_1, big_model_0);
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_0_1, big_model_0));
|
||||
// ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
// ASSERT_EQ(std::get<1>(is_subgraph), big_model_0);
|
||||
// ASSERT_EQ(std::get<2>(is_subgraph), test_model_0_1);
|
||||
// ASSERT_NO_THROW(this->is_subgraph(test_model_1, big_model_0));
|
||||
// ASSERT_FALSE(std::get<0>(this->is_subgraph(test_model_1, big_model_0)));
|
||||
// }
|
||||
// }
|
||||
|
||||
} // namespace
|
||||
// } // namespace
|
||||
|
@ -11,8 +11,12 @@
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "matchers/subgraph/repeat_pattern.hpp"
|
||||
|
||||
class Model_0 {
|
||||
private:
|
||||
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
|
||||
|
||||
public:
|
||||
Model_0() {
|
||||
// param param
|
||||
@ -48,6 +52,13 @@ public:
|
||||
std::make_shared<ov::op::v0::Result>(test_add_0);
|
||||
model = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter_0, test_parameter_1});
|
||||
ref_nodes = {{{test_abs_0, test_relu_0}, {test_abs_1, test_relu_1}}};
|
||||
{
|
||||
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
|
||||
ref_pattern_1 = {test_abs_1->inputs(), test_relu_1->outputs()};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_1}};
|
||||
ref_borders = std::move(ref_res);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> get() {
|
||||
@ -72,6 +83,14 @@ public:
|
||||
return ref;
|
||||
}
|
||||
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_ref_node_vector() { return ref_nodes; }
|
||||
|
||||
std::vector<std::vector<PatternBorders>>
|
||||
get_ref_node_borders() { return ref_borders; }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> model;
|
||||
std::vector<std::vector<ov::NodeVector>> ref_nodes;
|
||||
std::vector<std::vector<PatternBorders>> ref_borders;
|
||||
};
|
||||
|
@ -12,8 +12,12 @@
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "matchers/subgraph/repeat_pattern.hpp"
|
||||
|
||||
class Model_1 {
|
||||
private:
|
||||
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
|
||||
|
||||
public:
|
||||
Model_1() {
|
||||
// param param param param
|
||||
@ -119,6 +123,22 @@ public:
|
||||
ov::ParameterVector{test_parameter_0, test_parameter_1,
|
||||
test_parameter_0_0, test_parameter_0_1,
|
||||
test_parameter_1_0, test_parameter_1_1});
|
||||
|
||||
ref_nodes = {{{test_abs_0, test_relu_0}, {test_abs_0_0, test_relu_0_0}},
|
||||
{{test_abs_1, test_clamp_1}, {test_abs_0_1, test_clamp_0_1}},
|
||||
{{test_multiply_0_1, test_relu_0_1}, {test_multiply_1_1, test_relu_1_1}}};
|
||||
{
|
||||
PatternBorders ref_pattern_0 = {test_abs_0->inputs(), test_relu_0->outputs()},
|
||||
ref_pattern_0_0 = {test_abs_0_0->inputs(), test_relu_0_0->outputs()},
|
||||
ref_pattern_1 = {test_abs_1->inputs(), test_clamp_1->outputs()},
|
||||
ref_pattern_0_1_0 = {test_abs_0_1->inputs(), test_clamp_0_1->outputs()},
|
||||
test_pattern_0_1_1 = {test_multiply_0_1->inputs(), test_relu_0_1->outputs()},
|
||||
test_pattern_1_1 = {test_multiply_1_1->inputs(), test_relu_1_1->outputs()};
|
||||
std::vector<std::vector<PatternBorders>> ref_res = {{ref_pattern_0, ref_pattern_0_0},
|
||||
{ref_pattern_1, ref_pattern_0_1_0},
|
||||
{test_pattern_0_1_1, test_pattern_1_1}};
|
||||
ref_borders = std::move(ref_res);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> get() {
|
||||
@ -166,10 +186,19 @@ public:
|
||||
std::make_shared<ov::op::v0::Result>(test_relu_1);
|
||||
auto ref_model = std::make_shared<ov::Model>(ov::ResultVector{res},
|
||||
ov::ParameterVector{test_parameter_1_0, test_parameter_1_1});
|
||||
ref.push_back(ref_model);
|
||||
}
|
||||
return ref;
|
||||
}
|
||||
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_ref_node_vector() { return ref_nodes; }
|
||||
|
||||
std::vector<std::vector<PatternBorders>>
|
||||
get_ref_node_borders() { return ref_borders; }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> model;
|
||||
std::vector<std::vector<ov::NodeVector>> ref_nodes;
|
||||
std::vector<std::vector<PatternBorders>> ref_borders;
|
||||
};
|
||||
|
@ -13,6 +13,9 @@
|
||||
#include "openvino/op/result.hpp"
|
||||
|
||||
class Model_2 {
|
||||
private:
|
||||
using PatternBorders = ov::tools::subgraph_dumper::RepeatPatternExtractor::PatternBorders;
|
||||
|
||||
public:
|
||||
Model_2() {
|
||||
// param
|
||||
@ -55,9 +58,17 @@ public:
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ov::Model>> get_repeat_pattern_ref() {
|
||||
return {};
|
||||
return std::vector<std::shared_ptr<ov::Model>>();
|
||||
}
|
||||
|
||||
std::vector<std::vector<ov::NodeVector>>
|
||||
get_ref_node_vector() { return ref_nodes; }
|
||||
|
||||
std::vector<std::vector<PatternBorders>>
|
||||
get_ref_node_borders() { return ref_borders; }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> model;
|
||||
std::vector<std::vector<ov::NodeVector>> ref_nodes;
|
||||
std::vector<std::vector<PatternBorders>> ref_borders;
|
||||
};
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "utils/model.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
#include "test_models/model_0.hpp"
|
||||
#include "test_models/model_1.hpp"
|
||||
@ -16,11 +17,11 @@ using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
using ModelUtilsTest = SubgraphsDumperBaseTest;
|
||||
|
||||
std::set<std::shared_ptr<ov::Node>>
|
||||
ov::NodeVector
|
||||
get_functional_ops(const std::shared_ptr<ov::Model>& model) {
|
||||
std::set<std::shared_ptr<ov::Node>> nodes;
|
||||
std::vector<std::shared_ptr<ov::Node>> nodes;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
nodes.insert(op);
|
||||
nodes.push_back(op);
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
@ -31,12 +32,11 @@ TEST_F(ModelUtilsTest, generate_0) {
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops, "test_extractor");
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
}
|
||||
{
|
||||
SubgraphExtractor extractor;
|
||||
ASSERT_TRUE(extractor.match(test_model, recovered_model));
|
||||
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,12 +46,11 @@ TEST_F(ModelUtilsTest, generate_1) {
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops, "test_extractor");
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
}
|
||||
{
|
||||
SubgraphExtractor extractor;
|
||||
ASSERT_TRUE(extractor.match(test_model, recovered_model));
|
||||
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,14 +60,59 @@ TEST_F(ModelUtilsTest, generate_2) {
|
||||
{
|
||||
std::unordered_set<std::string> checked_ops;
|
||||
auto func_ops = get_functional_ops(test_model);
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops, "extract_model");
|
||||
auto model_with_in_info = generate_model(func_ops, checked_ops);
|
||||
recovered_model = std::get<0>(model_with_in_info);
|
||||
auto in_info = std::get<1>(model_with_in_info);
|
||||
}
|
||||
{
|
||||
SubgraphExtractor extractor;
|
||||
ASSERT_TRUE(extractor.match(test_model, recovered_model));
|
||||
ASSERT_TRUE(ModelComparator::get()->match(test_model, recovered_model));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModelUtilsTest, align_input_info) {
|
||||
Model_0 test_model_0, test_model_1;
|
||||
auto in_info_0 = get_input_info_by_model(test_model_0.get());
|
||||
auto in_info_1 = get_input_info_by_model(test_model_1.get());
|
||||
ASSERT_NE(in_info_0, in_info_1);
|
||||
ASSERT_NO_THROW(align_input_info(test_model_0.get(), test_model_1.get(), in_info_0, in_info_1));
|
||||
auto in_info_ref = align_input_info(test_model_0.get(), test_model_1.get(), in_info_0, in_info_1);
|
||||
ASSERT_EQ(in_info_1, in_info_ref);
|
||||
}
|
||||
|
||||
TEST_F(ModelUtilsTest, align_input_info_for_subgraphs) {
|
||||
Model_0 model_0, model_1;
|
||||
auto test_model_0 = model_0.get();
|
||||
auto test_model_1 = model_1.get();
|
||||
auto in_info_0 = get_input_info_by_model(test_model_0);
|
||||
auto in_info_1 = get_input_info_by_model(test_model_1);
|
||||
ASSERT_NE(in_info_0, in_info_1);
|
||||
std::map<std::string, std::string> matched_ops;
|
||||
auto params_0 = test_model_0->get_parameters();
|
||||
auto params_1 = test_model_1->get_parameters();
|
||||
size_t params_cnt = params_0.size();
|
||||
for (size_t param_id = 0; param_id < params_cnt; ++param_id) {
|
||||
matched_ops.insert({params_0[param_id]->get_friendly_name(),
|
||||
params_1[param_id]->get_friendly_name()});
|
||||
}
|
||||
ASSERT_NO_THROW(align_input_info(test_model_0, test_model_1,
|
||||
in_info_0, in_info_1,
|
||||
matched_ops));
|
||||
auto ref = align_input_info(test_model_0, test_model_1, in_info_0, in_info_1, matched_ops);
|
||||
ASSERT_EQ(in_info_1, ref);
|
||||
}
|
||||
|
||||
TEST_F(ModelUtilsTest, get_input_info_by_model) {
|
||||
Model_1 model;
|
||||
auto test_model = model.get();
|
||||
size_t param_idx = 0;
|
||||
std::map<std::string, InputInfo> ref;
|
||||
for (auto& param : test_model->get_parameters()) {
|
||||
std::string param_name = "parameter_" + std::to_string(param_idx++);
|
||||
param->set_friendly_name(param_name);
|
||||
ref.insert({param_name, InputInfo(param->get_default_output().get_partial_shape())});
|
||||
}
|
||||
auto cur = get_input_info_by_model(test_model);
|
||||
ASSERT_EQ(cur, ref);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,137 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "matchers/subgraph/subgraph.hpp"
|
||||
#include "utils/model_comparator.hpp"
|
||||
#include "base_test.hpp"
|
||||
|
||||
#include "openvino/op/abs.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ov::tools::subgraph_dumper;
|
||||
|
||||
// ======================= ExtractorsManagerTest Unit tests =======================
|
||||
class ModelComparatorTest : public SubgraphsDumperBaseTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
SubgraphsDumperBaseTest::SetUp();
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
test_parameter->set_friendly_name("test_parameter_0");
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_0 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
test_parameter->set_friendly_name("test_parameter_1");
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 5});
|
||||
std::shared_ptr<ov::op::v0::Relu> test_abs =
|
||||
std::make_shared<ov::op::v0::Relu>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
model_comparator->set_shape_strict_match(false);
|
||||
model_comparator->set_match_coefficient(0.9f);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> test_model_0_0, test_model_0_1, test_model_1;
|
||||
};
|
||||
|
||||
TEST_F(ModelComparatorTest, get) {
|
||||
ModelComparator::Ptr model_comparator = nullptr;
|
||||
ASSERT_NO_THROW(model_comparator = ModelComparator::get());
|
||||
ASSERT_EQ(model_comparator, ModelComparator::get());
|
||||
}
|
||||
|
||||
TEST_F(ModelComparatorTest, match) {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(model_comparator->match(test_model_0_0, test_model_1));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(model_comparator->match(test_model_0_1, test_model_1));
|
||||
}
|
||||
|
||||
TEST_F(ModelComparatorTest, match_strict_shape) {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
ASSERT_NO_THROW(model_comparator->set_shape_strict_match(true));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_FALSE(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
{
|
||||
{
|
||||
std::shared_ptr<ov::op::v0::Parameter> test_parameter =
|
||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2});
|
||||
test_parameter->set_friendly_name("test_parameter_1");
|
||||
std::shared_ptr<ov::op::v0::Abs> test_abs =
|
||||
std::make_shared<ov::op::v0::Abs>(test_parameter);
|
||||
std::shared_ptr<ov::op::v0::Result> test_res =
|
||||
std::make_shared<ov::op::v0::Result>(test_abs);
|
||||
test_model_0_1 = std::make_shared<ov::Model>(ov::ResultVector{test_res},
|
||||
ov::ParameterVector{test_parameter});
|
||||
}
|
||||
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModelComparatorTest, match_with_low_coeff) {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
model_comparator->set_match_coefficient(0.5f);
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_0_1));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_1));
|
||||
ASSERT_TRUE(model_comparator->match(test_model_0_0, test_model_1));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1));
|
||||
ASSERT_TRUE(model_comparator->match(test_model_0_1, test_model_1));
|
||||
}
|
||||
|
||||
TEST_F(ModelComparatorTest, match_with_in_info) {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
std::map<std::string, InputInfo> test_in_info({{"test_parameter_0", InputInfo()}}),
|
||||
test_in_info_1({{"test_parameter_1", InputInfo({}, 1, 2, true)}});
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info));
|
||||
ASSERT_TRUE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info)));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1));
|
||||
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_0, test_model_0_1, test_in_info, test_in_info_1)));
|
||||
ASSERT_NO_THROW(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info));
|
||||
ASSERT_FALSE(std::get<0>(model_comparator->match(test_model_0_1, test_model_1, test_in_info, test_in_info)));
|
||||
}
|
||||
|
||||
TEST_F(ModelComparatorTest, is_subgraph) {
|
||||
ModelComparator::Ptr model_comparator = ModelComparator::get();
|
||||
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_0, test_model_0_1));
|
||||
auto is_subgraph = model_comparator->is_subgraph(test_model_0_0, test_model_0_1);
|
||||
ASSERT_TRUE(std::get<0>(is_subgraph));
|
||||
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_0, test_model_1));
|
||||
ASSERT_FALSE(std::get<0>(model_comparator->is_subgraph(test_model_0_0, test_model_1)));
|
||||
ASSERT_NO_THROW(model_comparator->is_subgraph(test_model_0_1, test_model_1));
|
||||
ASSERT_FALSE(std::get<0>(model_comparator->is_subgraph(test_model_0_1, test_model_1)));
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user