Reference Implementation for RegionYolo operator (#2474)

This commit is contained in:
Gabriele Galiero Casay 2020-10-15 22:30:12 +02:00 committed by GitHub
parent db85069713
commit c9b16a79f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 475 additions and 4 deletions

View File

@ -0,0 +1,85 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/region_yolo.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
const std::vector<ngraph::Shape> inShapes_caffe = {
{1, 125, 13, 13}
};
const std::vector<ngraph::Shape> inShapes_mxnet = {
{1, 75, 52, 52},
{1, 75, 32, 32},
{1, 75, 26, 26},
{1, 75, 16, 16},
{1, 75, 13, 13},
{1, 75, 8, 8}
};
const std::vector<ngraph::Shape> inShapes_v3 = {
{1, 255, 52, 52},
{1, 255, 26, 26},
{1, 255, 13, 13}
};
const std::vector<std::vector<int64_t>> masks = {
{0, 1, 2},
{3, 4, 5},
{6, 7, 8}
};
const std::vector<bool> do_softmax = {true, false};
const std::vector<size_t> classes = {80, 20};
const std::vector<size_t> num_regions = {5, 9};
const size_t coords = 4;
const int start_axis = 1;
const int end_axis = 3;
const auto testCase_yolov3 = ::testing::Combine(
::testing::ValuesIn(inShapes_v3),
::testing::Values(classes[0]),
::testing::Values(coords),
::testing::Values(num_regions[1]),
::testing::Values(do_softmax[1]),
::testing::Values(masks[2]),
::testing::Values(start_axis),
::testing::Values(end_axis),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
const auto testCase_yolov3_mxnet = ::testing::Combine(
::testing::ValuesIn(inShapes_mxnet),
::testing::Values(classes[1]),
::testing::Values(coords),
::testing::Values(num_regions[1]),
::testing::Values(do_softmax[1]),
::testing::Values(masks[1]),
::testing::Values(start_axis),
::testing::Values(end_axis),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
const auto testCase_yolov2_caffe = ::testing::Combine(
::testing::ValuesIn(inShapes_caffe),
::testing::Values(classes[1]),
::testing::Values(coords),
::testing::Values(num_regions[0]),
::testing::Values(do_softmax[0]),
::testing::Values(masks[0]),
::testing::Values(start_axis),
::testing::Values(end_axis),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYolov3, RegionYoloLayerTest, testCase_yolov3, RegionYoloLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloMxnet, RegionYoloLayerTest, testCase_yolov3_mxnet, RegionYoloLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloCaffe, RegionYoloLayerTest, testCase_yolov2_caffe, RegionYoloLayerTest::getTestCaseName);

View File

@ -0,0 +1,38 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace LayerTestsDefinitions {
using regionYoloParamsTuple = std::tuple<
ngraph::Shape, // Input Shape
size_t, // classes
size_t, // coordinates
size_t, // num regions
bool, // do softmax
std::vector<int64_t>, // mask
int, // start axis
int, // end axis
InferenceEngine::Precision, // Network precision
std::string>; // Device name
class RegionYoloLayerTest : public testing::WithParamInterface<regionYoloParamsTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,63 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_core.hpp"
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include "single_layer_tests/region_yolo.hpp"
namespace LayerTestsDefinitions {
std::string RegionYoloLayerTest::getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj) {
ngraph::Shape inputShape;
size_t classes;
size_t coords;
size_t num_regions;
bool do_softmax;
std::vector<int64_t> mask;
int start_axis;
int end_axis;
InferenceEngine::Precision netPrecision;
std::string targetName;
std::tie(inputShape, classes, coords, num_regions, do_softmax , mask, start_axis, end_axis, netPrecision, targetName) = obj.param;
std::ostringstream result;
result << "IS=" << inputShape << "_";
result << "classes=" << classes << "_";
result << "coords=" << coords << "_";
result << "num=" << num_regions << "_";
result << "doSoftmax=" << do_softmax << "_";
result << "axis=" << start_axis << "_";
result << "endAxis=" << end_axis << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetName << "_";
return result.str();
}
void RegionYoloLayerTest::SetUp() {
ngraph::Shape inputShape;
size_t classes;
size_t coords;
size_t num_regions;
bool do_softmax;
std::vector<int64_t> mask;
int start_axis;
int end_axis;
InferenceEngine::Precision netPrecision;
std::tie(inputShape, classes, coords, num_regions, do_softmax, mask, start_axis, end_axis, netPrecision, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto param = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, inputShape);
auto region_yolo = std::make_shared<ngraph::op::v0::RegionYolo>(param, coords, classes, num_regions, do_softmax, mask, start_axis, end_axis);
function = std::make_shared<ngraph::Function>(std::make_shared<ngraph::opset1::Result>(region_yolo), ngraph::ParameterVector{param}, "RegionYolo");
}
TEST_P(RegionYoloLayerTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -79,7 +79,7 @@ namespace ngraph
int m_axis;
int m_end_axis;
};
}
} // namespace v0
using v0::RegionYolo;
}
}
} // namespace op
} // namespace ngraph

