Upgrade broadcast from v1 to v3 in SequenceToTi transformation (#12922)
* upgrade broadcast in SequanceToTensorIterator transformation, fix broadcast serialization * codestyle
This commit is contained in:
@@ -42,7 +42,13 @@ ngraph::Output<ngraph::Node> get_masked_value(const std::shared_ptr<ngraph::opse
|
||||
body_params.push_back(aggregated_Y_h_body_param);
|
||||
|
||||
// Create mask node deciding whether or not to mask batch data.
|
||||
auto batch_seq_length = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(data, seq_lengths, 0);
|
||||
auto data_shape = ngraph::op::util::make_try_fold<ngraph::opset5::ShapeOf>(data);
|
||||
auto axis = ngraph::opset5::Constant::create(data_shape->get_element_type(), {1}, {0});
|
||||
auto batch_seq_length =
|
||||
ngraph::op::util::make_try_fold<ngraph::opset5::Broadcast>(seq_lengths,
|
||||
data_shape,
|
||||
axis,
|
||||
ngraph::op::BroadcastType::EXPLICIT);
|
||||
|
||||
auto mask_condition = std::make_shared<ngraph::opset5::Greater>(current_iter, batch_seq_length);
|
||||
auto mask_Y_h = std::make_shared<ngraph::opset5::Equal>(current_iter, batch_seq_length);
|
||||
@@ -54,7 +60,6 @@ ngraph::Output<ngraph::Node> get_masked_value(const std::shared_ptr<ngraph::opse
|
||||
body_results.push_back(aggregated_result);
|
||||
|
||||
auto scalar_mask_value = ngraph::opset5::Constant::create(data.get_element_type(), {}, {0.f});
|
||||
auto data_shape = ngraph::op::util::make_try_fold<ngraph::opset5::ShapeOf>(data);
|
||||
auto mask_value = ngraph::op::util::make_try_fold<ngraph::opset5::Broadcast>(scalar_mask_value, data_shape);
|
||||
return ngraph::op::util::make_try_fold<ngraph::opset5::Select>(mask_condition, mask_value, data);
|
||||
}
|
||||
|
||||
@@ -63,9 +63,9 @@ template <>
|
||||
NGRAPH_API EnumNames<ngraph::op::BroadcastType>& EnumNames<ngraph::op::BroadcastType>::get() {
|
||||
static auto enum_names =
|
||||
EnumNames<ngraph::op::BroadcastType>("ngraph::op::BroadcastType",
|
||||
{{"none", ngraph::op::BroadcastType::NONE},
|
||||
{{"explicit", ngraph::op::BroadcastType::EXPLICIT},
|
||||
{"none", ngraph::op::BroadcastType::NONE},
|
||||
{"numpy", ngraph::op::BroadcastType::NUMPY},
|
||||
{"explicit", ngraph::op::BroadcastType::EXPLICIT},
|
||||
{"pdpd", ngraph::op::BroadcastType::PDPD},
|
||||
{"bidirectional", ngraph::op::BroadcastType::BIDIRECTIONAL}});
|
||||
return enum_names;
|
||||
|
||||
Reference in New Issue
Block a user