Revert PSROIPooling average logic introduced in 305f0056059b091e0290b… (#3765)
* Revert PSROIPooling average logic introduced in 305f005605
* fix computing bin coords in average mode
This commit is contained in:
parent
0d22328a25
commit
1e418ca549
@ -102,10 +102,10 @@ public:
|
||||
roi_width = roi_end_w - roi_start_w;
|
||||
roi_height = roi_end_h - roi_start_h;
|
||||
} else if (mode_ == "average") {
|
||||
roi_start_w = round(bottom_rois[1] * spatial_scale_);
|
||||
roi_start_h = round(bottom_rois[2] * spatial_scale_);
|
||||
roi_end_w = round(bottom_rois[3] * spatial_scale_) + 1.0f;
|
||||
roi_end_h = round(bottom_rois[4] * spatial_scale_) + 1.0f;
|
||||
roi_start_w = static_cast<float>(round(bottom_rois[1])) * spatial_scale_;
|
||||
roi_start_h = static_cast<float>(round(bottom_rois[2])) * spatial_scale_;
|
||||
roi_end_w = static_cast<float>(round(bottom_rois[3]) + 1.0f) * spatial_scale_;
|
||||
roi_end_h = static_cast<float>(round(bottom_rois[4]) + 1.0f) * spatial_scale_;
|
||||
// Force too small ROIs to be 1x1
|
||||
roi_width = std::max<float>(roi_end_w - roi_start_w, 0.1f); // avoid 0
|
||||
roi_height = std::max<float>(roi_end_h - roi_start_h, 0.1f);
|
||||
|
@ -69,16 +69,23 @@ namespace ngraph
|
||||
{
|
||||
const T* box = rois + roi * 5;
|
||||
int batch_id = box[0];
|
||||
float start_w = box[1] * spatial_scale;
|
||||
float start_h = box[2] * spatial_scale;
|
||||
float end_w = box[3] * spatial_scale;
|
||||
float end_h = box[4] * spatial_scale;
|
||||
if (mode == AVG)
|
||||
float start_w = 0;
|
||||
float start_h = 0;
|
||||
float end_w = 0;
|
||||
float end_h = 0;
|
||||
if (mode == BILINEAR)
|
||||
{
|
||||
start_w = std::roundf(start_w);
|
||||
start_h = std::roundf(start_h);
|
||||
end_w = std::roundf(end_w) + 1;
|
||||
end_h = std::roundf(end_h) + 1;
|
||||
start_w = box[1] * spatial_scale;
|
||||
start_h = box[2] * spatial_scale;
|
||||
end_w = box[3] * spatial_scale;
|
||||
end_h = box[4] * spatial_scale;
|
||||
}
|
||||
else if (mode == AVG)
|
||||
{
|
||||
start_w = std::roundf(box[1]) * spatial_scale;
|
||||
start_h = std::roundf(box[2]) * spatial_scale;
|
||||
end_w = (std::roundf(box[3]) + 1.0f) * spatial_scale;
|
||||
end_h = (std::roundf(box[4]) + 1.0f) * spatial_scale;
|
||||
}
|
||||
float box_width = end_w - start_w;
|
||||
float box_height = end_h - start_h;
|
||||
@ -110,19 +117,19 @@ namespace ngraph
|
||||
if (mode == AVG)
|
||||
{
|
||||
size_t bin_start_w = std::min(
|
||||
static_cast<size_t>(start_w + floorf(pw * bin_width)),
|
||||
static_cast<size_t>(floorf(start_w + pw * bin_width)),
|
||||
width - 1);
|
||||
size_t bin_start_h = std::min(
|
||||
static_cast<size_t>(start_h + floorf(ph * bin_height)),
|
||||
static_cast<size_t>(floorf(start_h + ph * bin_height)),
|
||||
height - 1);
|
||||
size_t current_bin_width =
|
||||
std::min(static_cast<size_t>(start_w +
|
||||
ceilf((pw + 1) * bin_width)),
|
||||
std::min(static_cast<size_t>(
|
||||
ceilf(start_w + (pw + 1) * bin_width)),
|
||||
width) -
|
||||
bin_start_w;
|
||||
size_t current_bin_height =
|
||||
std::min(static_cast<size_t>(start_h +
|
||||
ceilf((ph + 1) * bin_height)),
|
||||
std::min(static_cast<size_t>(
|
||||
ceilf(start_h + (ph + 1) * bin_height)),
|
||||
height) -
|
||||
bin_start_h;
|
||||
T sum = 0;
|
||||
|
@ -104,11 +104,11 @@ NGRAPH_TEST(${BACKEND_NAME}, psroi_pooling_average_spatial_scale)
|
||||
0, 5, 10, 20, 30, 0, 0, 15, 50, 20, 1, 50, 35, 55, 65, 1, 0, 60, 5, 70,
|
||||
};
|
||||
std::vector<float> output{
|
||||
6.2499962, 46.44986, 90.249184, 130.44876, 166.25095, 206.45341, 250.25606, 290.45853,
|
||||
6.3499966, 46.849857, 88.349236, 128.84866, 166.35095, 206.85341, 248.35596, 288.8584,
|
||||
338.11142, 378.21387, 424.11667, 464.21912, 498.12119, 538.21564, 584.10443, 624.19464,
|
||||
345.11185, 385.21429, 427.11685, 467.2193, 505.12161, 545.21393, 587.1037, 627.19391,
|
||||
|
||||
6.24999619, 46.399868, 90.2491837, 130.398758, 166.250946, 206.403397, 250.256058,
|
||||
290.408508, 6.34999657, 46.8498573, 87.3492432, 127.848656, 166.350952, 206.853409,
|
||||
247.355896, 287.858368, 338.11142, 378.163879, 424.116669, 464.169128, 498.121185,
|
||||
538.165649, 584.104431, 624.144653, 345.111847, 385.164307, 427.116852, 467.169312,
|
||||
505.121613, 545.16394, 587.103699, 627.143921,
|
||||
};
|
||||
|
||||
auto tc = test::TestCase<TestEngine>(f);
|
||||
|
Loading…
Reference in New Issue
Block a user