ROIAlign fix - Unify sample_value calculation between max/avg mode (#7710)
* Unify sample_value between max/avg mode * Fix mkldnn roi_align impl * Update test * Revert missing assert
This commit is contained in:
parent
414c3dc133
commit
0eeaf6b2e4
@ -317,25 +317,20 @@ void MKLDNNROIAlignNode::executeSpecified() {
|
|||||||
pointVector[sampleIndex + 3].second * wInputStride + blockResidual_;
|
pointVector[sampleIndex + 3].second * wInputStride + blockResidual_;
|
||||||
float part4 = srcData[part4Index];
|
float part4 = srcData[part4Index];
|
||||||
|
|
||||||
|
float sampleValue =
|
||||||
|
weightVector[sampleIndex] * part1 +
|
||||||
|
weightVector[sampleIndex + 1] * part2 +
|
||||||
|
weightVector[sampleIndex + 2] * part3 +
|
||||||
|
weightVector[sampleIndex + 3] * part4;
|
||||||
switch (getAlgorithm()) {
|
switch (getAlgorithm()) {
|
||||||
case Algorithm::ROIAlignMax:
|
case Algorithm::ROIAlignMax:
|
||||||
{
|
{
|
||||||
float sampleValue = std::max(
|
|
||||||
{weightVector[sampleIndex] * part1,
|
|
||||||
weightVector[sampleIndex + 1] * part2,
|
|
||||||
weightVector[sampleIndex + 2] * part3,
|
|
||||||
weightVector[sampleIndex + 3] * part4});
|
|
||||||
pooledValue = sampleValue > pooledValue ? sampleValue : pooledValue;
|
pooledValue = sampleValue > pooledValue ? sampleValue : pooledValue;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Algorithm::ROIAlignAvg:
|
case Algorithm::ROIAlignAvg:
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
float sampleValue =
|
|
||||||
weightVector[sampleIndex] * part1 +
|
|
||||||
weightVector[sampleIndex + 1] * part2 +
|
|
||||||
weightVector[sampleIndex + 2] * part3 +
|
|
||||||
weightVector[sampleIndex + 3] * part4;
|
|
||||||
pooledValue += sampleValue / numSamplesInBin;
|
pooledValue += sampleValue / numSamplesInBin;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,22 +158,17 @@ void roi_align(const T* feature_maps,
|
|||||||
pooling_points[sample_index + 3].first,
|
pooling_points[sample_index + 3].first,
|
||||||
pooling_points[sample_index + 3].second})];
|
pooling_points[sample_index + 3].second})];
|
||||||
|
|
||||||
|
T sample_value = pooling_weights[sample_index] * sample_part_1 +
|
||||||
|
pooling_weights[sample_index + 1] * sample_part_2 +
|
||||||
|
pooling_weights[sample_index + 2] * sample_part_3 +
|
||||||
|
pooling_weights[sample_index + 3] * sample_part_4;
|
||||||
switch (pooling_mode) {
|
switch (pooling_mode) {
|
||||||
case ROIPoolingMode::MAX: {
|
case ROIPoolingMode::MAX: {
|
||||||
T sample_value = std::max({pooling_weights[sample_index] * sample_part_1,
|
|
||||||
pooling_weights[sample_index + 1] * sample_part_2,
|
|
||||||
pooling_weights[sample_index + 2] * sample_part_3,
|
|
||||||
pooling_weights[sample_index + 3] * sample_part_4});
|
|
||||||
|
|
||||||
pooled_value = sample_value > pooled_value ? sample_value : pooled_value;
|
pooled_value = sample_value > pooled_value ? sample_value : pooled_value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ROIPoolingMode::AVG:
|
case ROIPoolingMode::AVG:
|
||||||
default: {
|
default: {
|
||||||
T sample_value = pooling_weights[sample_index] * sample_part_1 +
|
|
||||||
pooling_weights[sample_index + 1] * sample_part_2 +
|
|
||||||
pooling_weights[sample_index + 2] * sample_part_3 +
|
|
||||||
pooling_weights[sample_index + 3] * sample_part_4;
|
|
||||||
pooled_value += sample_value / (num_samples_in_bin);
|
pooled_value += sample_value / (num_samples_in_bin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -119,23 +119,21 @@ TEST(op_eval, roi_align_max_pool) {
|
|||||||
make_host_tensor<element::Type_t::i64>(Shape{num_rois})}));
|
make_host_tensor<element::Type_t::i64>(Shape{num_rois})}));
|
||||||
|
|
||||||
std::vector<float> expected_vec{
|
std::vector<float> expected_vec{
|
||||||
2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f,
|
3.4375, 3.6875, 3.9375, 4.1875, 5.10417, 5.35417, 5.60417, 5.85417, 6.77083, 7.02083, 7.27083, 7.52083,
|
||||||
4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f,
|
28.4375, 28.6875, 28.9375, 29.1875, 30.1042, 30.3542, 30.6042, 30.8542, 31.7708, 32.0208, 32.2708, 32.5208,
|
||||||
29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f,
|
53.4375, 53.6875, 53.9375, 54.1875, 55.1042, 55.3542, 55.6042, 55.8542, 56.7708, 57.0208, 57.2708, 57.5208,
|
||||||
45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f,
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f,
|
50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
|
||||||
50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f,
|
7.8125, 7.8125, 7.875, 8.125, 9.47917, 9.47917, 9.54167, 9.79167, 11.1458, 11.1458, 11.2083, 11.4583,
|
||||||
8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f,
|
32.8125, 32.8125, 32.875, 33.125, 34.4792, 34.4792, 34.5417, 34.7917, 36.1458, 36.1458, 36.2083, 36.4583,
|
||||||
31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f,
|
57.8125, 57.8125, 57.875, 58.125, 59.4792, 59.4792, 59.5417, 59.7917, 61.1458, 61.1458, 61.2083, 61.4583,
|
||||||
27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f,
|
4.75, 5, 5.25, 5.5, 6.41667, 6.66667, 6.91667, 7.16667, 8.08333, 8.33333, 8.58333, 8.83333,
|
||||||
2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f,
|
29.75, 30, 30.25, 30.5, 31.4167, 31.6667, 31.9167, 32.1667, 33.0833, 33.3333, 33.5833, 33.8333,
|
||||||
20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f,
|
54.75, 55, 55.25, 55.5, 56.4167, 56.6667, 56.9167, 57.1667, 58.0833, 58.3333, 58.5833, 58.8333,
|
||||||
42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f,
|
7.1875, 7.1875, 7.1875, 7.25, 8.85417, 8.85417, 8.85417, 8.91667, 10.5208, 10.5208, 10.5208, 10.5833,
|
||||||
32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f,
|
32.1875, 32.1875, 32.1875, 32.25, 33.8542, 33.8542, 33.8542, 33.9167, 35.5208, 35.5208, 35.5208, 35.5833,
|
||||||
9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f,
|
57.1875, 57.1875, 57.1875, 57.25, 58.8542, 58.8542, 58.8542, 58.9167, 60.5208, 60.5208, 60.5208, 60.5833};
|
||||||
32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f,
|
|
||||||
56.25f, 56.25f, 56.25f, 56.25f};
|
|
||||||
const auto expected_shape = Shape{num_rois, C, pooled_height, pooled_width};
|
const auto expected_shape = Shape{num_rois, C, pooled_height, pooled_width};
|
||||||
|
|
||||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||||
|
Loading…
Reference in New Issue
Block a user