add condition of input_2 which meets EmbeddingSegmentsSum spec (#14690)

Co-authored-by: Irina Efode <irina.efode@intel.com>
This commit is contained in:
Wilson Seok 2023-02-01 01:17:06 +09:00 committed by GitHub
parent f66f31a3b0
commit 0fd354a502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -673,6 +673,44 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v9::NonMaxSuppres
}
}
ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v3::EmbeddingSegmentsSum>& node,
size_t port,
const ov::element::Type& elemType,
const ov::Shape& targetShape) {
if (port == 2) {
ov::runtime::Tensor tensor = ov::runtime::Tensor(elemType, targetShape);
const auto &outputShape = node->get_output_shape(0);
const size_t range = outputShape[0] - 1; // values in segmentsIds should be less than num_segments
const size_t startFrom = 0;
const int seed = 1;
std::default_random_engine random(seed);
switch (elemType) {
case element::Type_t::i32: {
std::uniform_int_distribution<int32_t> distribution(startFrom, (startFrom + range));
auto *dataPtr = tensor.data<int32_t>();
for (size_t i = 0; i < tensor.get_size(); i++) {
dataPtr[i] = distribution(random);
}
return tensor;
}
case element::Type_t::i64: {
std::uniform_int_distribution<int64_t> distribution(startFrom, (startFrom + range));
auto *dataPtr = tensor.data<int64_t>();
for (size_t i = 0; i < tensor.get_size(); i++) {
dataPtr[i] = distribution(random);
}
return tensor;
}
default:
OPENVINO_UNREACHABLE("Unsupported element type for segment_ids: ", elemType);
}
}
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
}
ov::runtime::Tensor generate(const std::shared_ptr<ov::op::internal::AUGRUSequence>& node,
size_t port,
const ov::element::Type& elemType,