Implement reference nGraph implementation for operation ExperimentalDetectronROIFeatureExtractor (#6484)

* Written reference implementation of the operation ExperimentalDetectronROIFeatureExtractor.

* Small fixes.

* Started to write tests for evaluation of the operation ExperimentalDetectronROIfeatureExtractor.

* Written test for evaluation of the nGraph operation ExperimentalDetectronROIFeatureExtractor.

* Some changes.

* Added debug prints to evaluates.map.

* Added more debug prints.

* Added another debug prints.

* Added more debug prints.

* Added more debug prints.

* Added more debug prints.

* Inserted additional static_casts.

* Added more static_casts.

* Commented some debug prints.

* Some reversion.

* Deleted some debug prints.

* Deleted some debug prints.

* Deleted more debug prints.

* Added some casts and debug prints.

* Some changes.

* Small changes.

* Some changes.

* Added png files.

* Small changes.

* Code style fixes.

* Code style fixes.

* Rewritten some auxiliary functions.

* Corrected the body of the function experimental_detectron_roi_feature_extractor().

* Some code style fixes.

* Code style fixes.

* Small code style fixes.

* Commented one debug print.

* Small changes.

* Added some debug print.

* Small changes.

* Added more debug prints.

* Small fixes.

* Added more debug prints.

* Commented some code.

* Indexing operation [] was replaced by .at() method in the function pre_calc_for_bilinear_interpolate().

* Deleted unneeded variables w1, w2, w3, w4.

* Deleted variable xx.

* Added GCC pragma before the function pre_calc_for_bilinear_interpolate().

* Fixes in macros.

* Fixed pragma before the function pre_calc_for_bilinear_interpolate().

* Deleted some debug prints.

* Deleted more debug prints and fixed some code style issues.

* Deleted redundant assert.

* Deleted redundant assert in the function split_points().

* Started to move tests for nGraph reference implementation of ExperimentalDetectronROIFeatureExtractor to template plugin.

* Enabled test INTERPRETER.onnx_model_experimental_detectron_roi_feature_extractor.

* Deleted backend tests for the reference nGraph implementation of the operation ExperimentalDetectronROIFeatureExtractor.

* Deleted commented code.

* Fixed typo.

* Some fixes.

* Some fixes.

* Some fixes.

* Some fixes.

* Some fixes.

* Renamed the function that calculates ROIAlign.

* Deleted redundant usings.

* Now input shapes are parameters of test.

* Small fix.

* Now element type is also test parameter.

* Deleted some commented code.

* Added test for float16 case.

* Small fix.

* Added test for bfloat16 case.

* Deleted redundant parameters of tests.

* Deleted commented code.

* Deleted redundant structure.

* Small fix.

* Some reverting.
This commit is contained in:
Vladimir Gavrilov 2021-09-27 12:49:18 +03:00 committed by GitHub
parent 7fa9bbf6fc
commit 1d3df63d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 730 additions and 1 deletions

View File

@ -0,0 +1,230 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <ngraph/ngraph.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <tuple>
#include "base_reference_test.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace reference_tests;
struct ExperimentalROIParams {
ExperimentalROIParams(const std::vector<Tensor>& experimental_detectron_roi_feature_inputs,
const std::vector<Tensor>& expected_results,
const std::string& test_case_name)
: inputs{experimental_detectron_roi_feature_inputs},
expected_results{expected_results},
test_case_name{test_case_name} {}
std::vector<Tensor> inputs;
std::vector<Tensor> expected_results;
std::string test_case_name;
};
class ReferenceExperimentalROILayerTest : public testing::TestWithParam<ExperimentalROIParams>, public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
function = create_function(params.inputs);
inputData.reserve(params.inputs.size());
refOutData.reserve(params.expected_results.size());
for (const auto& input_tensor : params.inputs) {
inputData.push_back(input_tensor.data);
}
for (const auto& expected_tensor : params.expected_results) {
refOutData.push_back(expected_tensor.data);
}
}
static std::string getTestCaseName(const testing::TestParamInfo<ExperimentalROIParams>& obj) {
auto param = obj.param;
return param.test_case_name;
}
private:
std::shared_ptr<Function> create_function(const std::vector<Tensor>& inputs) {
op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes attrs;
attrs.aligned = false;
attrs.output_size = 3;
attrs.sampling_ratio = 2;
attrs.pyramid_scales = {4};
const size_t num_of_inputs = inputs.size();
NodeVector node_vector(num_of_inputs);
ParameterVector parameter_vector(num_of_inputs);
for (size_t i = 0; i < num_of_inputs; ++i) {
const auto& current_input = inputs[i];
auto current_parameter = std::make_shared<op::Parameter>(current_input.type, current_input.shape);
node_vector[i] = current_parameter;
parameter_vector[i] = current_parameter;
}
auto roi = std::make_shared<op::v6::ExperimentalDetectronROIFeatureExtractor>(node_vector, attrs);
auto fun = std::make_shared<Function>(OutputVector{roi->output(0), roi->output(1)}, parameter_vector);
return fun;
}
};
TEST_P(ReferenceExperimentalROILayerTest, ExperimentalROIWithHardcodedRefs) {
Exec();
}
INSTANTIATE_TEST_SUITE_P(
smoke_ExperimentalROI_With_Hardcoded_Refs,
ReferenceExperimentalROILayerTest,
::testing::Values(
ExperimentalROIParams(
std::vector<Tensor>{Tensor(Shape{2, 4},
ngraph::element::f32,
std::vector<float>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}),
Tensor(Shape{1, 2, 2, 3},
ngraph::element::f32,
std::vector<float>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0})},
std::vector<Tensor>{Tensor(Shape{2, 2, 3, 3},
ngraph::element::f32,
std::vector<float>{1.416667,
1.75,
2.083333,
2.416667,
2.75,
3.083333,
3.166667,
3.5,
3.833333,
7.416667,
7.75,
8.083333,
8.416667,
8.75,
9.083334,
9.166666,
9.5,
9.833334,
4.166667,
4.5,
4.833333,
4.166667,
4.5,
4.833333,
2.083333,
2.25,
2.416667,
10.16667,
10.5,
10.83333,
10.16667,
10.5,
10.83333,
5.083333,
5.25,
5.416667}),
Tensor(Shape{2, 4},
ngraph::element::f32,
std::vector<float>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0})},
"experimental_detectron_roi_feature_eval_f32"),
ExperimentalROIParams(
std::vector<Tensor>{Tensor(Shape{2, 4},
ngraph::element::f16,
std::vector<ngraph::float16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}),
Tensor(Shape{1, 2, 2, 3},
ngraph::element::f16,
std::vector<ngraph::float16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0})},
std::vector<Tensor>{Tensor(Shape{2, 2, 3, 3},
ngraph::element::f16,
std::vector<ngraph::float16>{1.416667,
1.75,
2.083333,
2.416667,
2.75,
3.083333,
3.166667,
3.5,
3.833333,
7.416667,
7.75,
8.083333,
8.416667,
8.75,
9.083334,
9.166666,
9.5,
9.833334,
4.166667,
4.5,
4.833333,
4.166667,
4.5,
4.833333,
2.083333,
2.25,
2.416667,
10.16667,
10.5,
10.83333,
10.16667,
10.5,
10.83333,
5.083333,
5.25,
5.416667}),
Tensor(Shape{2, 4},
ngraph::element::f16,
std::vector<ngraph::float16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0})},
"experimental_detectron_roi_feature_eval_f16"),
ExperimentalROIParams(
std::vector<Tensor>{Tensor(Shape{2, 4},
ngraph::element::bf16,
std::vector<ngraph::bfloat16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}),
Tensor(Shape{1, 2, 2, 3},
ngraph::element::bf16,
std::vector<ngraph::bfloat16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0})},
std::vector<Tensor>{Tensor(Shape{2, 2, 3, 3},
ngraph::element::bf16,
std::vector<ngraph::bfloat16>{1.416667,
1.75,
2.083333,
2.416667,
2.75,
3.083333,
3.166667,
3.5,
3.833333,
7.416667,
7.75,
8.083333,
8.416667,
8.75,
9.083334,
9.166666,
9.5,
9.833334,
4.166667,
4.5,
4.833333,
4.166667,
4.5,
4.833333,
2.083333,
2.25,
2.416667,
10.16667,
10.5,
10.83333,
10.16667,
10.5,
10.83333,
5.083333,
5.25,
5.416667}),
Tensor(Shape{2, 4},
ngraph::element::bf16,
std::vector<ngraph::bfloat16>{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0})},
"experimental_detectron_roi_feature_eval_bf16")));

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include <cstdint>
#include <ngraph/runtime/host_tensor.hpp>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph {
namespace runtime {
namespace reference {
void experimental_detectron_roi_feature_extractor(
const std::vector<std::vector<float>>& inputs,
const std::vector<Shape>& input_shapes,
const op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes& attrs,
float* output_rois_features,
float* output_rois);
void experimental_detectron_roi_feature_extractor_postprocessing(void* prois_features,
void* prois,
const ngraph::element::Type output_type,
const std::vector<float>& output_roi_features,
const std::vector<float>& output_rois,
const Shape& output_roi_features_shape,
const Shape& output_rois_shape);
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -0,0 +1,387 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/runtime/reference/experimental_detectron_roi_feature_extractor.hpp"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <numeric>
#include "ngraph/op/experimental_detectron_roi_feature.hpp"
#include "ngraph/shape.hpp"
#if defined(__GNUC__) && !defined(__clang__)
# if defined(__linux__) && defined(__i386__) && (__GNUC__ == 7 && __GNUC_MINOR__ == 5 && __GNUC_PATCHLEVEL__ == 0)
# define NEED_FIX 1
# else
# define NEED_FIX 0
# endif
#else
# define NEED_FIX 0
#endif
namespace {
constexpr int64_t input_rois_port = 0;
constexpr int64_t input_features_start_port = 1;
void redistribute_rois(const std::vector<float>& rois, std::vector<int64_t>& level_ids, const int64_t levels_num) {
const float canonical_scale = 224.0f;
const int64_t canonical_level = 2;
const size_t num_rois = level_ids.size();
for (size_t i = 0; i < num_rois; ++i) {
const float x0 = rois[4 * i + 0];
const float y0 = rois[4 * i + 1];
const float x1 = rois[4 * i + 2];
const float y1 = rois[4 * i + 3];
int64_t target_level = levels_num;
float area = (x1 - x0) * (y1 - y0);
if (area > 0) {
area = std::sqrt(area) / canonical_scale;
area = std::log2(area + 1e-6f);
target_level = static_cast<int64_t>(std::floor(area + canonical_level));
target_level = std::max(static_cast<int64_t>(0), std::min(levels_num - 1, target_level));
}
level_ids[i] = target_level;
}
}
void reord(const std::vector<float>& src_data,
const std::vector<int64_t>& ranks,
const int64_t step,
float* dst_data,
std::vector<int64_t>& dst_mapping) {
int64_t n = static_cast<int64_t>(ranks.size());
std::iota(dst_mapping.begin(), dst_mapping.end(), 0);
std::sort(dst_mapping.begin(), dst_mapping.end(), [&ranks](int64_t i1, int64_t i2) {
return ranks[i1] < ranks[i2];
});
for (int64_t i = 0; i < n; ++i) {
const int64_t j = dst_mapping[i];
memcpy(dst_data + i * step, src_data.data() + j * step, sizeof(float) * step);
}
}
void split_points(const std::vector<int64_t>& ids, std::vector<int64_t>& rois_per_level, const int64_t levels_num) {
rois_per_level.clear();
rois_per_level.resize(levels_num, 0);
for (size_t i = 0; i < ids.size(); ++i) {
rois_per_level[ids[i]]++;
}
for (int64_t i = 1; i < levels_num; ++i) {
rois_per_level[i] += rois_per_level[i - 1];
}
rois_per_level.insert(rois_per_level.begin(), 0);
}
// implementation taken from Caffe2
template <typename T>
struct PreCalc {
int64_t pos1;
int64_t pos2;
int64_t pos3;
int64_t pos4;
T w1;
T w2;
T w3;
T w4;
};
// The function pre_calc_for_bilinear_interpolate() gives incorrect results for -O3 optimization level, when IE
// is compiled using GCC 7.5.0 on Ubuntu 18.04 32-bit. But results are correct, for example, if we use Clang 10.0
// on Ubuntu 18.04 32-bit with -O3 optimization level. Next, the function pre_calc_for_bilinear_interpolate()
// gives incorrect results after compiling by GCC 7.5.0 or Clang 10 in Ubuntu 18.04 32-bit, if the optimization
// level is -O1 or -O2. Finally, the function gives correct result in Ubuntu 18.04 32-bit, if the optimization
// level is -O0.
#if NEED_FIX
# pragma GCC push_options
# pragma GCC optimize("-O0")
#endif
template <typename T>
void pre_calc_for_bilinear_interpolate(const int64_t height,
const int64_t width,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t iy_upper,
const int64_t ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int64_t roi_bin_grid_h,
int64_t roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int64_t pre_calc_index = 0;
for (int64_t ph = 0; ph < pooled_height; ph++) {
for (int64_t pw = 0; pw < pooled_width; pw++) {
for (int64_t iy = 0; iy < iy_upper; iy++) {
for (int64_t ix = 0; ix < ix_upper; ix++) {
T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
(static_cast<T>(iy) + static_cast<T>(0.5f)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
(static_cast<T>(ix) + static_cast<T>(0.5f)) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// deal with: inverse elements are out of feature map boundary
if (y < static_cast<T>(-1.0f) || y > static_cast<T>(height) || x < static_cast<T>(-1.0f) ||
x > static_cast<T>(width)) {
// empty
pre_calc_index += 1;
continue;
}
y = std::max(y, static_cast<T>(0.0f));
x = std::max(x, static_cast<T>(0.0f));
int64_t y_low = static_cast<int64_t>(y);
int64_t x_low = static_cast<int64_t>(x);
int64_t y_high = 0;
int64_t x_high = 0;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = static_cast<T>(1.0) - ly;
T hx = static_cast<T>(1.0) - lx;
// save weights and indeces
PreCalc<T> pc;
pc.pos1 = y_low * width + x_low;
pc.pos2 = y_low * width + x_high;
pc.pos3 = y_high * width + x_low;
pc.pos4 = y_high * width + x_high;
pc.w1 = hy * hx;
pc.w2 = hy * lx;
pc.w3 = ly * hx;
pc.w4 = ly * lx;
pre_calc.at(pre_calc_index) = pc;
pre_calc_index += 1;
}
}
}
}
}
#if NEED_FIX
# pragma GCC pop_options
#endif
template <typename T>
void ROIAlignForward(const int64_t nthreads,
const T* bottom_data,
const T& spatial_scale,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const T* bottom_rois,
const bool aligned,
T* top_data) {
int64_t roi_cols = 4;
int64_t n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
for (int64_t n = 0; n < n_rois; ++n) {
int64_t index_n = n * channels * pooled_width * pooled_height;
// roi could have 4 or 5 columns
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
int64_t roi_batch_ind = 0;
if (roi_cols == 5) {
roi_batch_ind = static_cast<int64_t>(offset_bottom_rois[0]);
offset_bottom_rois++;
}
T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
// Do not use rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[0] * spatial_scale - offset;
T roi_start_h = offset_bottom_rois[1] * spatial_scale - offset;
T roi_end_w = offset_bottom_rois[2] * spatial_scale - offset;
T roi_end_h = offset_bottom_rois[3] * spatial_scale - offset;
// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, static_cast<T>(1.0));
T roi_height = std::max(roi_end_h - roi_start_h, static_cast<T>(1.0));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int64_t roi_bin_grid_h =
(sampling_ratio > 0) ? sampling_ratio : static_cast<int64_t>(std::ceil(roi_height / pooled_height));
int64_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : static_cast<int64_t>(std::ceil(roi_width / pooled_width));
// We do average (integral) pooling inside a bin
const T count = static_cast<T>(roi_bin_grid_h * roi_bin_grid_w);
// we want to precalculate indices and weights shared by all channels,
// this is the key point of optimization
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate<T>(height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_start_h,
roi_start_w,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
pre_calc);
for (int64_t c = 0; c < channels; c++) {
int64_t index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
int64_t pre_calc_index = 0;
for (int64_t ph = 0; ph < pooled_height; ph++) {
for (int64_t pw = 0; pw < pooled_width; pw++) {
int64_t index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int64_t iy = 0; iy < roi_bin_grid_h; iy++) {
for (int64_t ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] + pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] + pc.w4 * offset_bottom_data[pc.pos4];
pre_calc_index += 1;
}
}
output_val /= count;
top_data[index] = output_val;
} // for pw
} // for ph
} // for c
}
}
} // namespace
namespace ngraph {
namespace runtime {
namespace reference {
void experimental_detectron_roi_feature_extractor(
const std::vector<std::vector<float>>& inputs,
const std::vector<Shape>& input_shapes,
const op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes& attrs,
float* output_rois_features,
float* output_rois) {
int64_t output_dim = attrs.output_size;
auto pyramid_scales = attrs.pyramid_scales;
int64_t sampling_ratio = attrs.sampling_ratio;
bool aligned = attrs.aligned;
int64_t pooled_height = output_dim;
int64_t pooled_width = output_dim;
const int64_t levels_num = static_cast<int64_t>(inputs.size() - input_features_start_port);
const int64_t num_rois = static_cast<int64_t>(input_shapes[input_rois_port][0]);
const int64_t channels_num = static_cast<int64_t>(input_shapes[input_features_start_port][1]);
const int64_t feaxels_per_roi = pooled_height * pooled_width * channels_num;
const float* input_rois = inputs[input_rois_port].data();
std::vector<int64_t> level_ids(num_rois, 0);
redistribute_rois(inputs[input_rois_port], level_ids, levels_num);
std::vector<float> reordered_rois(4 * num_rois, 0);
std::vector<int64_t> original_rois_mapping(num_rois, 0);
reord(inputs[input_rois_port], level_ids, 4, reordered_rois.data(), original_rois_mapping);
std::vector<int64_t> rois_per_level;
split_points(level_ids, rois_per_level, levels_num + 1);
std::vector<float> output_rois_features_temp(feaxels_per_roi * num_rois, 0);
for (int64_t i = 0; i < levels_num; ++i) {
const int64_t level_rois_offset = rois_per_level[i];
const int64_t level_rois_num = rois_per_level[i + 1] - level_rois_offset;
if (level_rois_num > 0) {
const float* featuremap = inputs[input_features_start_port + i].data();
const int64_t featuremap_height = static_cast<int64_t>(input_shapes[input_features_start_port + i][2]);
const int64_t featuremap_width = static_cast<int64_t>(input_shapes[input_features_start_port + i][3]);
ROIAlignForward<float>(feaxels_per_roi * level_rois_num,
featuremap,
1.0f / pyramid_scales[i],
channels_num,
featuremap_height,
featuremap_width,
pooled_height,
pooled_width,
sampling_ratio,
&reordered_rois[4 * level_rois_offset],
aligned,
&output_rois_features_temp[feaxels_per_roi * level_rois_offset]);
}
}
std::vector<int64_t> dummy_mapping(num_rois, 0);
reord(output_rois_features_temp, original_rois_mapping, feaxels_per_roi, output_rois_features, dummy_mapping);
memcpy(output_rois, input_rois, 4 * num_rois * sizeof(float));
}
void experimental_detectron_roi_feature_extractor_postprocessing(void* prois_features,
void* prois,
const ngraph::element::Type output_type,
const std::vector<float>& output_rois_features,
const std::vector<float>& output_rois,
const Shape& output_rois_features_shape,
const Shape& output_rois_shape) {
size_t output_rois_features_size = shape_size(output_rois_features_shape);
size_t output_rois_size = shape_size(output_rois_shape);
switch (output_type) {
case element::Type_t::bf16: {
bfloat16* output_rois_features_ptr = reinterpret_cast<bfloat16*>(prois_features);
bfloat16* output_rois_ptr = reinterpret_cast<bfloat16*>(prois);
for (size_t i = 0; i < output_rois_features_size; ++i) {
output_rois_features_ptr[i] = bfloat16(output_rois_features[i]);
}
for (size_t i = 0; i < output_rois_size; ++i) {
output_rois_ptr[i] = bfloat16(output_rois[i]);
}
} break;
case element::Type_t::f16: {
float16* output_rois_features_ptr = reinterpret_cast<float16*>(prois_features);
float16* output_rois_ptr = reinterpret_cast<float16*>(prois);
for (size_t i = 0; i < output_rois_features_size; ++i) {
output_rois_features_ptr[i] = float16(output_rois_features[i]);
}
for (size_t i = 0; i < output_rois_size; ++i) {
output_rois_ptr[i] = float16(output_rois[i]);
}
} break;
case element::Type_t::f32: {
float* output_rois_features_ptr = reinterpret_cast<float*>(prois_features);
float* output_rois_ptr = reinterpret_cast<float*>(prois);
memcpy(output_rois_features_ptr, output_rois_features.data(), output_rois_features_size * sizeof(float));
memcpy(output_rois_ptr, output_rois.data(), output_rois_size * sizeof(float));
} break;
default:;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -30,6 +30,7 @@
#include <ngraph/runtime/reference/experimental_detectron_detection_output.hpp>
#include <ngraph/runtime/reference/experimental_detectron_prior_grid_generator.hpp>
#include <ngraph/runtime/reference/experimental_detectron_proposal_single_image.hpp>
#include <ngraph/runtime/reference/experimental_detectron_roi_feature_extractor.hpp>
#include <ngraph/runtime/reference/experimental_detectron_topk_rois.hpp>
#include <ngraph/runtime/reference/extract_image_patches.hpp>
#include <ngraph/runtime/reference/fft.hpp>
@ -1222,6 +1223,80 @@ bool evaluate(const shared_ptr<op::v6::ExperimentalDetectronDetectionOutput>& op
return true;
}
namespace experimental_roi_feature {
struct InfoForEDROIFeature {
Shape output_rois_features_shape;
Shape output_rois_shape;
};
InfoForEDROIFeature get_info_for_ed_roi_feature(const std::vector<Shape> input_shapes,
const op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes& attrs) {
InfoForEDROIFeature result;
size_t output_size = static_cast<size_t>(attrs.output_size);
auto out_shape = Shape{0, 0, output_size, output_size};
auto out_rois_shape = Shape{0, 4};
auto rois_shape = input_shapes[0];
out_shape[0] = rois_shape[0];
out_rois_shape[0] = rois_shape[0];
out_shape[1] = input_shapes[1][1];
result.output_rois_features_shape = out_shape;
result.output_rois_shape = out_rois_shape;
return result;
}
} // namespace experimental_roi_feature
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v6::ExperimentalDetectronROIFeatureExtractor>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
const auto attrs = op->get_attrs();
std::vector<std::vector<float>> input_data;
std::vector<Shape> input_shapes;
for (const auto& input : inputs) {
const auto current_shape = input->get_shape();
input_data.push_back(get_floats(input, current_shape));
input_shapes.push_back(current_shape);
}
const auto info = experimental_roi_feature::get_info_for_ed_roi_feature(input_shapes, attrs);
const auto& output_rois_features_shape = info.output_rois_features_shape;
const auto& output_rois_shape = info.output_rois_shape;
const auto output_type = op->get_input_element_type(0);
outputs[0]->set_element_type(output_type);
outputs[0]->set_shape(output_rois_features_shape);
outputs[1]->set_element_type(output_type);
outputs[1]->set_shape(output_rois_shape);
std::vector<float> output_rois_features(shape_size(output_rois_features_shape));
std::vector<float> output_rois(shape_size(output_rois_shape));
runtime::reference::experimental_detectron_roi_feature_extractor(input_data,
input_shapes,
attrs,
output_rois_features.data(),
output_rois.data());
runtime::reference::experimental_detectron_roi_feature_extractor_postprocessing(outputs[0]->get_data_ptr(),
outputs[1]->get_data_ptr(),
output_type,
output_rois_features,
output_rois,
output_rois_features_shape,
output_rois_shape);
return true;
}
namespace fft_v7 {
struct InfoForFFT7 {
std::vector<float> input_data;

View File

@ -88,6 +88,7 @@ NGRAPH_OP(CTCGreedyDecoderSeqLen, op::v6)
NGRAPH_OP(ExperimentalDetectronDetectionOutput, op::v6)
NGRAPH_OP(ExperimentalDetectronGenerateProposalsSingleImage, op::v6)
NGRAPH_OP(ExperimentalDetectronPriorGridGenerator, op::v6)
NGRAPH_OP(ExperimentalDetectronROIFeatureExtractor, op::v6)
NGRAPH_OP(ExperimentalDetectronTopKROIs, op::v6)
NGRAPH_OP(GatherElements, op::v6)
NGRAPH_OP(MVN, ngraph::op::v6)

View File

@ -129,7 +129,7 @@ quantize_clamp_int32
minimum_u16
# Interpreter backend doesn't implement evaluate method for OP ExperimentalDetectronROIFeatureExtractor
INTERPRETER.onnx_model_experimental_detectron_roi_feature_extractor
# INTERPRETER.onnx_model_experimental_detectron_roi_feature_extractor
# No evaluator for DeformableConv2D
onnx_model_deformable_conv_2d