[GPU] Update RegionYolo to use nGraph shape inference (#18657)

* Update RegionYolo to use ngraph shape infer

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Add dynamic TCs for ov_gpu_func_tests

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Add shape infer TCs for ov_gpu_unit_tests

Signed-off-by: Andrew Park <andrew.park@intel.com>

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2023-07-24 13:55:21 +09:00 committed by GitHub
parent 74b5f8673c
commit 2c889c8b5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 383 additions and 19 deletions

View File

@ -50,27 +50,51 @@ public:
size_t get_num_coords() const { size_t get_num_coords() const {
return m_num_coords; return m_num_coords;
} }
void set_num_coords(const size_t num_coords) {
m_num_coords = num_coords;
}
size_t get_num_classes() const { size_t get_num_classes() const {
return m_num_classes; return m_num_classes;
} }
void set_num_classes(const size_t num_classes) {
m_num_classes = num_classes;
}
size_t get_num_regions() const { size_t get_num_regions() const {
return m_num_regions; return m_num_regions;
} }
void set_num_regions(const size_t num_regions) {
m_num_regions = num_regions;
}
bool get_do_softmax() const { bool get_do_softmax() const {
return m_do_softmax; return m_do_softmax;
} }
void set_do_softmax(const bool do_softmax) {
m_do_softmax = do_softmax;
}
const std::vector<int64_t>& get_mask() const { const std::vector<int64_t>& get_mask() const {
return m_mask; return m_mask;
} }
void set_mask(const std::vector<int64_t>& mask) {
m_mask = mask;
}
const std::vector<float>& get_anchors() const { const std::vector<float>& get_anchors() const {
return m_anchors; return m_anchors;
} }
void set_anchors(const std::vector<float>& anchors) {
m_anchors = anchors;
}
int get_axis() const { int get_axis() const {
return m_axis; return m_axis;
} }
void set_axis(const int axis) {
m_axis = axis;
}
int get_end_axis() const { int get_end_axis() const {
return m_end_axis; return m_end_axis;
} }
void set_end_axis(const int end_axis) {
m_end_axis = end_axis;
}
private: private:
size_t m_num_coords; size_t m_num_coords;

View File

@ -27,14 +27,20 @@ struct region_yolo : public primitive_base<region_yolo> {
const uint32_t coords, const uint32_t coords,
const uint32_t classes, const uint32_t classes,
const uint32_t num, const uint32_t num,
const uint32_t mask_size = 0, const std::vector<int64_t>& mask,
const uint32_t mask_size,
const int32_t axis,
const int32_t end_axis,
const bool do_softmax = true, const bool do_softmax = true,
const padding& output_padding = padding()) const padding& output_padding = padding())
: primitive_base(id, {input}, {output_padding}), : primitive_base(id, {input}, {output_padding}),
coords(coords), coords(coords),
classes(classes), classes(classes),
num(num), num(num),
mask(mask),
mask_size(mask_size), mask_size(mask_size),
axis(axis),
end_axis(end_axis),
do_softmax(do_softmax) {} do_softmax(do_softmax) {}
/// @brief Defines a scope of a region yolo normalization /// @brief Defines a scope of a region yolo normalization
@ -43,7 +49,10 @@ struct region_yolo : public primitive_base<region_yolo> {
uint32_t coords; uint32_t coords;
uint32_t classes; uint32_t classes;
uint32_t num; uint32_t num;
std::vector<int64_t> mask;
uint32_t mask_size; uint32_t mask_size;
int32_t axis;
int32_t end_axis;
bool do_softmax; bool do_softmax;
size_t hash() const override { size_t hash() const override {
@ -51,7 +60,10 @@ struct region_yolo : public primitive_base<region_yolo> {
seed = hash_combine(seed, coords); seed = hash_combine(seed, coords);
seed = hash_combine(seed, classes); seed = hash_combine(seed, classes);
seed = hash_combine(seed, num); seed = hash_combine(seed, num);
seed = hash_range(seed, mask.begin(), mask.end());
seed = hash_combine(seed, mask_size); seed = hash_combine(seed, mask_size);
seed = hash_combine(seed, axis);
seed = hash_combine(seed, end_axis);
seed = hash_combine(seed, do_softmax); seed = hash_combine(seed, do_softmax);
return seed; return seed;
} }
@ -65,7 +77,10 @@ struct region_yolo : public primitive_base<region_yolo> {
return coords == rhs_casted.coords && return coords == rhs_casted.coords &&
classes == rhs_casted.classes && classes == rhs_casted.classes &&
num == rhs_casted.num && num == rhs_casted.num &&
mask == rhs_casted.mask &&
mask_size == rhs_casted.mask_size && mask_size == rhs_casted.mask_size &&
axis == rhs_casted.axis &&
end_axis == rhs_casted.end_axis &&
do_softmax == rhs_casted.do_softmax; do_softmax == rhs_casted.do_softmax;
} }
@ -74,7 +89,10 @@ struct region_yolo : public primitive_base<region_yolo> {
ob << coords; ob << coords;
ob << classes; ob << classes;
ob << num; ob << num;
ob << mask;
ob << mask_size; ob << mask_size;
ob << axis;
ob << end_axis;
ob << do_softmax; ob << do_softmax;
} }
@ -83,7 +101,10 @@ struct region_yolo : public primitive_base<region_yolo> {
ib >> coords; ib >> coords;
ib >> classes; ib >> classes;
ib >> num; ib >> num;
ib >> mask;
ib >> mask_size; ib >> mask_size;
ib >> axis;
ib >> end_axis;
ib >> do_softmax; ib >> do_softmax;
} }
}; };

