remove resize asserts (#11234)

This commit is contained in:
Mateusz Bencer
2022-03-30 18:13:55 +02:00
committed by GitHub
parent a635150b9d
commit 7a0d85a067

View File

@@ -142,7 +142,6 @@ OutputVector resize(const onnx_import::Node& node) {
// in "tf_crop_and_resize" which is not handled now
const auto inputs = node.get_ng_inputs();
const auto& data = inputs.at(0);
const auto& data_shape = data.get_partial_shape();
auto attrs = get_resize_attrs(node);
@@ -150,12 +149,6 @@ OutputVector resize(const onnx_import::Node& node) {
{
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SIZES;
const auto& sizes = inputs.at(3);
const auto& sizes_shape = sizes.get_partial_shape();
CHECK_VALID_NODE(node,
(sizes_shape.is_static() || data_shape.rank().is_static()),
" Data rank or shape of sizes input is required to be static.");
const auto scales = calculate_scales_based_on_sizes(data, sizes);
return {std::make_shared<default_opset::Interpolate>(data, sizes, scales, attrs)};
@@ -164,12 +157,6 @@ OutputVector resize(const onnx_import::Node& node) {
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SCALES;
const auto& scales = inputs.at(2);
const auto& scales_shape = scales.get_partial_shape();
CHECK_VALID_NODE(node,
(scales_shape.is_static() || data_shape.rank().is_static()),
" Data rank or shape of scales input is required to be static.");
const auto output_shape = calculate_output_shape_based_on_scales(data, scales);
return {std::make_shared<default_opset::Interpolate>(data, output_shape, scales, attrs)};
}
@@ -181,9 +168,6 @@ OutputVector resize(const onnx_import::Node& node) {
const auto& data = inputs.at(0);
const auto& scales = inputs.at(1);
const auto& data_shape = data.get_partial_shape();
const auto& scales_shape = scales.get_partial_shape();
auto attrs = get_resize_attrs(node);
if (attrs.mode == InterpolateMode::NEAREST) {
@@ -193,10 +177,6 @@ OutputVector resize(const onnx_import::Node& node) {
attrs.coordinate_transformation_mode = Transform_mode::ASYMMETRIC;
}
CHECK_VALID_NODE(node,
(scales_shape.is_static() || data_shape.rank().is_static()),
" Data rank or shape of scales input is required to be static.");
const auto output_shape = calculate_output_shape_based_on_scales(data, scales);
return {std::make_shared<default_opset::Interpolate>(data, output_shape, scales, attrs)};
}