[GPU] Update RegionYolo to use nGraph shape inference (#18657)
* Update RegionYolo to use ngraph shape infer Signed-off-by: Andrew Park <andrew.park@intel.com> * Add dynamic TCs for ov_gpu_func_tests Signed-off-by: Andrew Park <andrew.park@intel.com> * Add shape infer TCs for ov_gpu_unit_tests Signed-off-by: Andrew Park <andrew.park@intel.com> --------- Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
74b5f8673c
commit
2c889c8b5e
@ -50,27 +50,51 @@ public:
|
||||
size_t get_num_coords() const {
|
||||
return m_num_coords;
|
||||
}
|
||||
void set_num_coords(const size_t num_coords) {
|
||||
m_num_coords = num_coords;
|
||||
}
|
||||
size_t get_num_classes() const {
|
||||
return m_num_classes;
|
||||
}
|
||||
void set_num_classes(const size_t num_classes) {
|
||||
m_num_classes = num_classes;
|
||||
}
|
||||
size_t get_num_regions() const {
|
||||
return m_num_regions;
|
||||
}
|
||||
void set_num_regions(const size_t num_regions) {
|
||||
m_num_regions = num_regions;
|
||||
}
|
||||
bool get_do_softmax() const {
|
||||
return m_do_softmax;
|
||||
}
|
||||
void set_do_softmax(const bool do_softmax) {
|
||||
m_do_softmax = do_softmax;
|
||||
}
|
||||
const std::vector<int64_t>& get_mask() const {
|
||||
return m_mask;
|
||||
}
|
||||
void set_mask(const std::vector<int64_t>& mask) {
|
||||
m_mask = mask;
|
||||
}
|
||||
const std::vector<float>& get_anchors() const {
|
||||
return m_anchors;
|
||||
}
|
||||
void set_anchors(const std::vector<float>& anchors) {
|
||||
m_anchors = anchors;
|
||||
}
|
||||
int get_axis() const {
|
||||
return m_axis;
|
||||
}
|
||||
void set_axis(const int axis) {
|
||||
m_axis = axis;
|
||||
}
|
||||
int get_end_axis() const {
|
||||
return m_end_axis;
|
||||
}
|
||||
void set_end_axis(const int end_axis) {
|
||||
m_end_axis = end_axis;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t m_num_coords;
|
||||
|
@ -27,14 +27,20 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
const uint32_t coords,
|
||||
const uint32_t classes,
|
||||
const uint32_t num,
|
||||
const uint32_t mask_size = 0,
|
||||
const std::vector<int64_t>& mask,
|
||||
const uint32_t mask_size,
|
||||
const int32_t axis,
|
||||
const int32_t end_axis,
|
||||
const bool do_softmax = true,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, {output_padding}),
|
||||
coords(coords),
|
||||
classes(classes),
|
||||
num(num),
|
||||
mask(mask),
|
||||
mask_size(mask_size),
|
||||
axis(axis),
|
||||
end_axis(end_axis),
|
||||
do_softmax(do_softmax) {}
|
||||
|
||||
/// @brief Defines a scope of a region yolo normalization
|
||||
@ -43,7 +49,10 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
uint32_t coords;
|
||||
uint32_t classes;
|
||||
uint32_t num;
|
||||
std::vector<int64_t> mask;
|
||||
uint32_t mask_size;
|
||||
int32_t axis;
|
||||
int32_t end_axis;
|
||||
bool do_softmax;
|
||||
|
||||
size_t hash() const override {
|
||||
@ -51,7 +60,10 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
seed = hash_combine(seed, coords);
|
||||
seed = hash_combine(seed, classes);
|
||||
seed = hash_combine(seed, num);
|
||||
seed = hash_range(seed, mask.begin(), mask.end());
|
||||
seed = hash_combine(seed, mask_size);
|
||||
seed = hash_combine(seed, axis);
|
||||
seed = hash_combine(seed, end_axis);
|
||||
seed = hash_combine(seed, do_softmax);
|
||||
return seed;
|
||||
}
|
||||
@ -65,7 +77,10 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
return coords == rhs_casted.coords &&
|
||||
classes == rhs_casted.classes &&
|
||||
num == rhs_casted.num &&
|
||||
mask == rhs_casted.mask &&
|
||||
mask_size == rhs_casted.mask_size &&
|
||||
axis == rhs_casted.axis &&
|
||||
end_axis == rhs_casted.end_axis &&
|
||||
do_softmax == rhs_casted.do_softmax;
|
||||
}
|
||||
|
||||
@ -74,7 +89,10 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
ob << coords;
|
||||
ob << classes;
|
||||
ob << num;
|
||||
ob << mask;
|
||||
ob << mask_size;
|
||||
ob << axis;
|
||||
ob << end_axis;
|
||||
ob << do_softmax;
|
||||
}
|
||||
|
||||
@ -83,7 +101,10 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
ib >> coords;
|
||||
ib >> classes;
|
||||
ib >> num;
|
||||
ib >> mask;
|
||||
ib >> mask_size;
|
||||
ib >> axis;
|
||||
ib >> end_axis;
|
||||
ib >> do_softmax;
|
||||
}
|
||||
};
|
||||
|
@ -9,6 +9,17 @@
|
||||
#include <string>
|
||||
|
||||
namespace cldnn {
|
||||
template <>
|
||||
struct typed_program_node<region_yolo> : public typed_program_node_base<region_yolo> {
|
||||
using parent = typed_program_node_base<region_yolo>;
|
||||
|
||||
public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
|
||||
};
|
||||
|
||||
using region_yolo_node = typed_program_node<region_yolo>;
|
||||
|
||||
template <>
|
||||
@ -17,6 +28,8 @@ class typed_primitive_inst<region_yolo> : public typed_primitive_inst_base<regio
|
||||
using parent::parent;
|
||||
|
||||
public:
|
||||
template<typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(region_yolo_node const& node, kernel_impl_params const& impl_param);
|
||||
static layout calc_output_layout(region_yolo_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(region_yolo_node const& node);
|
||||
|
||||
|
@ -3,6 +3,8 @@
|
||||
//
|
||||
|
||||
#include "region_yolo_inst.h"
|
||||
#include "region_yolo_shape_inference.hpp"
|
||||
|
||||
#include "primitive_type_base.h"
|
||||
#include "json_object.h"
|
||||
#include <string>
|
||||
@ -34,6 +36,33 @@ layout region_yolo_inst::calc_output_layout(region_yolo_node const& node, kernel
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> region_yolo_inst::calc_output_layouts(region_yolo_node const& node, kernel_impl_params const& impl_param) {
|
||||
auto desc = impl_param.typed_desc<region_yolo>();
|
||||
auto input_layout = impl_param.get_input_layout(0);
|
||||
auto output_type = desc->output_data_types[0].value_or(input_layout.data_type);
|
||||
auto output_format = input_layout.format;
|
||||
|
||||
ov::op::v0::RegionYolo op;
|
||||
op.set_num_coords(static_cast<size_t>(desc->coords));
|
||||
op.set_num_classes(static_cast<size_t>(desc->classes));
|
||||
op.set_num_regions(static_cast<size_t>(desc->num));
|
||||
op.set_do_softmax(desc->do_softmax);
|
||||
op.set_mask(desc->mask);
|
||||
op.set_axis(desc->axis);
|
||||
op.set_end_axis(desc->end_axis);
|
||||
|
||||
std::vector<ShapeType> output_shapes = { ShapeType() };
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
input_layout.get<ShapeType>()
|
||||
};
|
||||
ov::op::v0::shape_infer(&op, input_shapes, output_shapes);
|
||||
|
||||
return { layout{output_shapes[0], output_type, output_format} };
|
||||
}
|
||||
|
||||
template std::vector<layout> region_yolo_inst::calc_output_layouts<ov::PartialShape>(region_yolo_node const& node, const kernel_impl_params& impl_param);
|
||||
|
||||
std::string region_yolo_inst::to_string(region_yolo_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
@ -41,7 +70,10 @@ std::string region_yolo_inst::to_string(region_yolo_node const& node) {
|
||||
auto classes = desc->classes;
|
||||
auto num = desc->num;
|
||||
auto do_softmax = desc->do_softmax;
|
||||
auto mask = desc->mask;
|
||||
auto mask_size = desc->mask_size;
|
||||
auto axis = desc->axis;
|
||||
auto end_axis = desc->end_axis;
|
||||
|
||||
std::stringstream primitive_description;
|
||||
|
||||
@ -50,7 +82,11 @@ std::string region_yolo_inst::to_string(region_yolo_node const& node) {
|
||||
region_yolo_info.add("classes", classes);
|
||||
region_yolo_info.add("num", num);
|
||||
region_yolo_info.add("do_softmax", do_softmax);
|
||||
region_yolo_info.add("mask", mask);
|
||||
region_yolo_info.add("mask_size", mask_size);
|
||||
region_yolo_info.add("axis", axis);
|
||||
region_yolo_info.add("end_axis", end_axis);
|
||||
|
||||
|
||||
node_info->add("region yolo info", region_yolo_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
@ -21,14 +21,20 @@ static void CreateRegionYoloOp(Program& p, const std::shared_ptr<ngraph::op::v0:
|
||||
uint32_t classes = static_cast<uint32_t>(op->get_num_classes());
|
||||
uint32_t num = static_cast<uint32_t>(op->get_num_regions());
|
||||
bool do_softmax = op->get_do_softmax();
|
||||
uint32_t mask_size = static_cast<uint32_t>(op->get_mask().size());
|
||||
std::vector<int64_t> mask = op->get_mask();
|
||||
uint32_t mask_size = static_cast<uint32_t>(mask.size());
|
||||
int32_t axis = op->get_axis();
|
||||
int32_t end_axis = op->get_end_axis();
|
||||
|
||||
auto regionPrim = cldnn::region_yolo(layerName,
|
||||
inputs[0],
|
||||
coords,
|
||||
classes,
|
||||
num,
|
||||
mask,
|
||||
mask_size,
|
||||
axis,
|
||||
end_axis,
|
||||
do_softmax);
|
||||
|
||||
p.add_primitive(*op, regionPrim);
|
||||
|
@ -0,0 +1,177 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/single_layer/region_yolo.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "ie_precision.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
#include <string>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace GPULayerTestsDefinitions {
|
||||
|
||||
struct regionYoloAttributes {
|
||||
size_t classes;
|
||||
size_t coordinates;
|
||||
size_t num_regions;
|
||||
bool do_softmax;
|
||||
int start_axis;
|
||||
int end_axis;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
InputShape, // Input Shape
|
||||
regionYoloAttributes, // Params
|
||||
std::vector<int64_t>, // mask
|
||||
ov::test::ElementType, // Network input precision
|
||||
ov::test::ElementType, // Network output precision
|
||||
std::map<std::string, std::string>, // Additional network configuration
|
||||
std::string // Device name
|
||||
> RegionYoloGPUTestParam;
|
||||
|
||||
class RegionYoloLayerGPUTest : public testing::WithParamInterface<RegionYoloGPUTestParam>,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<RegionYoloGPUTestParam> obj) {
|
||||
InputShape inputShape;
|
||||
regionYoloAttributes attributes;
|
||||
std::vector<int64_t> mask;
|
||||
ov::test::ElementType inpPrecision;
|
||||
ov::test::ElementType outPrecision;
|
||||
std::string targetName;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
|
||||
std::tie(inputShape, attributes, mask, inpPrecision, outPrecision, additionalConfig, targetName) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << inputShape << "_";
|
||||
result << "classes=" << attributes.classes << "_";
|
||||
result << "coords=" << attributes.coordinates << "_";
|
||||
result << "num=" << attributes.num_regions << "_";
|
||||
result << "doSoftmax=" << attributes.do_softmax << "_";
|
||||
result << "axis=" << attributes.start_axis << "_";
|
||||
result << "endAxis=" << attributes.end_axis << "_";
|
||||
result << "inpPRC=" << inpPrecision << "_";
|
||||
result << "outPRC=" << outPrecision << "_";
|
||||
result << "targetDevice=" << targetName << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InputShape inputShape;
|
||||
regionYoloAttributes attributes;
|
||||
std::vector<int64_t> mask;
|
||||
ov::test::ElementType inPrc;
|
||||
ov::test::ElementType outPrc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
|
||||
std::tie(inputShape, attributes, mask, inPrc, outPrc, additionalConfig, targetDevice) = this->GetParam();
|
||||
|
||||
init_input_shapes({ inputShape });
|
||||
|
||||
auto paramRegionYolo = ngraph::builder::makeDynamicParams(inPrc, inputDynamicShapes);
|
||||
|
||||
const auto region_yolo = std::make_shared<ngraph::op::v0::RegionYolo>(paramRegionYolo[0],
|
||||
attributes.coordinates, attributes.classes, attributes.num_regions,
|
||||
attributes.do_softmax, mask, attributes.start_axis, attributes.end_axis);
|
||||
|
||||
ngraph::ResultVector results;
|
||||
for (size_t i = 0; i < region_yolo->get_output_size(); i++)
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(region_yolo->output(i)));
|
||||
function = std::make_shared<ngraph::Function>(results, paramRegionYolo, "RegionYolo");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(RegionYoloLayerGPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::map<std::string, std::string> emptyAdditionalConfig;
|
||||
|
||||
const std::vector<ov::test::ElementType> inpOutPrc = {ov::test::ElementType::f16, ov::test::ElementType::f32};
|
||||
|
||||
const std::vector<InputShape> inShapes_caffe_dynamic = {
|
||||
{{-1, -1, -1, -1}, {{1, 125, 13, 13}, {1, 125, 26, 26}}},
|
||||
{{{1, 2}, {100, 125}, {13, 26}, {13, 26}}, {{1, 125, 13, 13}, {1, 125, 26, 26}}}
|
||||
};
|
||||
|
||||
const std::vector<InputShape> inShapes_mxnet_dynamic = {
|
||||
{{-1, -1, -1, -1}, {{1, 75, 52, 52}, {1, 75, 32, 32}, {1, 75, 26, 26}}},
|
||||
{{{1, 2}, {75, 80}, {26, 52}, {26, 52}}, {{1, 75, 52, 52}, {1, 75, 32, 32}, {1, 75, 26, 26}}},
|
||||
};
|
||||
|
||||
const std::vector<InputShape> inShapes_v3_dynamic = {
|
||||
{{-1, -1, -1, -1}, {{1, 255, 52, 52}, {1, 255, 26, 26}, {1, 255, 13, 13}}},
|
||||
{{{1, 2}, {255, 256}, {13, 52}, {13, 52}}, {{1, 255, 52, 52}, {1, 255, 26, 26}, {1, 255, 13, 13}}}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int64_t>> masks = {
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8}
|
||||
};
|
||||
|
||||
const std::vector<bool> do_softmax = {true, false};
|
||||
const std::vector<size_t> classes = {80, 20};
|
||||
const std::vector<size_t> num_regions = {5, 9};
|
||||
|
||||
const regionYoloAttributes yoloV3attr = {80, 4, 9, false, 1, 3};
|
||||
|
||||
const auto testCase_yolov3_dynamic = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_v3_dynamic),
|
||||
::testing::Values(yoloV3attr),
|
||||
::testing::Values(masks[2]),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::Values(emptyAdditionalConfig),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
const regionYoloAttributes yoloV3mxnetAttr = {20, 4, 9, false, 1, 3};
|
||||
|
||||
const auto testCase_yolov3_mxnet_dynamic = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_mxnet_dynamic),
|
||||
::testing::Values(yoloV3mxnetAttr),
|
||||
::testing::Values(masks[1]),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::Values(emptyAdditionalConfig),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
const regionYoloAttributes yoloV2caffeAttr = {20, 4, 5, true, 1, 3};
|
||||
|
||||
const auto testCase_yolov2_caffe_dynamic = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_caffe_dynamic),
|
||||
::testing::Values(yoloV2caffeAttr),
|
||||
::testing::Values(masks[0]),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::ValuesIn(inpOutPrc),
|
||||
::testing::Values(emptyAdditionalConfig),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYolov3Dynamic, RegionYoloLayerGPUTest,
|
||||
testCase_yolov3_dynamic,
|
||||
RegionYoloLayerGPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYoloMxnetDynamic, RegionYoloLayerGPUTest,
|
||||
testCase_yolov3_mxnet_dynamic,
|
||||
RegionYoloLayerGPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYoloCaffeDynamic, RegionYoloLayerGPUTest,
|
||||
testCase_yolov2_caffe_dynamic,
|
||||
RegionYoloLayerGPUTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace GPULayerTestsDefinitions
|
@ -0,0 +1,84 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/region_yolo.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "region_yolo_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct region_yolo_test_params {
|
||||
layout in_layout;
|
||||
uint32_t coords;
|
||||
uint32_t classes;
|
||||
uint32_t num;
|
||||
std::vector<int64_t> mask;
|
||||
int32_t axis;
|
||||
int32_t end_axis;
|
||||
bool do_softmax;
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class region_yolo_test : public testing::TestWithParam<region_yolo_test_params> { };
|
||||
|
||||
TEST_P(region_yolo_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_layout_prim = std::make_shared<input_layout>("input", p.in_layout);
|
||||
auto region_yolo_prim = std::make_shared<region_yolo>("region_yolo",
|
||||
input_info("input"),
|
||||
p.coords,
|
||||
p.classes,
|
||||
p.num,
|
||||
p.mask,
|
||||
static_cast<uint32_t>(p.mask.size()),
|
||||
p.axis,
|
||||
p.end_axis,
|
||||
p.do_softmax);
|
||||
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto& input_layout_node = prog.get_or_create(input_layout_prim);
|
||||
auto& region_yolo_node = prog.get_or_create(region_yolo_prim);
|
||||
program_wrapper::add_connection(prog, input_layout_node, region_yolo_node);
|
||||
auto res = region_yolo_inst::calc_output_layouts<ov::PartialShape>(region_yolo_node, *region_yolo_node.get_kernel_impl_params());
|
||||
|
||||
ASSERT_EQ(res.size(), 1);
|
||||
ASSERT_EQ(res[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, region_yolo_test,
|
||||
testing::ValuesIn(std::vector<region_yolo_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{1, 255, 26, 26}, data_types::f32, format::bfyx},
|
||||
4, 80, 6, { 0, 1, 2 }, 1, 3, false,
|
||||
layout{ov::PartialShape{1, 255, 26, 26}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
|
||||
4, 80, 6, { 0, 1, 2 }, 1, 3, false,
|
||||
layout{ov::PartialShape{-1, 255, -1, -1}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
|
||||
4, 80, 6, { 0, 1, 2 }, 1, 3, true,
|
||||
layout{ov::PartialShape::dynamic(2), data_types::f32, format::bfyx}
|
||||
},
|
||||
}));
|
||||
|
||||
} // shape_infer_tests
|
@ -160,6 +160,8 @@ struct region_yolo_test_params {
|
||||
uint32_t coords;
|
||||
uint32_t classes;
|
||||
uint32_t regionNum;
|
||||
int32_t axis;
|
||||
int32_t end_axis;
|
||||
data_types dataType;
|
||||
format fmt;
|
||||
bool softMax;
|
||||
@ -179,7 +181,8 @@ void runRegionTest(region_yolo_test_params& params, bool is_caching_test = false
|
||||
topology.add(input_layout("InputData", inputPrim->get_layout()));
|
||||
topology.add(reorder("reorder_pre", input_info("InputData"), params.fmt, params.dataType));
|
||||
topology.add(region_yolo("region_yolo", input_info("reorder_pre"), params.coords, params.classes,
|
||||
params.regionNum, static_cast<uint32_t>(params.mask.size()), params.softMax));
|
||||
params.regionNum, params.mask, static_cast<uint32_t>(params.mask.size()),
|
||||
params.axis, params.end_axis, params.softMax));
|
||||
topology.add(reorder("reorder_post", input_info("region_yolo"), format::bfyx, params.dataType));
|
||||
|
||||
cldnn::network::ptr network = get_network(engine, topology, get_test_default_config(engine), get_test_stream_ptr(), is_caching_test);
|
||||
@ -204,82 +207,82 @@ void runRegionTest(region_yolo_test_params& params, bool is_caching_test = false
|
||||
} // namespace
|
||||
|
||||
TEST(region_yolo_gpu_fp32, bfyx) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, false};
|
||||
runRegionTest<float>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, bfyx_softmax) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, true};
|
||||
runRegionTest<float>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, byxf) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, false};
|
||||
runRegionTest<float>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, byxf_softmax) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, true};
|
||||
runRegionTest<float>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, bfyx) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, false};
|
||||
runRegionTest<FLOAT16>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, bfyx_softmax) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, true};
|
||||
runRegionTest<FLOAT16>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, byxf) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, false};
|
||||
runRegionTest<FLOAT16>(params);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, byxf_softmax) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, true};
|
||||
runRegionTest<FLOAT16>(params);
|
||||
}
|
||||
|
||||
#ifdef RUN_ALL_MODEL_CACHING_TESTS
|
||||
TEST(region_yolo_gpu_fp32, bfyx_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, false};
|
||||
runRegionTest<float>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, bfyx_softmax_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, true};
|
||||
runRegionTest<float>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, byxf_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, false};
|
||||
runRegionTest<float>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp32, byxf_softmax_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, true};
|
||||
runRegionTest<float>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, bfyx_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, false};
|
||||
runRegionTest<FLOAT16>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, bfyx_softmax_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, true};
|
||||
runRegionTest<FLOAT16>(params, true);
|
||||
}
|
||||
|
||||
TEST(region_yolo_gpu_fp16, byxf_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, false};
|
||||
runRegionTest<FLOAT16>(params, true);
|
||||
}
|
||||
#endif // RUN_ALL_MODEL_CACHING_TESTS
|
||||
TEST(region_yolo_gpu_fp16, byxf_softmax_cached) {
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true};
|
||||
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, true};
|
||||
runRegionTest<FLOAT16>(params, true);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user