View File

@ -9,6 +9,17 @@
#include <string> #include <string>
namespace cldnn { namespace cldnn {
template <>
struct typed_program_node<region_yolo> : public typed_program_node_base<region_yolo> {
using parent = typed_program_node_base<region_yolo>;
public:
using parent::parent;
program_node& input(size_t index = 0) const { return get_dependency(index); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
};
using region_yolo_node = typed_program_node<region_yolo>; using region_yolo_node = typed_program_node<region_yolo>;
template <> template <>
@ -17,6 +28,8 @@ class typed_primitive_inst<region_yolo> : public typed_primitive_inst_base<regio
using parent::parent; using parent::parent;
public: public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(region_yolo_node const& node, kernel_impl_params const& impl_param);
static layout calc_output_layout(region_yolo_node const& node, kernel_impl_params const& impl_param); static layout calc_output_layout(region_yolo_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(region_yolo_node const& node); static std::string to_string(region_yolo_node const& node);

View File

@ -3,6 +3,8 @@
// //
#include "region_yolo_inst.h" #include "region_yolo_inst.h"
#include "region_yolo_shape_inference.hpp"
#include "primitive_type_base.h" #include "primitive_type_base.h"
#include "json_object.h" #include "json_object.h"
#include <string> #include <string>
@ -34,6 +36,33 @@ layout region_yolo_inst::calc_output_layout(region_yolo_node const& node, kernel
} }
} }
template<typename ShapeType>
std::vector<layout> region_yolo_inst::calc_output_layouts(region_yolo_node const& node, kernel_impl_params const& impl_param) {
auto desc = impl_param.typed_desc<region_yolo>();
auto input_layout = impl_param.get_input_layout(0);
auto output_type = desc->output_data_types[0].value_or(input_layout.data_type);
auto output_format = input_layout.format;
ov::op::v0::RegionYolo op;
op.set_num_coords(static_cast<size_t>(desc->coords));
op.set_num_classes(static_cast<size_t>(desc->classes));
op.set_num_regions(static_cast<size_t>(desc->num));
op.set_do_softmax(desc->do_softmax);
op.set_mask(desc->mask);
op.set_axis(desc->axis);
op.set_end_axis(desc->end_axis);
std::vector<ShapeType> output_shapes = { ShapeType() };
std::vector<ShapeType> input_shapes = {
input_layout.get<ShapeType>()
};
ov::op::v0::shape_infer(&op, input_shapes, output_shapes);
return { layout{output_shapes[0], output_type, output_format} };
}
template std::vector<layout> region_yolo_inst::calc_output_layouts<ov::PartialShape>(region_yolo_node const& node, const kernel_impl_params& impl_param);
std::string region_yolo_inst::to_string(region_yolo_node const& node) { std::string region_yolo_inst::to_string(region_yolo_node const& node) {
auto desc = node.get_primitive(); auto desc = node.get_primitive();
auto node_info = node.desc_to_json(); auto node_info = node.desc_to_json();
@ -41,7 +70,10 @@ std::string region_yolo_inst::to_string(region_yolo_node const& node) {
auto classes = desc->classes; auto classes = desc->classes;
auto num = desc->num; auto num = desc->num;
auto do_softmax = desc->do_softmax; auto do_softmax = desc->do_softmax;
auto mask = desc->mask;
auto mask_size = desc->mask_size; auto mask_size = desc->mask_size;
auto axis = desc->axis;
auto end_axis = desc->end_axis;
std::stringstream primitive_description; std::stringstream primitive_description;
@ -50,7 +82,11 @@ std::string region_yolo_inst::to_string(region_yolo_node const& node) {
region_yolo_info.add("classes", classes); region_yolo_info.add("classes", classes);
region_yolo_info.add("num", num); region_yolo_info.add("num", num);
region_yolo_info.add("do_softmax", do_softmax); region_yolo_info.add("do_softmax", do_softmax);
region_yolo_info.add("mask", mask);
region_yolo_info.add("mask_size", mask_size); region_yolo_info.add("mask_size", mask_size);
region_yolo_info.add("axis", axis);
region_yolo_info.add("end_axis", end_axis);
node_info->add("region yolo info", region_yolo_info); node_info->add("region yolo info", region_yolo_info);
node_info->dump(primitive_description); node_info->dump(primitive_description);

View File

@ -21,14 +21,20 @@ static void CreateRegionYoloOp(Program& p, const std::shared_ptr<ngraph::op::v0:
uint32_t classes = static_cast<uint32_t>(op->get_num_classes()); uint32_t classes = static_cast<uint32_t>(op->get_num_classes());
uint32_t num = static_cast<uint32_t>(op->get_num_regions()); uint32_t num = static_cast<uint32_t>(op->get_num_regions());
bool do_softmax = op->get_do_softmax(); bool do_softmax = op->get_do_softmax();
uint32_t mask_size = static_cast<uint32_t>(op->get_mask().size()); std::vector<int64_t> mask = op->get_mask();
uint32_t mask_size = static_cast<uint32_t>(mask.size());
int32_t axis = op->get_axis();
int32_t end_axis = op->get_end_axis();
auto regionPrim = cldnn::region_yolo(layerName, auto regionPrim = cldnn::region_yolo(layerName,
inputs[0], inputs[0],
coords, coords,
classes, classes,
num, num,
mask,
mask_size, mask_size,
axis,
end_axis,
do_softmax); do_softmax);
p.add_primitive(*op, regionPrim); p.add_primitive(*op, regionPrim);

View File

@ -0,0 +1,177 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_layer/region_yolo.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ie_precision.hpp"
#include "ngraph_functions/builders.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"
#include <string>
using namespace ngraph;
using namespace InferenceEngine;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
struct regionYoloAttributes {
size_t classes;
size_t coordinates;
size_t num_regions;
bool do_softmax;
int start_axis;
int end_axis;
};
typedef std::tuple<
InputShape, // Input Shape
regionYoloAttributes, // Params
std::vector<int64_t>, // mask
ov::test::ElementType, // Network input precision
ov::test::ElementType, // Network output precision
std::map<std::string, std::string>, // Additional network configuration
std::string // Device name
> RegionYoloGPUTestParam;
class RegionYoloLayerGPUTest : public testing::WithParamInterface<RegionYoloGPUTestParam>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<RegionYoloGPUTestParam> obj) {
InputShape inputShape;
regionYoloAttributes attributes;
std::vector<int64_t> mask;
ov::test::ElementType inpPrecision;
ov::test::ElementType outPrecision;
std::string targetName;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShape, attributes, mask, inpPrecision, outPrecision, additionalConfig, targetName) = obj.param;
std::ostringstream result;
result << "IS=" << inputShape << "_";
result << "classes=" << attributes.classes << "_";
result << "coords=" << attributes.coordinates << "_";
result << "num=" << attributes.num_regions << "_";
result << "doSoftmax=" << attributes.do_softmax << "_";
result << "axis=" << attributes.start_axis << "_";
result << "endAxis=" << attributes.end_axis << "_";
result << "inpPRC=" << inpPrecision << "_";
result << "outPRC=" << outPrecision << "_";
result << "targetDevice=" << targetName << "_";
return result.str();
}
protected:
void SetUp() override {
InputShape inputShape;
regionYoloAttributes attributes;
std::vector<int64_t> mask;
ov::test::ElementType inPrc;
ov::test::ElementType outPrc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShape, attributes, mask, inPrc, outPrc, additionalConfig, targetDevice) = this->GetParam();
init_input_shapes({ inputShape });
auto paramRegionYolo = ngraph::builder::makeDynamicParams(inPrc, inputDynamicShapes);
const auto region_yolo = std::make_shared<ngraph::op::v0::RegionYolo>(paramRegionYolo[0],
attributes.coordinates, attributes.classes, attributes.num_regions,
attributes.do_softmax, mask, attributes.start_axis, attributes.end_axis);
ngraph::ResultVector results;
for (size_t i = 0; i < region_yolo->get_output_size(); i++)
results.push_back(std::make_shared<ngraph::opset1::Result>(region_yolo->output(i)));
function = std::make_shared<ngraph::Function>(results, paramRegionYolo, "RegionYolo");
}
};
TEST_P(RegionYoloLayerGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
std::map<std::string, std::string> emptyAdditionalConfig;
const std::vector<ov::test::ElementType> inpOutPrc = {ov::test::ElementType::f16, ov::test::ElementType::f32};
const std::vector<InputShape> inShapes_caffe_dynamic = {
{{-1, -1, -1, -1}, {{1, 125, 13, 13}, {1, 125, 26, 26}}},
{{{1, 2}, {100, 125}, {13, 26}, {13, 26}}, {{1, 125, 13, 13}, {1, 125, 26, 26}}}
};
const std::vector<InputShape> inShapes_mxnet_dynamic = {
{{-1, -1, -1, -1}, {{1, 75, 52, 52}, {1, 75, 32, 32}, {1, 75, 26, 26}}},
{{{1, 2}, {75, 80}, {26, 52}, {26, 52}}, {{1, 75, 52, 52}, {1, 75, 32, 32}, {1, 75, 26, 26}}},
};
const std::vector<InputShape> inShapes_v3_dynamic = {
{{-1, -1, -1, -1}, {{1, 255, 52, 52}, {1, 255, 26, 26}, {1, 255, 13, 13}}},
{{{1, 2}, {255, 256}, {13, 52}, {13, 52}}, {{1, 255, 52, 52}, {1, 255, 26, 26}, {1, 255, 13, 13}}}
};
const std::vector<std::vector<int64_t>> masks = {
{0, 1, 2},
{3, 4, 5},
{6, 7, 8}
};
const std::vector<bool> do_softmax = {true, false};
const std::vector<size_t> classes = {80, 20};
const std::vector<size_t> num_regions = {5, 9};
const regionYoloAttributes yoloV3attr = {80, 4, 9, false, 1, 3};
const auto testCase_yolov3_dynamic = ::testing::Combine(
::testing::ValuesIn(inShapes_v3_dynamic),
::testing::Values(yoloV3attr),
::testing::Values(masks[2]),
::testing::ValuesIn(inpOutPrc),
::testing::ValuesIn(inpOutPrc),
::testing::Values(emptyAdditionalConfig),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
const regionYoloAttributes yoloV3mxnetAttr = {20, 4, 9, false, 1, 3};
const auto testCase_yolov3_mxnet_dynamic = ::testing::Combine(
::testing::ValuesIn(inShapes_mxnet_dynamic),
::testing::Values(yoloV3mxnetAttr),
::testing::Values(masks[1]),
::testing::ValuesIn(inpOutPrc),
::testing::ValuesIn(inpOutPrc),
::testing::Values(emptyAdditionalConfig),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
const regionYoloAttributes yoloV2caffeAttr = {20, 4, 5, true, 1, 3};
const auto testCase_yolov2_caffe_dynamic = ::testing::Combine(
::testing::ValuesIn(inShapes_caffe_dynamic),
::testing::Values(yoloV2caffeAttr),
::testing::Values(masks[0]),
::testing::ValuesIn(inpOutPrc),
::testing::ValuesIn(inpOutPrc),
::testing::Values(emptyAdditionalConfig),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYolov3Dynamic, RegionYoloLayerGPUTest,
testCase_yolov3_dynamic,
RegionYoloLayerGPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYoloMxnetDynamic, RegionYoloLayerGPUTest,
testCase_yolov3_mxnet_dynamic,
RegionYoloLayerGPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GPURegionYoloCaffeDynamic, RegionYoloLayerGPUTest,
testCase_yolov2_caffe_dynamic,
RegionYoloLayerGPUTest::getTestCaseName);
} // namespace
} // namespace GPULayerTestsDefinitions

View File

@ -0,0 +1,84 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/region_yolo.hpp>
#include <intel_gpu/primitives/data.hpp>
#include "region_yolo_inst.h"
#include "program_wrapper.h"
#include <cmath>
#include <algorithm>
using namespace cldnn;
using namespace ::tests;
namespace shape_infer_tests {
struct region_yolo_test_params {
layout in_layout;
uint32_t coords;
uint32_t classes;
uint32_t num;
std::vector<int64_t> mask;
int32_t axis;
int32_t end_axis;
bool do_softmax;
layout expected_layout;
};
class region_yolo_test : public testing::TestWithParam<region_yolo_test_params> { };
TEST_P(region_yolo_test, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
auto input_layout_prim = std::make_shared<input_layout>("input", p.in_layout);
auto region_yolo_prim = std::make_shared<region_yolo>("region_yolo",
input_info("input"),
p.coords,
p.classes,
p.num,
p.mask,
static_cast<uint32_t>(p.mask.size()),
p.axis,
p.end_axis,
p.do_softmax);
cldnn::program prog(engine);
auto& input_layout_node = prog.get_or_create(input_layout_prim);
auto& region_yolo_node = prog.get_or_create(region_yolo_prim);
program_wrapper::add_connection(prog, input_layout_node, region_yolo_node);
auto res = region_yolo_inst::calc_output_layouts<ov::PartialShape>(region_yolo_node, *region_yolo_node.get_kernel_impl_params());
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], p.expected_layout);
}
INSTANTIATE_TEST_SUITE_P(smoke, region_yolo_test,
testing::ValuesIn(std::vector<region_yolo_test_params>{
{
layout{ov::PartialShape{1, 255, 26, 26}, data_types::f32, format::bfyx},
4, 80, 6, { 0, 1, 2 }, 1, 3, false,
layout{ov::PartialShape{1, 255, 26, 26}, data_types::f32, format::bfyx}
},
{
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
4, 80, 6, { 0, 1, 2 }, 1, 3, false,
layout{ov::PartialShape{-1, 255, -1, -1}, data_types::f32, format::bfyx}
},
{
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
4, 80, 6, { 0, 1, 2 }, 1, 3, true,
layout{ov::PartialShape::dynamic(2), data_types::f32, format::bfyx}
},
}));
} // shape_infer_tests

View File

@ -160,6 +160,8 @@ struct region_yolo_test_params {
uint32_t coords; uint32_t coords;
uint32_t classes; uint32_t classes;
uint32_t regionNum; uint32_t regionNum;
int32_t axis;
int32_t end_axis;
data_types dataType; data_types dataType;
format fmt; format fmt;
bool softMax; bool softMax;
@ -179,7 +181,8 @@ void runRegionTest(region_yolo_test_params& params, bool is_caching_test = false
topology.add(input_layout("InputData", inputPrim->get_layout())); topology.add(input_layout("InputData", inputPrim->get_layout()));
topology.add(reorder("reorder_pre", input_info("InputData"), params.fmt, params.dataType)); topology.add(reorder("reorder_pre", input_info("InputData"), params.fmt, params.dataType));
topology.add(region_yolo("region_yolo", input_info("reorder_pre"), params.coords, params.classes, topology.add(region_yolo("region_yolo", input_info("reorder_pre"), params.coords, params.classes,
params.regionNum, static_cast<uint32_t>(params.mask.size()), params.softMax)); params.regionNum, params.mask, static_cast<uint32_t>(params.mask.size()),
params.axis, params.end_axis, params.softMax));
topology.add(reorder("reorder_post", input_info("region_yolo"), format::bfyx, params.dataType)); topology.add(reorder("reorder_post", input_info("region_yolo"), format::bfyx, params.dataType));
cldnn::network::ptr network = get_network(engine, topology, get_test_default_config(engine), get_test_stream_ptr(), is_caching_test); cldnn::network::ptr network = get_network(engine, topology, get_test_default_config(engine), get_test_stream_ptr(), is_caching_test);
@ -204,82 +207,82 @@ void runRegionTest(region_yolo_test_params& params, bool is_caching_test = false
} // namespace } // namespace
TEST(region_yolo_gpu_fp32, bfyx) { TEST(region_yolo_gpu_fp32, bfyx) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, false};
runRegionTest<float>(params); runRegionTest<float>(params);
} }
TEST(region_yolo_gpu_fp32, bfyx_softmax) { TEST(region_yolo_gpu_fp32, bfyx_softmax) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, true};
runRegionTest<float>(params); runRegionTest<float>(params);
} }
TEST(region_yolo_gpu_fp32, byxf) { TEST(region_yolo_gpu_fp32, byxf) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, false};
runRegionTest<float>(params); runRegionTest<float>(params);
} }
TEST(region_yolo_gpu_fp32, byxf_softmax) { TEST(region_yolo_gpu_fp32, byxf_softmax) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, true};
runRegionTest<float>(params); runRegionTest<float>(params);
} }
TEST(region_yolo_gpu_fp16, bfyx) { TEST(region_yolo_gpu_fp16, bfyx) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, false};
runRegionTest<FLOAT16>(params); runRegionTest<FLOAT16>(params);
} }
TEST(region_yolo_gpu_fp16, bfyx_softmax) { TEST(region_yolo_gpu_fp16, bfyx_softmax) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, true};
runRegionTest<FLOAT16>(params); runRegionTest<FLOAT16>(params);
} }
TEST(region_yolo_gpu_fp16, byxf) { TEST(region_yolo_gpu_fp16, byxf) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, false};
runRegionTest<FLOAT16>(params); runRegionTest<FLOAT16>(params);
} }
TEST(region_yolo_gpu_fp16, byxf_softmax) { TEST(region_yolo_gpu_fp16, byxf_softmax) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, true};
runRegionTest<FLOAT16>(params); runRegionTest<FLOAT16>(params);
} }
#ifdef RUN_ALL_MODEL_CACHING_TESTS #ifdef RUN_ALL_MODEL_CACHING_TESTS
TEST(region_yolo_gpu_fp32, bfyx_cached) { TEST(region_yolo_gpu_fp32, bfyx_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, false};
runRegionTest<float>(params, true); runRegionTest<float>(params, true);
} }
TEST(region_yolo_gpu_fp32, bfyx_softmax_cached) { TEST(region_yolo_gpu_fp32, bfyx_softmax_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::bfyx, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::bfyx, true};
runRegionTest<float>(params, true); runRegionTest<float>(params, true);
} }
TEST(region_yolo_gpu_fp32, byxf_cached) { TEST(region_yolo_gpu_fp32, byxf_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, false};
runRegionTest<float>(params, true); runRegionTest<float>(params, true);
} }
TEST(region_yolo_gpu_fp32, byxf_softmax_cached) { TEST(region_yolo_gpu_fp32, byxf_softmax_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f32, format::byxf, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f32, format::byxf, true};
runRegionTest<float>(params, true); runRegionTest<float>(params, true);
} }
TEST(region_yolo_gpu_fp16, bfyx_cached) { TEST(region_yolo_gpu_fp16, bfyx_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, false};
runRegionTest<FLOAT16>(params, true); runRegionTest<FLOAT16>(params, true);
} }
TEST(region_yolo_gpu_fp16, bfyx_softmax_cached) { TEST(region_yolo_gpu_fp16, bfyx_softmax_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::bfyx, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::bfyx, true};
runRegionTest<FLOAT16>(params, true); runRegionTest<FLOAT16>(params, true);
} }
TEST(region_yolo_gpu_fp16, byxf_cached) { TEST(region_yolo_gpu_fp16, byxf_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, false}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, false};
runRegionTest<FLOAT16>(params, true); runRegionTest<FLOAT16>(params, true);
} }
#endif // RUN_ALL_MODEL_CACHING_TESTS #endif // RUN_ALL_MODEL_CACHING_TESTS
TEST(region_yolo_gpu_fp16, byxf_softmax_cached) { TEST(region_yolo_gpu_fp16, byxf_softmax_cached) {
region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, data_types::f16, format::byxf, true}; region_yolo_test_params params{{ 1, 33, 52, 52 }, { 0, 1, 2 }, 4, 6, 3, 1, 3, data_types::f16, format::byxf, true};
runRegionTest<FLOAT16>(params, true); runRegionTest<FLOAT16>(params, true);
} }