Implement interval logic for Interepolate-4 output_shape calculation (#4820)

* Written interval logic for Interpolate-4.

* Small changes.

* Added tests for interval logic.

* Small fix.
This commit is contained in:
Vladimir Gavrilov 2021-03-26 15:24:58 +03:00 committed by GitHub
parent 76cf1b2b65
commit 9bd63176f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 6 deletions

View File

@ -163,6 +163,18 @@ std::vector<int64_t> op::v4::Interpolate::get_axes() const
static constexpr float epsilon = 1.0e-6f;
namespace
{
int64_t multiply_bound_and_scale(int64_t bound, float scale)
{
if (bound == -1)
{
return bound;
}
return static_cast<int64_t>(static_cast<float>(bound) * scale);
}
}
void op::v4::Interpolate::infer_using_scales(PartialShape& output_shape,
const std::vector<int64_t>& axes,
const std::vector<float>& scales,
@ -171,12 +183,15 @@ void op::v4::Interpolate::infer_using_scales(PartialShape& output_shape,
size_t i = 0;
for (auto axis : axes)
{
if (padded_input_shape[axis].is_static())
{
float padded_len = static_cast<float>(padded_input_shape[axis].get_length());
int64_t new_dim = static_cast<int64_t>(padded_len * (scales[i] + epsilon));
output_shape[axis] = Dimension(new_dim);
}
const auto& current_dim = padded_input_shape[axis];
float multiplier = scales[i] + epsilon;
int64_t new_lower_bound =
multiply_bound_and_scale(current_dim.get_min_length(), multiplier);
int64_t new_upper_bound =
multiply_bound_and_scale(current_dim.get_max_length(), multiplier);
output_shape[axis] = Dimension(new_lower_bound, new_upper_bound);
++i;
}
}

View File

@ -198,3 +198,28 @@ TEST(type_prop, interpolate_v4_partial_static_rank3)
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
ASSERT_TRUE(interp->get_output_partial_shape(0).rank().is_static());
}
TEST(type_prop, interpolate_v4_interval_logic)
{
auto image = std::make_shared<op::Parameter>(
element::f32, PartialShape{2, 2, Dimension(12, 800), Dimension(0, -1), Dimension(24, -1)});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{3});
auto scales = op::Constant::create<float>(element::f32, Shape{3}, {0.5f, 0.25f, 0.125f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 4});
const auto out_shape = PartialShape{2, 2, Dimension(6, 400), Dimension(0, -1), Dimension(3, -1)};
InterpolateAttrs attrs;
attrs.mode = InterpolateMode::nearest;
attrs.shape_calculation_mode = ShapeCalcMode::scales;
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
attrs.antialias = false;
attrs.pads_begin = {0, 0, 0, 0, 0};
attrs.pads_end = {0, 0, 0, 0, 0};
attrs.cube_coeff = -0.75;
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
EXPECT_EQ(interp->get_element_type(), element::f32);
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
}