[Remove APIs] remove api add_matcher()
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
parent
bd4f9746a3
commit
f0d0557e83
@ -33,7 +33,8 @@ target_link_libraries(${TARGET_NAME}_obj PRIVATE openvino::itt)
|
||||
|
||||
target_include_directories(${TARGET_NAME}_obj PRIVATE $<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:ngraph,INTERFACE_INCLUDE_DIRECTORIES>>
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>>)
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>>
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..)
|
||||
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME}_obj)
|
||||
|
||||
@ -51,7 +52,8 @@ target_link_libraries(${TARGET_NAME} INTERFACE openvino::runtime)
|
||||
target_include_directories(${TARGET_NAME} INTERFACE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:ngraph,INTERFACE_INCLUDE_DIRECTORIES>>
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>>)
|
||||
$<BUILD_INTERFACE:$<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>>
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..)
|
||||
|
||||
# LTO
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include <low_precision/layer_transformation.hpp>
|
||||
#include <low_precision/network_helper.hpp>
|
||||
|
||||
#include "core/src/itt.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@ -448,9 +448,24 @@ void LayerTransformation::addPattern(ngraph::pass::GraphRewrite& pass, Transform
|
||||
};
|
||||
// TODO: better name for matcher? required?
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(patternRoot, matcher_name);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
pass.add_matcher(m, internal_callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, internal_callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = internal_callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
pass.add_matcher(match_pass);
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include "low_precision/low_precision.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "core/src/itt.hpp"
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -134,9 +134,24 @@ void make_matcher_type_relaxed(ngraph::pass::GraphRewrite* transformation) {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(p_node, matcher_name);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
transformation->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
transformation->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
ngraph::pass::low_precision::TypeRelaxedReplacer::TypeRelaxedReplacer() {
|
||||
|
@ -257,14 +257,6 @@ public:
|
||||
return pass;
|
||||
}
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ov::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
@ -239,37 +239,6 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Model> f,
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (m->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
property));
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
// TODO: before deprecate this function, by default expect the
|
||||
// callback require static shape.
|
||||
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
||||
auto pass_config = get_pass_config();
|
||||
// We have to preserve disabled passes because in case when we register matchers inside
|
||||
|
@ -111,9 +111,23 @@ public:
|
||||
};
|
||||
|
||||
auto m = make_shared<TestMatcher>(make_shared<op::v1::Multiply>(pattern, iconst1));
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
this->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
void construct_add_zero() {
|
||||
@ -156,9 +170,23 @@ public:
|
||||
|
||||
auto add = make_shared<op::v1::Add>(pattern, iconst0);
|
||||
auto m = make_shared<TestMatcher>(add);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
this->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
TestGraphRewrite() : GraphRewrite() {
|
||||
|
@ -34,7 +34,9 @@ target_compile_definitions(interpreter_backend
|
||||
)
|
||||
target_link_libraries(interpreter_backend PRIVATE ngraph::builder ngraph::reference openvino::util openvino::runtime::dev)
|
||||
|
||||
target_include_directories(interpreter_backend PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>)
|
||||
target_include_directories(interpreter_backend PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../common/conditional_compilation/include)
|
||||
|
||||
file(GLOB_RECURSE all_backends_src "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp")
|
||||
add_clang_format_target(interpreter_backend_clang FOR_SOURCES ${all_backends_src})
|
||||
|
144
src/plugins/template/backend/pass/dyn_elimination.cpp
Normal file
144
src/plugins/template/backend/pass/dyn_elimination.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dyn_elimination.hpp"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/range.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/runtime/reference/range.hpp"
|
||||
#include "ngraph/slice_plan.hpp"
|
||||
#include "core/src/itt.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::DynElimination::DynElimination() : GraphRewrite() {
|
||||
construct_range();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
|
||||
const Shape& shape,
|
||||
const std::shared_ptr<op::Constant>& start_arg,
|
||||
const std::shared_ptr<op::Constant>& step_arg) {
|
||||
std::vector<T> elements(shape_size(shape));
|
||||
std::vector<T> start_vec = start_arg->get_vector<T>();
|
||||
std::vector<T> step_vec = step_arg->get_vector<T>();
|
||||
|
||||
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
|
||||
|
||||
runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
|
||||
|
||||
return make_shared<op::Constant>(et, shape, elements);
|
||||
}
|
||||
|
||||
void pass::DynElimination::construct_range() {
|
||||
auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
|
||||
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
|
||||
|
||||
auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
|
||||
auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
|
||||
auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
|
||||
|
||||
NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
|
||||
step_arg->get_output_partial_shape(0).rank().compatible(0));
|
||||
|
||||
auto et = range_node->get_output_element_type(0);
|
||||
auto shape = range_node->get_output_shape(0);
|
||||
|
||||
std::shared_ptr<op::Constant> replacement;
|
||||
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic error "-Wswitch"
|
||||
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||
#endif
|
||||
switch (et) {
|
||||
case element::Type_t::bf16:
|
||||
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::f64:
|
||||
replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic:
|
||||
case element::Type_t::boolean:
|
||||
NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
|
||||
break;
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
replace_node(range_node, replacement);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
range_matcher->get_name(),
|
||||
range_matcher,
|
||||
[range_matcher, range_callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << range_matcher->get_name() << " on " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(range_matcher)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << range_matcher->get_name() << " matched " << node;
|
||||
bool status = range_callback(*range_matcher.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
range_matcher->clear_state();
|
||||
return status;
|
||||
}
|
||||
range_matcher->clear_state();
|
||||
return false;
|
||||
},
|
||||
all_pass_property_off);
|
||||
add_matcher(match_pass);
|
||||
}
|
Loading…
Reference in New Issue
Block a user