View File

@ -0,0 +1,175 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
static inline int entry_index(int width,
int height,
int coords,
int classes,
int outputs,
int batch,
int location,
int entry)
{
int n = location / (width * height);
int loc = location % (width * height);
return batch * outputs + n * width * height * (coords + classes + 1) +
entry * width * height + loc;
}
template <typename T>
static inline T sigmoid(float x)
{
return static_cast<T>(1.f / (1.f + std::exp(-x)));
}
template <typename T>
static inline void softmax_generic(
const T* src_data, T* dst_data, int batches, int channels, int height, int width)
{
const int area = height * width;
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
{
const int offset = batch_idx * channels * area;
for (unsigned int i = 0; i < height * width; i++)
{
T max = src_data[batch_idx * channels * area + i];
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
T val = src_data[offset + channel_idx * area + i];
max = std::max(max, val);
}
T sum = 0;
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
dst_data[offset + channel_idx * area + i] =
std::exp(src_data[offset + channel_idx * area + i] - max);
sum += dst_data[offset + channel_idx * area + i];
}
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
{
dst_data[offset + channel_idx * area + i] /= sum;
}
}
}
}
template <typename T>
void region_yolo(const T* input,
T* output,
const Shape& input_shape,
const int coords,
const int classes,
const int regions,
const bool do_softmax,
const std::vector<int64_t>& mask)
{
NGRAPH_CHECK(input_shape.size() == 4);
const int batches = input_shape[0];
const int channels = input_shape[1];
const int height = input_shape[2];
const int width = input_shape[3];
const auto mask_size = mask.size();
std::copy(input, input + shape_size(input_shape), output);
int num_regions = 0;
int end_index = 0;
if (do_softmax)
{
// Region layer (Yolo v2)
num_regions = regions;
end_index = width * height;
}
else
{
// Yolo layer (Yolo v3)
num_regions = mask_size;
end_index = width * height * (classes + 1);
}
const int inputs_size = width * height * num_regions * (classes + coords + 1);
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
{
for (unsigned int n = 0; n < num_regions; n++)
{
int index = entry_index(width,
height,
coords,
classes,
inputs_size,
batch_idx,
n * width * height,
0);
std::transform(output + index,
output + index + 2 * width * height,
output + index,
[](T elem) { return sigmoid<T>(elem); });
index = entry_index(width,
height,
coords,
classes,
inputs_size,
batch_idx,
n * width * height,
coords);
std::transform(output + index,
output + index + end_index,
output + index,
[](T elem) { return sigmoid<T>(elem); });
}
}
if (do_softmax)
{
int index =
entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1);
int batch_offset = inputs_size / regions;
for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++)
{
softmax_generic<T>(input + index + batch_idx * batch_offset,
output + index + batch_idx * batch_offset,
1,
classes,
height,
width);
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -60,6 +60,12 @@ bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
void op::RegionYolo::validate_and_infer_types()
{
auto input_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_et.is_real(),
"Type of input is expected to be a floating point type. Got: ",
input_et);
if (get_input_partial_shape(0).is_static())
{
Shape input_shape = get_input_partial_shape(0).to_shape();

View File

@ -325,6 +325,7 @@ set(MULTI_TEST_SRC
backend/reduce_min.in.cpp
backend/reduce_prod.in.cpp
backend/reduce_sum.in.cpp
backend/region_yolo.in.cpp
backend/relu.in.cpp
backend/reorg_yolo.in.cpp
backend/replace_slice.in.cpp

View File

@ -787,7 +787,7 @@ TEST(attributes, reduce_sum_op)
TEST(attributes, region_yolo_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::RegionYolo>();
auto data = make_shared<op::Parameter>(element::i64, Shape{1, 255, 26, 26});
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 255, 26, 26});
size_t num_coords = 4;
size_t num_classes = 1;

View File

@ -0,0 +1,86 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <fstream>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v2_caffe)
{
const size_t num = 5;
const size_t coords = 4;
const size_t classes = 20;
const size_t batch = 1;
const size_t channels = 125;
const size_t width = 13;
const size_t height = 13;
const size_t count = width * height * channels;
const std::vector<int64_t> mask{0, 1, 2};
Shape input_shape{batch, channels, height, width};
Shape output_shape{batch, channels * height * width};
auto A = make_shared<op::Parameter>(element::f32, input_shape);
auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, true, mask, 1, 3);
auto f = make_shared<Function>(R, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input_from_file<float>(input_shape, TEST_FILES, "region_in_yolov2_caffe.data");
test_case.add_expected_output_from_file<float>(
output_shape, TEST_FILES, "region_out_yolov2_caffe.data");
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v3_mxnet)
{
const size_t num = 9;
const size_t coords = 4;
const size_t classes = 20;
const size_t batch = 1;
const size_t channels = 75;
const size_t width = 32;
const size_t height = 32;
const std::vector<int64_t> mask{0, 1, 2};
Shape shape{batch, channels, height, width};
const auto count = shape_size(shape);
const auto A = make_shared<op::Parameter>(element::f32, shape);
const auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, false, mask, 1, 3);
const auto f = make_shared<Function>(R, ParameterVector{A});
EXPECT_EQ(R->get_output_shape(0), shape);
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input_from_file<float>(shape, TEST_FILES, "region_in_yolov3_mxnet.data");
test_case.add_expected_output_from_file<float>(
shape, TEST_FILES, "region_out_yolov3_mxnet.data");
test_case.run_with_tolerance_as_fp(1.0e-4f);
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1466,6 +1466,8 @@ IE_GPU.matmul_2x2_2x2
IE_GPU.matmul_2x3_3x3
IE_GPU.matmul_3x2_3x3_transpose
IE_GPU.matmul_3x2_2x3_transpose
IE_GPU.region_yolo_v2_caffe
IE_GPU.region_yolo_v3_mxnet
# Unsupported collapse op with dynamic shape
IE_GPU.builder_opset1_collapse_dyn_shape

View File

@ -77,6 +77,7 @@
#include "ngraph/runtime/reference/prior_box.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/region_yolo.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reorg_yolo.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
@ -1187,6 +1188,19 @@ protected:
break;
}
case OP_TYPEID::RegionYolo_v0:
{
const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
region_yolo->get_num_coords(),
region_yolo->get_num_classes(),
region_yolo->get_num_regions(),
region_yolo->get_do_softmax(),
region_yolo->get_mask());
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));

View File

@ -21,6 +21,7 @@
#define ID_SUFFIX(NAME) NAME##_v0
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
NGRAPH_OP(DetectionOutput, op::v0)
NGRAPH_OP(RegionYolo, op::v0)
NGRAPH_OP(ReorgYolo, op::v0)
NGRAPH_OP(RNNCell, op::v0)
#undef ID_SUFFIX