Fix merge master error

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
Zhai, Xuejun 2023-03-06 15:05:55 +08:00
parent fcbfa114f2
commit 748fcdea3a

View File

@ -1,143 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "dyn_elimination.hpp"
#include <numeric>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/range.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/slice_plan.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
pass::DynElimination::DynElimination() : GraphRewrite() {
construct_range();
}
template <typename T>
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg) {
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
return make_shared<op::Constant>(et, shape, elements);
}
void pass::DynElimination::construct_range() {
auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
auto range_callback = [start_arg_label, stop_arg_label, step_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto start_arg = static_pointer_cast<op::Constant>(pattern_map[start_arg_label]);
auto step_arg = static_pointer_cast<op::Constant>(pattern_map[step_arg_label]);
auto range_node = static_pointer_cast<op::Range>(m.get_match_root());
NGRAPH_CHECK(start_arg->get_output_partial_shape(0).rank().compatible(0) &&
step_arg->get_output_partial_shape(0).rank().compatible(0));
auto et = range_node->get_output_element_type(0);
auto shape = range_node->get_output_shape(0);
std::shared_ptr<op::Constant> replacement;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (et) {
case element::Type_t::bf16:
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f16:
replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f32:
replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
break;
case element::Type_t::f64:
replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i8:
replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i16:
replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i32:
replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i64:
replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u8:
replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u16:
replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u32:
replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u64:
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::boolean:
NGRAPH_CHECK(false, "Internal nGraph error: unsupported element type: ", et);
break;
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic pop
#endif
replace_node(range_node, replacement);
return true;
};
auto range_matcher = make_shared<pattern::Matcher>(range_pat, "DynElimination.Range");
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
range_matcher->get_name(),
range_matcher,
[range_matcher, range_callback](const std::shared_ptr<Node>& node) -> bool {
NGRAPH_DEBUG << "Running matcher " << range_matcher->get_name() << " on " << node;
OV_PASS_CALLBACK(m);
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(range_matcher)->match(node->output(0))) {
NGRAPH_DEBUG << "Matcher " << range_matcher->get_name() << " matched " << node;
bool status = range_callback(*range_matcher.get());
// explicitly clear Matcher state because it holds pointers to matched nodes
range_matcher->clear_state();
return status;
}
range_matcher->clear_state();
return false;
},
all_pass_property_off);
add_matcher(match_pass);
}