ShuffleChannels shape propagation unified (#6269)

This commit is contained in:
Evgenya Stepyreva 2021-06-22 14:35:30 +03:00 committed by GitHub
parent 15ee515a88
commit b0e932567d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -63,11 +63,10 @@ void op::ShuffleChannels::validate_and_infer_types()
if (get_input_partial_shape(0).is_static()) if (get_input_partial_shape(0).is_static())
{ {
const auto shape = get_input_shape(0); const auto shape = get_input_shape(0);
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D."); this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D.");
size_t axis_zb = get_zero_based_axis();
size_t axis_zb = get_zero_based_axis();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis_zb < shape.size(), axis_zb < shape.size(),
"The 'axis' parameter for ShuffleChannels has to point to one of the " "The 'axis' parameter for ShuffleChannels has to point to one of the "
@ -81,14 +80,8 @@ void op::ShuffleChannels::validate_and_infer_types()
this, this,
channel_dim_size % m_group == 0, channel_dim_size % m_group == 0,
"The channel dimension size has to be a multiple of the groups parameter value."); "The channel dimension size has to be a multiple of the groups parameter value.");
set_output_size(1);
set_output_type(0, data_type, shape);
}
else
{
const auto shape = get_input_partial_shape(0);
set_output_type(0, data_type, shape);
} }
set_output_type(0, data_type, get_input_partial_shape(0));
} }
shared_ptr<Node> op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const shared_ptr<Node> op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const