add condition of input_2 which meets EmbeddingSegmentsSum spec (#14690)
Co-authored-by: Irina Efode <irina.efode@intel.com>
This commit is contained in:
parent
f66f31a3b0
commit
0fd354a502
@ -673,6 +673,44 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v9::NonMaxSuppres
|
||||
}
|
||||
}
|
||||
|
||||
ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v3::EmbeddingSegmentsSum>& node,
|
||||
size_t port,
|
||||
const ov::element::Type& elemType,
|
||||
const ov::Shape& targetShape) {
|
||||
if (port == 2) {
|
||||
ov::runtime::Tensor tensor = ov::runtime::Tensor(elemType, targetShape);
|
||||
|
||||
const auto &outputShape = node->get_output_shape(0);
|
||||
const size_t range = outputShape[0] - 1; // values in segmentsIds should be less than num_segments
|
||||
const size_t startFrom = 0;
|
||||
const int seed = 1;
|
||||
std::default_random_engine random(seed);
|
||||
switch (elemType) {
|
||||
case element::Type_t::i32: {
|
||||
std::uniform_int_distribution<int32_t> distribution(startFrom, (startFrom + range));
|
||||
|
||||
auto *dataPtr = tensor.data<int32_t>();
|
||||
for (size_t i = 0; i < tensor.get_size(); i++) {
|
||||
dataPtr[i] = distribution(random);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
case element::Type_t::i64: {
|
||||
std::uniform_int_distribution<int64_t> distribution(startFrom, (startFrom + range));
|
||||
|
||||
auto *dataPtr = tensor.data<int64_t>();
|
||||
for (size_t i = 0; i < tensor.get_size(); i++) {
|
||||
dataPtr[i] = distribution(random);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
default:
|
||||
OPENVINO_UNREACHABLE("Unsupported element type for segment_ids: ", elemType);
|
||||
}
|
||||
}
|
||||
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
|
||||
}
|
||||
|
||||
ov::runtime::Tensor generate(const std::shared_ptr<ov::op::internal::AUGRUSequence>& node,
|
||||
size_t port,
|
||||
const ov::element::Type& elemType,
|
||||
|
Loading…
Reference in New Issue
Block a user