[Shape Inference] ScatterUpdate - add evaluate_label (#14381)
* Add ScatterUpdate type_prop tests * Add evaluate_label for ScatterUpdate * style update * Reuse has_no_labels * Adjust the logic to the single output * Remove redundant size_check * Migrate from HostTensor to ov Tensor * Remove deprecation macro * Use auto instead of HostTensorPtr * Adjust tensor element type * Update to reuse labels vector * Use sizeof size_t to set element type * Ensure static shapes on inputs
This commit is contained in:
parent
e0c21ce302
commit
72c39c3e32
@ -37,6 +37,7 @@ public:
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
private:
|
||||
|
@ -122,3 +122,59 @@ bool op::v3::ScatterUpdate::has_evaluate() const {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool scatter_label_evaluator(const Node* node, TensorLabelVector& output_labels) {
|
||||
const auto& input_values = node->input_values();
|
||||
|
||||
constexpr auto data_in_idx = 0;
|
||||
constexpr auto updates_in_idx = 2;
|
||||
std::vector<size_t> data_labels = input_values[data_in_idx].get_tensor().get_value_label();
|
||||
std::vector<size_t> updates_labels = input_values[updates_in_idx].get_tensor().get_value_label();
|
||||
|
||||
if (ov::has_no_labels(data_labels) && ov::has_no_labels(updates_labels)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr auto element_type = (sizeof(size_t) == 8) ? element::u64 : element::u32;
|
||||
std::vector<ov::runtime::Tensor> input_tensors;
|
||||
input_tensors.reserve(input_values.size());
|
||||
|
||||
auto make_input_label = [&](const Output<Node>& input, TensorLabel& labels) {
|
||||
input_tensors.emplace_back(element_type, input.get_shape());
|
||||
labels.resize(shape_size(input.get_shape()));
|
||||
memcpy(input_tensors.back().data(), labels.data(), input_tensors.back().get_byte_size());
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < input_values.size(); ++i) {
|
||||
const auto& input = input_values[i];
|
||||
if (i == data_in_idx) {
|
||||
make_input_label(input, data_labels);
|
||||
} else if (i == updates_in_idx) {
|
||||
make_input_label(input, updates_labels);
|
||||
} else {
|
||||
const auto host_tensor_ptr = input.get_tensor().get_lower_value();
|
||||
input_tensors.emplace_back(host_tensor_ptr->get_element_type(),
|
||||
host_tensor_ptr->get_shape(),
|
||||
host_tensor_ptr->get_data_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
ov::TensorVector output_tensors{ov::Tensor(element_type, node->get_output_shape(0))};
|
||||
if (node->evaluate(output_tensors, input_tensors)) {
|
||||
size_t* ptr = static_cast<size_t*>(output_tensors[0].data(element_type));
|
||||
output_labels[0] = std::vector<size_t>(ptr, ptr + output_tensors[0].get_size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool op::v3::ScatterUpdate::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_label);
|
||||
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(2).is_static() &&
|
||||
get_input_tensor(1).has_and_set_bound() && get_input_tensor(3).has_and_set_bound()) {
|
||||
return scatter_label_evaluator(this, output_labels);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -2,12 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
using type = ngraph::element::Type;
|
||||
@ -208,6 +211,48 @@ TEST(type_prop, scatter_update_v3_dynamic_data_shape) {
|
||||
EXPECT_TRUE(scatter_update->get_output_partial_shape(0).is_dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_update_v3_interval_label_data_shape) {
|
||||
auto labeled_dim = Dimension(1, 9);
|
||||
size_t label = 222;
|
||||
ov::DimensionTracker::set_label(labeled_dim, label);
|
||||
PartialShape data_shape = PartialShape{-1, {2, 8}, labeled_dim, 4};
|
||||
Shape indices_shape{2, 1};
|
||||
Shape updates_shape{3, 2, 1, 2, 4};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto idx = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto updates = make_shared<op::Parameter>(element::f32, updates_shape);
|
||||
auto axis = op::Constant::create(element::i32, Shape{}, {1});
|
||||
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(data, idx, updates, axis);
|
||||
|
||||
const auto& output_shape = scatter_update->get_output_partial_shape(0);
|
||||
EXPECT_EQ(output_shape, data_shape);
|
||||
EXPECT_THAT(get_shape_labels(output_shape), ElementsAre(ov::no_label, ov::no_label, label, ov::no_label));
|
||||
EXPECT_EQ(scatter_update->get_output_element_type(0), element::f32);
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_update_v3_value_label_propagation) {
|
||||
auto labeled_dim = Dimension(5, 7);
|
||||
size_t label = 2345664;
|
||||
ov::DimensionTracker::set_label(labeled_dim, label);
|
||||
PartialShape data_shape = PartialShape{labeled_dim};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::i8, data_shape);
|
||||
auto shape_of = make_shared<op::v3::ShapeOf>(data);
|
||||
auto scatter_update = make_shared<op::v3::ScatterUpdate>(op::Constant::create(element::i64, Shape{2}, {1, 0}),
|
||||
op::Constant::create(element::i64, Shape{1}, {1}),
|
||||
shape_of,
|
||||
op::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto broadcast =
|
||||
make_shared<op::v3::Broadcast>(op::Constant::create(element::i64, Shape{1, 1}, {4}), scatter_update);
|
||||
|
||||
const auto& output_shape = broadcast->get_output_partial_shape(0);
|
||||
EXPECT_EQ(output_shape, PartialShape({1, {5, 7}}));
|
||||
EXPECT_EQ(ov::DimensionTracker::get_label(output_shape[0]), ov::no_label);
|
||||
EXPECT_EQ(ov::DimensionTracker::get_label(output_shape[1]), label);
|
||||
}
|
||||
|
||||
TEST(type_prop, scatter_update_v3_partial_value_propagation) {
|
||||
// strided slice should take from 5 to 7 elements from the 10 elements in the input data
|
||||
auto input = make_shared<op::Parameter>(element::i8, PartialShape{ov::Dimension(5, 7)});
|
||||
|
Loading…
Reference in New Issue
Block a user