[CPU] Improvement for NoneZero and Gather (#16641)
This commit is contained in:
parent
35398e339d
commit
fc88bed604
@ -10,6 +10,7 @@
|
||||
#include <numeric>
|
||||
#include <openvino/core/validation_util.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/opsets/opset7.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
@ -25,7 +26,7 @@ using namespace ov;
|
||||
//`simplify_gather`, optimizes gather if Gather is gathering the
|
||||
// whole input tensor
|
||||
static bool simplify_gather(shared_ptr<Node> node) {
|
||||
if (auto gather = ov::as_type_ptr<opset3::Gather>(node)) {
|
||||
if (auto gather = ov::as_type_ptr<op::util::GatherBase>(node)) {
|
||||
// check if we are gathering the whole input
|
||||
auto data = gather->input_value(0);
|
||||
auto indices = gather->input_value(1);
|
||||
@ -34,10 +35,6 @@ static bool simplify_gather(shared_ptr<Node> node) {
|
||||
if (data.get_partial_shape().is_dynamic() || indices.get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
// if rank of data and gather output dont match, we will skip
|
||||
if (data.get_shape().size() != node->get_shape().size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis = gather->get_axis();
|
||||
if (axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
|
||||
@ -45,6 +42,22 @@ static bool simplify_gather(shared_ptr<Node> node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data.get_shape().size() != node->get_shape().size()) {
|
||||
auto constant_indices = ov::as_type_ptr<opset3::Constant>(gather->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_indices)
|
||||
return false;
|
||||
// case_3: if input_shape is (1,3,5,5) and axis = 0, indices = 0, then gather is just a Squeeze
|
||||
const auto const_indices = constant_indices->cast_vector<int64_t>();
|
||||
if (data.get_shape()[axis] == 1 && const_indices.size() == 1 && const_indices[0] == 0) {
|
||||
auto squeeze = std::make_shared<opset8::Squeeze>(gather->input_value(0), gather->input_value(2));
|
||||
squeeze->set_friendly_name(gather->get_friendly_name());
|
||||
ov::copy_runtime_info(gather, squeeze);
|
||||
ov::replace_node(gather, squeeze);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// case_1 : if the input tensor is of shape (4, 1, 4)
|
||||
// and axis = 1, then the gather would be simply
|
||||
// gathering the whole input tensor, so we can optimize this
|
||||
@ -297,7 +310,7 @@ static bool eliminate_unsqueeze(const shared_ptr<Node>& node) {
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, eliminate_reshape_v1, opset3::Reshape);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, eliminate_unsqueeze, opset3::Unsqueeze);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, eliminate_nop, op::v1::Broadcast, op::v3::Broadcast);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, simplify_gather, opset3::Gather);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, simplify_gather, opset3::Gather, opset7::Gather, opset8::Gather);
|
||||
|
||||
pass::EliminatePad::EliminatePad() {
|
||||
MATCHER_SCOPE(EliminatePad);
|
||||
|
@ -1310,3 +1310,31 @@ TEST(SplitConcatElimination, no_sequence_found) {
|
||||
EXPECT_EQ(count_ops_of_type<ov::opset9::Split>(model), 1) << "SplitConcatElimination transformation has failed. "
|
||||
"The number of Split ops is not 1";
|
||||
}
|
||||
|
||||
TEST(nop_elimination, gather_to_squeeze) {
|
||||
auto generate_func = [](int64_t gather_axis) {
|
||||
ov::Shape shape{3, 3, 4, 4};
|
||||
shape[gather_axis] = 1;
|
||||
auto arg = std::make_shared<op::Parameter>(element::f32, shape);
|
||||
auto indices = op::Constant::create(element::i64, Shape{}, vector<int64_t>{0});
|
||||
auto axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{gather_axis});
|
||||
auto gather = std::make_shared<op::v8::Gather>(arg, indices, axis);
|
||||
return std::make_shared<Function>(NodeVector{gather}, ParameterVector{arg});
|
||||
};
|
||||
|
||||
auto func_axis_0 = generate_func(0);
|
||||
auto func_axis_1 = generate_func(1);
|
||||
auto func_axis_2 = generate_func(2);
|
||||
auto func_axis_3 = generate_func(3);
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ov::pass::NopElimination>();
|
||||
auto run_and_check = [&](std::shared_ptr<Function>& func) {
|
||||
pass_manager.run_passes(func);
|
||||
EXPECT_EQ(count_ops_of_type<op::v8::Gather>(func), 0);
|
||||
EXPECT_EQ(count_ops_of_type<op::v0::Squeeze>(func), 1);
|
||||
};
|
||||
run_and_check(func_axis_0);
|
||||
run_and_check(func_axis_1);
|
||||
run_and_check(func_axis_2);
|
||||
run_and_check(func_axis_3);
|
||||
}
|
||||
|
@ -82,16 +82,6 @@ std::vector<size_t> NonZero::getNonZeroElementsCount(const T* src, const Shape&
|
||||
counts.push_back(count);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < inSize; i++) {
|
||||
if (src[i] != zero) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
counts.push_back(count);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
threadsCount = parallel_get_num_threads();
|
||||
if (inSize < blockSize * threadsCount)
|
||||
@ -174,13 +164,16 @@ void NonZero::executeSpecified() {
|
||||
dst[0] = 0;
|
||||
break;
|
||||
case 1: {
|
||||
size_t outputIndex = 0;
|
||||
for (int i = 0; i < srcDims[0]; ++i) {
|
||||
if (src[i] != zero) {
|
||||
dst[outputIndex] = i;
|
||||
outputIndex++;
|
||||
}
|
||||
}
|
||||
//if nonZeroCounts.size() > 1, then the 2nd round scan could run in parallel.
|
||||
parallel_nt(threadsCount, [&](int ithr, int nthr){
|
||||
size_t outputIndex = std::accumulate(nonZeroCounts.begin(), nonZeroCounts.begin() + ithr, 0);
|
||||
for_1d(ithr, nthr, inShape.getElementsCount(), [&](size_t i) {
|
||||
if (src[i] != zero) {
|
||||
dst[outputIndex] = i;
|
||||
outputIndex++;
|
||||
}
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
|
Loading…
Reference in New Issue
Block a user