ShuffleChannels shape propagation unified (#6269)
This commit is contained in:
parent
15ee515a88
commit
b0e932567d
@ -63,11 +63,10 @@ void op::ShuffleChannels::validate_and_infer_types()
|
|||||||
if (get_input_partial_shape(0).is_static())
|
if (get_input_partial_shape(0).is_static())
|
||||||
{
|
{
|
||||||
const auto shape = get_input_shape(0);
|
const auto shape = get_input_shape(0);
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(
|
NODE_VALIDATION_CHECK(
|
||||||
this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D.");
|
this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D.");
|
||||||
size_t axis_zb = get_zero_based_axis();
|
|
||||||
|
|
||||||
|
size_t axis_zb = get_zero_based_axis();
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
axis_zb < shape.size(),
|
axis_zb < shape.size(),
|
||||||
"The 'axis' parameter for ShuffleChannels has to point to one of the "
|
"The 'axis' parameter for ShuffleChannels has to point to one of the "
|
||||||
@ -81,14 +80,8 @@ void op::ShuffleChannels::validate_and_infer_types()
|
|||||||
this,
|
this,
|
||||||
channel_dim_size % m_group == 0,
|
channel_dim_size % m_group == 0,
|
||||||
"The channel dimension size has to be a multiple of the groups parameter value.");
|
"The channel dimension size has to be a multiple of the groups parameter value.");
|
||||||
set_output_size(1);
|
|
||||||
set_output_type(0, data_type, shape);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto shape = get_input_partial_shape(0);
|
|
||||||
set_output_type(0, data_type, shape);
|
|
||||||
}
|
}
|
||||||
|
set_output_type(0, data_type, get_input_partial_shape(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const
|
shared_ptr<Node> op::ShuffleChannels::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
Loading…
Reference in New Issue
Block a user