Upgrade broadcast from v1 to v3 in SequenceToTi transformation (#12922)

* upgrade broadcast in SequanceToTensorIterator transformation, fix broadcast serialization

* codestyle
This commit is contained in:
Ivan Tikhonov
2022-09-22 06:48:30 +03:00
committed by GitHub
parent 9d56cfd79b
commit f2c0e0b4d7
2 changed files with 9 additions and 4 deletions

View File

@@ -42,7 +42,13 @@ ngraph::Output<ngraph::Node> get_masked_value(const std::shared_ptr<ngraph::opse
body_params.push_back(aggregated_Y_h_body_param);
// Create mask node deciding whether or not to mask batch data.
auto batch_seq_length = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(data, seq_lengths, 0);
auto data_shape = ngraph::op::util::make_try_fold<ngraph::opset5::ShapeOf>(data);
auto axis = ngraph::opset5::Constant::create(data_shape->get_element_type(), {1}, {0});
auto batch_seq_length =
ngraph::op::util::make_try_fold<ngraph::opset5::Broadcast>(seq_lengths,
data_shape,
axis,
ngraph::op::BroadcastType::EXPLICIT);
auto mask_condition = std::make_shared<ngraph::opset5::Greater>(current_iter, batch_seq_length);
auto mask_Y_h = std::make_shared<ngraph::opset5::Equal>(current_iter, batch_seq_length);
@@ -54,7 +60,6 @@ ngraph::Output<ngraph::Node> get_masked_value(const std::shared_ptr<ngraph::opse
body_results.push_back(aggregated_result);
auto scalar_mask_value = ngraph::opset5::Constant::create(data.get_element_type(), {}, {0.f});
auto data_shape = ngraph::op::util::make_try_fold<ngraph::opset5::ShapeOf>(data);
auto mask_value = ngraph::op::util::make_try_fold<ngraph::opset5::Broadcast>(scalar_mask_value, data_shape);
return ngraph::op::util::make_try_fold<ngraph::opset5::Select>(mask_condition, mask_value, data);
}

View File

@@ -63,9 +63,9 @@ template <>
NGRAPH_API EnumNames<ngraph::op::BroadcastType>& EnumNames<ngraph::op::BroadcastType>::get() {
static auto enum_names =
EnumNames<ngraph::op::BroadcastType>("ngraph::op::BroadcastType",
{{"none", ngraph::op::BroadcastType::NONE},
{{"explicit", ngraph::op::BroadcastType::EXPLICIT},
{"none", ngraph::op::BroadcastType::NONE},
{"numpy", ngraph::op::BroadcastType::NUMPY},
{"explicit", ngraph::op::BroadcastType::EXPLICIT},
{"pdpd", ngraph::op::BroadcastType::PDPD},
{"bidirectional", ngraph::op::BroadcastType::BIDIRECTIONAL}});
return enum_names;