diff --git a/ngraph/core/src/op/shuffle_channels.cpp b/ngraph/core/src/op/shuffle_channels.cpp index 71683af4030..d38a560f72b 100644 --- a/ngraph/core/src/op/shuffle_channels.cpp +++ b/ngraph/core/src/op/shuffle_channels.cpp @@ -63,11 +63,10 @@ void op::ShuffleChannels::validate_and_infer_types() if (get_input_partial_shape(0).is_static()) { const auto shape = get_input_shape(0); - NODE_VALIDATION_CHECK( this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D."); - size_t axis_zb = get_zero_based_axis(); + size_t axis_zb = get_zero_based_axis(); NODE_VALIDATION_CHECK(this, axis_zb < shape.size(), "The 'axis' parameter for ShuffleChannels has to point to one of the " @@ -81,14 +80,8 @@ void op::ShuffleChannels::validate_and_infer_types() this, channel_dim_size % m_group == 0, "The channel dimension size has to be a multiple of the groups parameter value."); - set_output_size(1); - set_output_type(0, data_type, shape); - } - else - { - const auto shape = get_input_partial_shape(0); - set_output_type(0, data_type, shape); } + set_output_type(0, data_type, get_input_partial_shape(0)); } shared_ptr op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const