ShuffleChannels shape propagation unified (#6269)
This commit is contained in:
parent
15ee515a88
commit
b0e932567d
@ -63,11 +63,10 @@ void op::ShuffleChannels::validate_and_infer_types()
|
||||
if (get_input_partial_shape(0).is_static())
|
||||
{
|
||||
const auto shape = get_input_shape(0);
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D.");
|
||||
size_t axis_zb = get_zero_based_axis();
|
||||
|
||||
size_t axis_zb = get_zero_based_axis();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis_zb < shape.size(),
|
||||
"The 'axis' parameter for ShuffleChannels has to point to one of the "
|
||||
@ -81,14 +80,8 @@ void op::ShuffleChannels::validate_and_infer_types()
|
||||
this,
|
||||
channel_dim_size % m_group == 0,
|
||||
"The channel dimension size has to be a multiple of the groups parameter value.");
|
||||
set_output_size(1);
|
||||
set_output_type(0, data_type, shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto shape = get_input_partial_shape(0);
|
||||
set_output_type(0, data_type, shape);
|
||||
}
|
||||
set_output_type(0, data_type, get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
|
Loading…
Reference in New Issue
Block a user