diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp new file mode 100644 index 00000000000..eb2e2807ffe --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/region_yolo.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +const std::vector inShapes_caffe = { + {1, 125, 13, 13} +}; + +const std::vector inShapes_mxnet = { + {1, 75, 52, 52}, + {1, 75, 32, 32}, + {1, 75, 26, 26}, + {1, 75, 16, 16}, + {1, 75, 13, 13}, + {1, 75, 8, 8} +}; + +const std::vector inShapes_v3 = { + {1, 255, 52, 52}, + {1, 255, 26, 26}, + {1, 255, 13, 13} +}; + +const std::vector> masks = { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8} +}; + +const std::vector do_softmax = {true, false}; +const std::vector classes = {80, 20}; +const std::vector num_regions = {5, 9}; +const size_t coords = 4; +const int start_axis = 1; +const int end_axis = 3; + +const auto testCase_yolov3 = ::testing::Combine( + ::testing::ValuesIn(inShapes_v3), + ::testing::Values(classes[0]), + ::testing::Values(coords), + ::testing::Values(num_regions[1]), + ::testing::Values(do_softmax[1]), + ::testing::Values(masks[2]), + ::testing::Values(start_axis), + ::testing::Values(end_axis), + ::testing::Values(InferenceEngine::Precision::FP32), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCase_yolov3_mxnet = ::testing::Combine( + ::testing::ValuesIn(inShapes_mxnet), + ::testing::Values(classes[1]), + ::testing::Values(coords), + ::testing::Values(num_regions[1]), + ::testing::Values(do_softmax[1]), + ::testing::Values(masks[1]), + ::testing::Values(start_axis), + ::testing::Values(end_axis), + ::testing::Values(InferenceEngine::Precision::FP32), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCase_yolov2_caffe = ::testing::Combine( + ::testing::ValuesIn(inShapes_caffe), + ::testing::Values(classes[1]), + ::testing::Values(coords), + ::testing::Values(num_regions[0]), + ::testing::Values(do_softmax[0]), + ::testing::Values(masks[0]), + ::testing::Values(start_axis), + ::testing::Values(end_axis), + ::testing::Values(InferenceEngine::Precision::FP32), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYolov3, RegionYoloLayerTest, testCase_yolov3, RegionYoloLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloMxnet, RegionYoloLayerTest, testCase_yolov3_mxnet, RegionYoloLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloCaffe, RegionYoloLayerTest, testCase_yolov2_caffe, RegionYoloLayerTest::getTestCaseName); diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp new file mode 100644 index 00000000000..c8d74f6003f --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "functional_test_utils/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { + +using regionYoloParamsTuple = std::tuple< + ngraph::Shape, // Input Shape + size_t, // classes + size_t, // coordinates + size_t, // num regions + bool, // do softmax + std::vector, // mask + int, // start axis + int, // end axis + InferenceEngine::Precision, // Network precision + std::string>; // Device name + +class RegionYoloLayerTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions \ No newline at end of file diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp new file mode 100644 index 00000000000..968909418bc --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ie_core.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "functional_test_utils/precision_utils.hpp" +#include "functional_test_utils/plugin_cache.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include "single_layer_tests/region_yolo.hpp" + +namespace LayerTestsDefinitions { + +std::string RegionYoloLayerTest::getTestCaseName(const testing::TestParamInfo &obj) { + ngraph::Shape inputShape; + size_t classes; + size_t coords; + size_t num_regions; + bool do_softmax; + std::vector mask; + int start_axis; + int end_axis; + InferenceEngine::Precision netPrecision; + std::string targetName; + std::tie(inputShape, classes, coords, num_regions, do_softmax , mask, start_axis, end_axis, netPrecision, targetName) = obj.param; + std::ostringstream result; + result << "IS=" << inputShape << "_"; + result << "classes=" << classes << "_"; + result << "coords=" << coords << "_"; + result << "num=" << num_regions << "_"; + result << "doSoftmax=" << do_softmax << "_"; + result << "axis=" << start_axis << "_"; + result << "endAxis=" << end_axis << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetName << "_"; + return result.str(); +} + +void RegionYoloLayerTest::SetUp() { + ngraph::Shape inputShape; + size_t classes; + size_t coords; + size_t num_regions; + bool do_softmax; + std::vector mask; + int start_axis; + int end_axis; + InferenceEngine::Precision netPrecision; + std::tie(inputShape, classes, coords, num_regions, do_softmax, mask, start_axis, end_axis, netPrecision, targetDevice) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto param = std::make_shared(ngraph::element::f32, inputShape); + auto region_yolo = std::make_shared(param, coords, classes, num_regions, do_softmax, mask, start_axis, end_axis); + function = std::make_shared(std::make_shared(region_yolo), ngraph::ParameterVector{param}, "RegionYolo"); +} + +TEST_P(RegionYoloLayerTest, CompareWithRefs) { + Run(); +}; + +} // namespace LayerTestsDefinitions \ No newline at end of file diff --git a/ngraph/core/include/ngraph/op/region_yolo.hpp b/ngraph/core/include/ngraph/op/region_yolo.hpp index 8dfdbb8e66d..b7d9181a968 100644 --- a/ngraph/core/include/ngraph/op/region_yolo.hpp +++ b/ngraph/core/include/ngraph/op/region_yolo.hpp @@ -79,7 +79,7 @@ namespace ngraph int m_axis; int m_end_axis; }; - } + } // namespace v0 using v0::RegionYolo; - } -} + } // namespace op +} // namespace ngraph diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp new file mode 100644 index 00000000000..2ca3f324e4f --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp @@ -0,0 +1,175 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include + +#include "ngraph/shape.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + static inline int entry_index(int width, + int height, + int coords, + int classes, + int outputs, + int batch, + int location, + int entry) + { + int n = location / (width * height); + int loc = location % (width * height); + return batch * outputs + n * width * height * (coords + classes + 1) + + entry * width * height + loc; + } + + template + static inline T sigmoid(float x) + { + return static_cast(1.f / (1.f + std::exp(-x))); + } + template + static inline void softmax_generic( + const T* src_data, T* dst_data, int batches, int channels, int height, int width) + { + const int area = height * width; + for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++) + { + const int offset = batch_idx * channels * area; + for (unsigned int i = 0; i < height * width; i++) + { + T max = src_data[batch_idx * channels * area + i]; + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + T val = src_data[offset + channel_idx * area + i]; + max = std::max(max, val); + } + + T sum = 0; + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + dst_data[offset + channel_idx * area + i] = + std::exp(src_data[offset + channel_idx * area + i] - max); + sum += dst_data[offset + channel_idx * area + i]; + } + + for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++) + { + dst_data[offset + channel_idx * area + i] /= sum; + } + } + } + } + + template + void region_yolo(const T* input, + T* output, + const Shape& input_shape, + const int coords, + const int classes, + const int regions, + const bool do_softmax, + const std::vector& mask) + { + NGRAPH_CHECK(input_shape.size() == 4); + + const int batches = input_shape[0]; + const int channels = input_shape[1]; + const int height = input_shape[2]; + const int width = input_shape[3]; + + const auto mask_size = mask.size(); + + std::copy(input, input + shape_size(input_shape), output); + + int num_regions = 0; + int end_index = 0; + + if (do_softmax) + { + // Region layer (Yolo v2) + num_regions = regions; + end_index = width * height; + } + else + { + // Yolo layer (Yolo v3) + num_regions = mask_size; + end_index = width * height * (classes + 1); + } + + const int inputs_size = width * height * num_regions * (classes + coords + 1); + + for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++) + { + for (unsigned int n = 0; n < num_regions; n++) + { + int index = entry_index(width, + height, + coords, + classes, + inputs_size, + batch_idx, + n * width * height, + 0); + std::transform(output + index, + output + index + 2 * width * height, + output + index, + [](T elem) { return sigmoid(elem); }); + + index = entry_index(width, + height, + coords, + classes, + inputs_size, + batch_idx, + n * width * height, + coords); + std::transform(output + index, + output + index + end_index, + output + index, + [](T elem) { return sigmoid(elem); }); + } + } + + if (do_softmax) + { + int index = + entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1); + int batch_offset = inputs_size / regions; + for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++) + { + softmax_generic(input + index + batch_idx * batch_offset, + output + index + batch_idx * batch_offset, + 1, + classes, + height, + width); + } + } + } + + } // namespace reference + + } // namespace runtime + +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/core/src/op/region_yolo.cpp b/ngraph/core/src/op/region_yolo.cpp index f260acec7f6..4eed7f59904 100644 --- a/ngraph/core/src/op/region_yolo.cpp +++ b/ngraph/core/src/op/region_yolo.cpp @@ -60,6 +60,12 @@ bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor) void op::RegionYolo::validate_and_infer_types() { auto input_et = get_input_element_type(0); + + NODE_VALIDATION_CHECK(this, + input_et.is_real(), + "Type of input is expected to be a floating point type. Got: ", + input_et); + if (get_input_partial_shape(0).is_static()) { Shape input_shape = get_input_partial_shape(0).to_shape(); diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 6f46f1430d9..6e3a9f5312c 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -325,6 +325,7 @@ set(MULTI_TEST_SRC backend/reduce_min.in.cpp backend/reduce_prod.in.cpp backend/reduce_sum.in.cpp + backend/region_yolo.in.cpp backend/relu.in.cpp backend/reorg_yolo.in.cpp backend/replace_slice.in.cpp diff --git a/ngraph/test/attributes.cpp b/ngraph/test/attributes.cpp index 322c8605de7..64a5a60451d 100644 --- a/ngraph/test/attributes.cpp +++ b/ngraph/test/attributes.cpp @@ -787,7 +787,7 @@ TEST(attributes, reduce_sum_op) TEST(attributes, region_yolo_op) { FactoryRegistry::get().register_factory(); - auto data = make_shared(element::i64, Shape{1, 255, 26, 26}); + auto data = make_shared(element::f32, Shape{1, 255, 26, 26}); size_t num_coords = 4; size_t num_classes = 1; diff --git a/ngraph/test/backend/region_yolo.in.cpp b/ngraph/test/backend/region_yolo.in.cpp new file mode 100644 index 00000000000..8d520c4929a --- /dev/null +++ b/ngraph/test/backend/region_yolo.in.cpp @@ -0,0 +1,86 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/engine/test_engines.hpp" +#include "util/test_case.hpp" +#include "util/test_control.hpp" + +NGRAPH_SUPPRESS_DEPRECATED_START + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; +using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME}); + +NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v2_caffe) +{ + const size_t num = 5; + const size_t coords = 4; + const size_t classes = 20; + const size_t batch = 1; + const size_t channels = 125; + const size_t width = 13; + const size_t height = 13; + const size_t count = width * height * channels; + const std::vector mask{0, 1, 2}; + + Shape input_shape{batch, channels, height, width}; + Shape output_shape{batch, channels * height * width}; + + auto A = make_shared(element::f32, input_shape); + auto R = make_shared(A, coords, classes, num, true, mask, 1, 3); + auto f = make_shared(R, ParameterVector{A}); + + auto test_case = test::TestCase(f); + + test_case.add_input_from_file(input_shape, TEST_FILES, "region_in_yolov2_caffe.data"); + test_case.add_expected_output_from_file( + output_shape, TEST_FILES, "region_out_yolov2_caffe.data"); + test_case.run_with_tolerance_as_fp(1.0e-4f); +} + +NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v3_mxnet) +{ + const size_t num = 9; + const size_t coords = 4; + const size_t classes = 20; + const size_t batch = 1; + const size_t channels = 75; + const size_t width = 32; + const size_t height = 32; + const std::vector mask{0, 1, 2}; + + Shape shape{batch, channels, height, width}; + const auto count = shape_size(shape); + + const auto A = make_shared(element::f32, shape); + const auto R = make_shared(A, coords, classes, num, false, mask, 1, 3); + const auto f = make_shared(R, ParameterVector{A}); + + EXPECT_EQ(R->get_output_shape(0), shape); + + auto test_case = test::TestCase(f); + + test_case.add_input_from_file(shape, TEST_FILES, "region_in_yolov3_mxnet.data"); + test_case.add_expected_output_from_file( + shape, TEST_FILES, "region_out_yolov3_mxnet.data"); + test_case.run_with_tolerance_as_fp(1.0e-4f); +} diff --git a/ngraph/test/files/region_in_yolov2_caffe.data b/ngraph/test/files/region_in_yolov2_caffe.data new file mode 100644 index 00000000000..3111300e54d Binary files /dev/null and b/ngraph/test/files/region_in_yolov2_caffe.data differ diff --git a/ngraph/test/files/region_in_yolov3_mxnet.data b/ngraph/test/files/region_in_yolov3_mxnet.data new file mode 100644 index 00000000000..7fea67d7062 Binary files /dev/null and b/ngraph/test/files/region_in_yolov3_mxnet.data differ diff --git a/ngraph/test/files/region_out_yolov2_caffe.data b/ngraph/test/files/region_out_yolov2_caffe.data new file mode 100644 index 00000000000..44807ba0ce6 Binary files /dev/null and b/ngraph/test/files/region_out_yolov2_caffe.data differ diff --git a/ngraph/test/files/region_out_yolov3_mxnet.data b/ngraph/test/files/region_out_yolov3_mxnet.data new file mode 100644 index 00000000000..b5336a7d5dc Binary files /dev/null and b/ngraph/test/files/region_out_yolov3_mxnet.data differ diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index b248835517a..f2ae030e8c3 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1466,6 +1466,8 @@ IE_GPU.matmul_2x2_2x2 IE_GPU.matmul_2x3_3x3 IE_GPU.matmul_3x2_3x3_transpose IE_GPU.matmul_3x2_2x3_transpose +IE_GPU.region_yolo_v2_caffe +IE_GPU.region_yolo_v3_mxnet # Unsupported collapse op with dynamic shape IE_GPU.builder_opset1_collapse_dyn_shape diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index 0070aaab1dd..d78518810df 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -77,6 +77,7 @@ #include "ngraph/runtime/reference/prior_box.hpp" #include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/quantize.hpp" +#include "ngraph/runtime/reference/region_yolo.hpp" #include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/reorg_yolo.hpp" #include "ngraph/runtime/reference/replace_slice.hpp" @@ -1187,6 +1188,19 @@ protected: break; } + case OP_TYPEID::RegionYolo_v0: + { + const op::RegionYolo* region_yolo = static_cast(&node); + reference::region_yolo(args[0]->get_data_ptr(), + out[0]->get_data_ptr(), + args[0]->get_shape(), + region_yolo->get_num_coords(), + region_yolo->get_num_classes(), + region_yolo->get_num_regions(), + region_yolo->get_do_softmax(), + region_yolo->get_mask()); + break; + } case OP_TYPEID::Relu: { size_t element_count = shape_size(node.get_output_shape(0)); diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index de33cda40be..4cfe6693f17 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -21,6 +21,7 @@ #define ID_SUFFIX(NAME) NAME##_v0 NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0) NGRAPH_OP(DetectionOutput, op::v0) +NGRAPH_OP(RegionYolo, op::v0) NGRAPH_OP(ReorgYolo, op::v0) NGRAPH_OP(RNNCell, op::v0) #undef ID_SUFFIX