Implement interval logic for Interepolate-4 output_shape calculation (#4820)
* Written interval logic for Interpolate-4. * Small changes. * Added tests for interval logic. * Small fix.
This commit is contained in:
parent
76cf1b2b65
commit
9bd63176f8
@ -163,6 +163,18 @@ std::vector<int64_t> op::v4::Interpolate::get_axes() const
|
||||
|
||||
static constexpr float epsilon = 1.0e-6f;
|
||||
|
||||
namespace
|
||||
{
|
||||
int64_t multiply_bound_and_scale(int64_t bound, float scale)
|
||||
{
|
||||
if (bound == -1)
|
||||
{
|
||||
return bound;
|
||||
}
|
||||
return static_cast<int64_t>(static_cast<float>(bound) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
void op::v4::Interpolate::infer_using_scales(PartialShape& output_shape,
|
||||
const std::vector<int64_t>& axes,
|
||||
const std::vector<float>& scales,
|
||||
@ -171,12 +183,15 @@ void op::v4::Interpolate::infer_using_scales(PartialShape& output_shape,
|
||||
size_t i = 0;
|
||||
for (auto axis : axes)
|
||||
{
|
||||
if (padded_input_shape[axis].is_static())
|
||||
{
|
||||
float padded_len = static_cast<float>(padded_input_shape[axis].get_length());
|
||||
int64_t new_dim = static_cast<int64_t>(padded_len * (scales[i] + epsilon));
|
||||
output_shape[axis] = Dimension(new_dim);
|
||||
}
|
||||
const auto& current_dim = padded_input_shape[axis];
|
||||
float multiplier = scales[i] + epsilon;
|
||||
|
||||
int64_t new_lower_bound =
|
||||
multiply_bound_and_scale(current_dim.get_min_length(), multiplier);
|
||||
int64_t new_upper_bound =
|
||||
multiply_bound_and_scale(current_dim.get_max_length(), multiplier);
|
||||
|
||||
output_shape[axis] = Dimension(new_lower_bound, new_upper_bound);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
@ -198,3 +198,28 @@ TEST(type_prop, interpolate_v4_partial_static_rank3)
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).rank().is_static());
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v4_interval_logic)
|
||||
{
|
||||
auto image = std::make_shared<op::Parameter>(
|
||||
element::f32, PartialShape{2, 2, Dimension(12, 800), Dimension(0, -1), Dimension(24, -1)});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{3});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{3}, {0.5f, 0.25f, 0.125f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 4});
|
||||
|
||||
const auto out_shape = PartialShape{2, 2, Dimension(6, 400), Dimension(0, -1), Dimension(3, -1)};
|
||||
|
||||
InterpolateAttrs attrs;
|
||||
attrs.mode = InterpolateMode::nearest;
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::scales;
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
|
||||
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
|
||||
attrs.antialias = false;
|
||||
attrs.pads_begin = {0, 0, 0, 0, 0};
|
||||
attrs.pads_end = {0, 0, 0, 0, 0};
|
||||
attrs.cube_coeff = -0.75;
|
||||
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
|
||||
|
||||
EXPECT_EQ(interp->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user