[Shape Inference] ScatterUpdate - add evaluate_label (#14381)

* Add ScatterUpdate type_prop tests

* Add evaluate_label for ScatterUpdate

* style update

* Reuse has_no_labels

* Adjust the logic to the single output

* Remove redundant size_check

* Migrate from HostTensor to ov Tensor

* Remove deprecation macro

* Use auto instead of HostTensorPtr

* Adjust tensor element type

* Update to reuse labels vector

* Use sizeof size_t to set element type

* Ensure static shapes on inputs
This commit is contained in:
Katarzyna Mitrus 2022-12-16 11:34:25 +01:00 committed by GitHub
parent e0c21ce302
commit 72c39c3e32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 0 deletions

View File

@ -37,6 +37,7 @@ public:
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate_label(TensorLabelVector& output_labels) const override;
bool has_evaluate() const override;
private:

View File

@ -122,3 +122,59 @@ bool op::v3::ScatterUpdate::has_evaluate() const {
}
return false;
}
namespace {
bool scatter_label_evaluator(const Node* node, TensorLabelVector& output_labels) {
const auto& input_values = node->input_values();
constexpr auto data_in_idx = 0;
constexpr auto updates_in_idx = 2;
std::vector<size_t> data_labels = input_values[data_in_idx].get_tensor().get_value_label();
std::vector<size_t> updates_labels = input_values[updates_in_idx].get_tensor().get_value_label();
if (ov::has_no_labels(data_labels) && ov::has_no_labels(updates_labels)) {
return false;
}
constexpr auto element_type = (sizeof(size_t) == 8) ? element::u64 : element::u32;
std::vector<ov::runtime::Tensor> input_tensors;
input_tensors.reserve(input_values.size());
auto make_input_label = [&](const Output<Node>& input, TensorLabel& labels) {
input_tensors.emplace_back(element_type, input.get_shape());
labels.resize(shape_size(input.get_shape()));
memcpy(input_tensors.back().data(), labels.data(), input_tensors.back().get_byte_size());
};
for (size_t i = 0; i < input_values.size(); ++i) {
const auto& input = input_values[i];
if (i == data_in_idx) {
make_input_label(input, data_labels);
} else if (i == updates_in_idx) {
make_input_label(input, updates_labels);
} else {
const auto host_tensor_ptr = input.get_tensor().get_lower_value();
input_tensors.emplace_back(host_tensor_ptr->get_element_type(),
host_tensor_ptr->get_shape(),
host_tensor_ptr->get_data_ptr());
}
}
ov::TensorVector output_tensors{ov::Tensor(element_type, node->get_output_shape(0))};
if (node->evaluate(output_tensors, input_tensors)) {
size_t* ptr = static_cast<size_t*>(output_tensors[0].data(element_type));
output_labels[0] = std::vector<size_t>(ptr, ptr + output_tensors[0].get_size());
return true;
}
return false;
}
} // namespace
bool op::v3::ScatterUpdate::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_label);
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(2).is_static() &&
get_input_tensor(1).has_and_set_bound() && get_input_tensor(3).has_and_set_bound()) {
return scatter_label_evaluator(this, output_labels);
}
return false;
}

View File

@ -2,12 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "dimension_tracker.hpp"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
namespace {
using type = ngraph::element::Type;
@ -208,6 +211,48 @@ TEST(type_prop, scatter_update_v3_dynamic_data_shape) {
EXPECT_TRUE(scatter_update->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, scatter_update_v3_interval_label_data_shape) {
auto labeled_dim = Dimension(1, 9);
size_t label = 222;
ov::DimensionTracker::set_label(labeled_dim, label);
PartialShape data_shape = PartialShape{-1, {2, 8}, labeled_dim, 4};
Shape indices_shape{2, 1};
Shape updates_shape{3, 2, 1, 2, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto idx = make_shared<op::Parameter>(element::i32, indices_shape);
auto updates = make_shared<op::Parameter>(element::f32, updates_shape);
auto axis = op::Constant::create(element::i32, Shape{}, {1});
auto scatter_update = make_shared<op::v3::ScatterUpdate>(data, idx, updates, axis);
const auto& output_shape = scatter_update->get_output_partial_shape(0);
EXPECT_EQ(output_shape, data_shape);
EXPECT_THAT(get_shape_labels(output_shape), ElementsAre(ov::no_label, ov::no_label, label, ov::no_label));
EXPECT_EQ(scatter_update->get_output_element_type(0), element::f32);
}
TEST(type_prop, scatter_update_v3_value_label_propagation) {
auto labeled_dim = Dimension(5, 7);
size_t label = 2345664;
ov::DimensionTracker::set_label(labeled_dim, label);
PartialShape data_shape = PartialShape{labeled_dim};
auto data = make_shared<op::Parameter>(element::i8, data_shape);
auto shape_of = make_shared<op::v3::ShapeOf>(data);
auto scatter_update = make_shared<op::v3::ScatterUpdate>(op::Constant::create(element::i64, Shape{2}, {1, 0}),
op::Constant::create(element::i64, Shape{1}, {1}),
shape_of,
op::Constant::create(element::i64, Shape{1}, {0}));
auto broadcast =
make_shared<op::v3::Broadcast>(op::Constant::create(element::i64, Shape{1, 1}, {4}), scatter_update);
const auto& output_shape = broadcast->get_output_partial_shape(0);
EXPECT_EQ(output_shape, PartialShape({1, {5, 7}}));
EXPECT_EQ(ov::DimensionTracker::get_label(output_shape[0]), ov::no_label);
EXPECT_EQ(ov::DimensionTracker::get_label(output_shape[1]), label);
}
TEST(type_prop, scatter_update_v3_partial_value_propagation) {
// strided slice should take from 5 to 7 elements from the 10 elements in the input data
auto input = make_shared<op::Parameter>(element::i8, PartialShape{ov::Dimension(5, 7)});