[GPU] RandomUniform new shape inference for dynamism support (#19087)

This commit is contained in:
Sergey Shlyapnikov 2023-08-10 09:43:08 +04:00 committed by GitHub
parent f683fabcbf
commit d91d72c89c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 504 additions and 60 deletions

View File

@ -20,8 +20,7 @@ struct random_uniform : public primitive_base<random_uniform> {
random_uniform() : primitive_base("", {}),
global_seed(0),
op_seed(0),
output_shape{},
output_format(format::type::any) {}
output_shape{} {}
DECLARE_OBJECT_TYPE_SERIALIZATION
@ -36,20 +35,26 @@ struct random_uniform : public primitive_base<random_uniform> {
*/
random_uniform(const primitive_id &id, const std::vector<input_info> &inputs,
const data_types &data_type, const uint64_t global_seed,
const uint64_t op_seed, const tensor output_shape,
const format output_format,
const uint64_t op_seed, const ov::Shape output_shape,
const padding &output_padding = padding())
: primitive_base(id, inputs, {output_padding},
{optional_data_type{data_type}}),
global_seed(global_seed),
op_seed(op_seed),
output_shape(output_shape),
output_format(output_format) {}
output_shape(output_shape) {}
random_uniform(const primitive_id &id, const std::vector<input_info> &inputs,
const data_types &data_type, const uint64_t global_seed,
const uint64_t op_seed, const padding &output_padding = padding())
: primitive_base(id, inputs, {output_padding},
{optional_data_type{data_type}}),
global_seed(global_seed),
op_seed(op_seed),
output_shape() {}
const uint64_t global_seed;
const uint64_t op_seed;
const tensor output_shape;
const format output_format;
const ov::Shape output_shape;
size_t hash() const override {
size_t seed = primitive::hash();
@ -73,17 +78,13 @@ struct random_uniform : public primitive_base<random_uniform> {
ob << global_seed;
ob << op_seed;
ob << output_shape;
ob << make_data(&output_format.value, sizeof(format::type));
}
void load(BinaryInputBuffer& ib) override {
primitive_base<random_uniform>::load(ib);
ib >> *const_cast<uint64_t*>(&global_seed);
ib >> *const_cast<uint64_t*>(&op_seed);
ib >> *const_cast<tensor*>(&output_shape);
format::type tmp_type = format::type::any;
ib >> make_data(&tmp_type, sizeof(format::type));
*const_cast<format*>(&output_format) = format(tmp_type);
ib >> *const_cast<ov::Shape*>(&output_shape);
}
};

View File

@ -9,6 +9,17 @@
namespace cldnn {
template <>
struct typed_program_node<random_uniform> : public typed_program_node_base<random_uniform> {
using parent = typed_program_node_base<random_uniform>;
public:
using parent::parent;
program_node& input(size_t index = 0) const { return get_dependency(index); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {0}; }
};
using random_uniform_node = typed_program_node<random_uniform>;
template<>
@ -17,6 +28,8 @@ class typed_primitive_inst<random_uniform> : public typed_primitive_inst_base<ra
using parent::parent;
public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(random_uniform_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(random_uniform_node const &node, kernel_impl_params const& impl_param);
static std::string to_string(random_uniform_node const &node);

View File

@ -2,11 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <random_uniform_inst.h>
#include "random_uniform_inst.h"
#include "primitive_type_base.h"
#include <sstream>
#include <json_object.h>
#include <data_inst.h>
#include "json_object.h"
#include "random_uniform_shape_inference.hpp"
namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(random_uniform)
@ -17,9 +17,56 @@ random_uniform_inst::typed_primitive_inst(network& network, random_uniform_node
layout random_uniform_inst::calc_output_layout(random_uniform_node const &node, kernel_impl_params const& impl_param) {
auto primitive = impl_param.typed_desc<random_uniform>();
return {*primitive->output_data_types[0], primitive->output_format, primitive->output_shape};
auto format = format::get_default_format(primitive->output_shape.size());
return {primitive->output_shape, *primitive->output_data_types[0], format};
}
template<typename ShapeType>
std::vector<layout> random_uniform_inst::calc_output_layouts(random_uniform_node const& /*node*/, kernel_impl_params const& impl_param) {
auto desc = impl_param.typed_desc<random_uniform>();
auto output_data_type = desc->output_data_types[0].value_or(impl_param.get_input_layout().data_type);
std::vector<ShapeType> output_shapes;
std::vector<ShapeType> input_shapes = { impl_param.get_input_layout(0).get_partial_shape(),
impl_param.get_input_layout(1).get_partial_shape(),
impl_param.get_input_layout(2).get_partial_shape() };
auto& memory_deps = impl_param.memory_deps;
std::map<size_t, ngraph::HostTensorPtr> const_data;
auto run_shape_infer = [&]() {
ov::op::v8::RandomUniform op;
if (memory_deps.count(1) > 0 && memory_deps.count(2) > 0) {
auto min_val = memory_deps.at(1);
cldnn::mem_lock<uint8_t, mem_lock_type::read> min_val_lock(min_val, impl_param.get_stream());
const_data.emplace(1, make_host_tensor(min_val->get_layout(), min_val_lock.data()));
auto max_val = memory_deps.at(2);
cldnn::mem_lock<uint8_t, mem_lock_type::read> max_val_lock(max_val, impl_param.get_stream());
const_data.emplace(2, make_host_tensor(max_val->get_layout(), max_val_lock.data()));
return ov::op::v8::shape_infer(&op, input_shapes, ov::make_tensor_accessor(const_data));
} else {
return ov::op::v8::shape_infer(&op, input_shapes, ov::make_tensor_accessor(const_data));
}
};
if (memory_deps.count(0) > 0) {
auto output_shape = memory_deps.at(0);
cldnn::mem_lock<uint8_t, mem_lock_type::read> output_shape_lock(output_shape, impl_param.get_stream());
const_data.emplace(0, make_host_tensor(output_shape->get_layout(), output_shape_lock.data()));
output_shapes = run_shape_infer();
} else {
output_shapes = run_shape_infer();
}
return { layout{output_shapes[0], output_data_type, format::get_default_format(output_shapes[0].size())} };
}
template std::vector<layout> random_uniform_inst::calc_output_layouts<ov::PartialShape>(random_uniform_node const& node, const kernel_impl_params& impl_param);
std::string random_uniform_inst::to_string(random_uniform_node const &node) {
auto node_info = node.desc_to_json();
json_composite random_uniform_info;

View File

@ -15,17 +15,33 @@ namespace {
void CreateRandomUniformOp(Program &p, const std::shared_ptr<ngraph::op::v8::RandomUniform> &op) {
auto inputs = p.GetInputInfo(op);
auto output_shape = op->get_output_shape(0);
cldnn::format outputFormat = cldnn::format::get_default_format(output_shape.size());
auto input_pshape = op->get_input_partial_shape(0);
auto output_pshape = op->get_output_partial_shape(0);
auto random_uniform_prim = cldnn::random_uniform(layer_type_name_ID(op),
inputs,
cldnn::element_type_to_data_type(op->get_out_type()),
op->get_global_seed(),
op->get_op_seed(),
tensor_from_dims(output_shape),
outputFormat);
p.add_primitive(*op, random_uniform_prim);
OPENVINO_ASSERT(input_pshape.is_static(), "[GPU] Dynamic input of RandomUniform leads to dynamic output rank, but GPU doesn't support it yet");
if (output_pshape.is_static() && !p.use_new_shape_infer()) {
auto output_shape = output_pshape.get_shape();
// Extend to 4D shape
output_shape.insert(output_shape.end(), 4 - output_shape.size(), 1ul);
auto random_uniform_prim = cldnn::random_uniform(layer_type_name_ID(op),
inputs,
cldnn::element_type_to_data_type(op->get_out_type()),
op->get_global_seed(),
op->get_op_seed(),
output_shape);
p.add_primitive(*op, random_uniform_prim);
} else {
OPENVINO_ASSERT(input_pshape.size() == 1, "[GPU] RandomUniform expects 1D input, got ", input_pshape.size());
auto random_uniform_prim = cldnn::random_uniform(layer_type_name_ID(op),
inputs,
cldnn::element_type_to_data_type(op->get_out_type()),
op->get_global_seed(),
op->get_op_seed());
p.add_primitive(*op, random_uniform_prim);
}
}
} // namespace

View File

@ -0,0 +1,208 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
using namespace ngraph;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
typedef std::tuple<
std::vector<InputShape>, // Input shapes
std::pair<double, double>, // Min value, Max value
std::pair<uint64_t, uint64_t>, // Global seed, operation seed
ElementType, // Network precision
TargetDevice, // Device name
std::map<std::string, std::string> // Additional network configuration
> RandomUnifromDynamicGPUTestParamsSet;
class RandomUnifromDynamicGPUTest : public testing::WithParamInterface<RandomUnifromDynamicGPUTestParamsSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<RandomUnifromDynamicGPUTestParamsSet>& obj) {
RandomUnifromDynamicGPUTestParamsSet basicParamsSet = obj.param;
std::ostringstream result;
std::vector<InputShape> input_shapes;
std::pair<double, double> min_max_values;
std::pair<uint64_t, uint64_t> seeds;
ElementType precision;
TargetDevice target_device;
std::map<std::string, std::string> additionalConfig;
std::tie(input_shapes, min_max_values, seeds, precision, target_device, additionalConfig) = basicParamsSet;
result << "shape=";
for (const auto& shape : input_shapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
for (const auto& actual_shape : shape.second) {
result << ov::test::utils::partialShape2str({actual_shape}) << "_";
}
}
result << "precision=" << precision << "_";
result << "min_max_values=" << min_max_values.first << "_" << min_max_values.second << "_";
result << "seeds=" << seeds.first << "_" << seeds.second << "_";
result << "target_device=" << target_device;
return result.str();
}
protected:
void init_input_shapes(const std::vector<InputShape>& shapes) {
if (shapes.empty()) {
targetStaticShapes = {{}};
return;
}
size_t targetStaticShapeSize = shapes.front().second.size();
for (size_t i = 1; i < shapes.size(); ++i) {
if (targetStaticShapeSize < shapes[i].second.size()) {
targetStaticShapeSize = shapes[i].second.size();
}
}
targetStaticShapes.resize(targetStaticShapeSize);
for (const auto& shape : shapes) {
auto dynShape = shape.first;
inputDynamicShapes.push_back(dynShape);
for (size_t i = 0; i < targetStaticShapeSize; ++i) {
targetStaticShapes[i].push_back(i < shape.second.size() ? shape.second.at(i) : shape.second.back());
}
}
}
template<typename T>
void set_tensor_value(T scalar, ov::Tensor& tensor) {
#define CASE(X) \
case X: { \
auto *dataPtr = tensor.data<element_type_traits<X>::value_type>(); \
dataPtr[0] = static_cast<element_type_traits<X>::value_type>(scalar); \
break; \
}
switch (tensor.get_element_type()) {
CASE(ElementType::boolean)
CASE(ElementType::i8)
CASE(ElementType::i16)
CASE(ElementType::i32)
CASE(ElementType::i64)
CASE(ElementType::u8)
CASE(ElementType::u16)
CASE(ElementType::u32)
CASE(ElementType::u64)
CASE(ElementType::bf16)
CASE(ElementType::f16)
CASE(ElementType::f32)
CASE(ElementType::f64)
CASE(ElementType::u1)
CASE(ElementType::i4)
CASE(ElementType::u4)
default: OPENVINO_THROW("Unsupported element type: ", tensor.get_element_type());
}
}
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
auto generate_input = [&](size_t index, ElementType element_type) {
ov::Tensor tensor(element_type, targetInputStaticShapes[index]);
if (index != 0) {
auto scalar_val = index == 1 ? min_max_values.first : min_max_values.second;
set_tensor_value(scalar_val, tensor);
}
inputs.insert({funcInputs[index].get_node_shared_ptr(), tensor});
};
for (size_t i = 0; i < targetInputStaticShapes.size(); ++i)
generate_input(i, funcInputs[i].get_element_type());
}
void SetUp() override {
RandomUnifromDynamicGPUTestParamsSet basicParamsSet = this->GetParam();
std::vector<InputShape> shapes;
ElementType netType;
std::map<std::string, std::string> additionalConfig;
std::pair<uint64_t, uint64_t> seeds;
ov::ParameterVector params;
std::tie(shapes, min_max_values, seeds, netType, targetDevice, additionalConfig) = basicParamsSet;
init_input_shapes(shapes);
params = builder::makeDynamicParams(netType, inputDynamicShapes);
const auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(params[0]);
const auto random_uniform = std::make_shared<ov::op::v8::RandomUniform>(shape_of, params[1], params[2], netType, seeds.first, seeds.second);
ov::ResultVector results = {std::make_shared<ov::op::v0::Result>(random_uniform)};
function = std::make_shared<ov::Model>(results, params, "random_uniform_test");
}
precisions_map get_ref_precisions_convert_map() override {
// Do not convert reference function from FP16 to FP32 precision, since in case of RandomUniform operation
// data type is matter
return {};
}
private:
std::pair<double, double> min_max_values;
};
TEST_P(RandomUnifromDynamicGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
std::map<std::string, std::string> emptyAdditionalConfig;
const std::vector<std::vector<ov::test::InputShape>> dynInputShapes = {
{
{{ov::PartialShape::dynamic(4)}, {{1, 2, 3, 4}, {1, 1, 5, 5}, {2, 3, 4, 5}}},
{{1}, {{1}}},
{{1}, {{1}}}
},
{
{{ov::PartialShape::dynamic(3)}, {{1, 2, 3}, {1, 1, 5}, {2, 3, 4}}},
{{1}, {{1}}},
{{1}, {{1}}}
},
{
{{ov::PartialShape::dynamic(2)}, {{1, 2}, {1, 1}, {2, 3}}},
{{1}, {{1}}},
{{1}, {{1}}}
},
{
{{ov::PartialShape::dynamic(1)}, {{1}, {2}, {3}}},
{{1}, {{1}}},
{{1}, {{1}}}
},
};
const std::vector<std::pair<double, double>> min_max_values = {
{10, 30},
};
const std::vector<std::pair<uint64_t, uint64_t>> seeds = {
{100, 10},
};
const std::vector<ElementType> netPrecisions = {
ElementType::i32,
ElementType::f32,
ElementType::f16,
};
const auto testParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
::testing::ValuesIn(min_max_values),
::testing::ValuesIn(seeds),
::testing::ValuesIn(netPrecisions),
::testing::Values(ov::test::utils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_random_uniform, RandomUnifromDynamicGPUTest,
testParams_smoke, RandomUnifromDynamicGPUTest::getTestCaseName);
} // namespace
} // namespace GPULayerTestsDefinitions

View File

@ -0,0 +1,143 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/crop.hpp>
#include <intel_gpu/primitives/data.hpp>
#include "random_uniform_inst.h"
#include "program_wrapper.h"
using namespace cldnn;
using namespace ::tests;
namespace shape_infer_tests {
struct random_uniform_si_test_params {
ov::PartialShape expected_out_pshape;
data_types out_data_type;
std::pair<float, float> min_max_vals;
};
class random_uniform_si_test : public testing::TestWithParam<random_uniform_si_test_params> { };
TEST_P(random_uniform_si_test, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
cldnn::program prog(engine);
std::vector<std::shared_ptr<primitive>> input_prims;
std::vector<input_info> input_prim_ids;
std::vector<layout> input_layouts;
const size_t num_inputs = 3;
for (size_t idx = 0; idx < num_inputs; idx++) {
auto in_layout = layout{{1}, p.out_data_type, format::bfyx};
if (idx == 0) {
auto input_pshape = ov::PartialShape{static_cast<long int>(p.expected_out_pshape.size())};
in_layout = layout{input_pshape, data_types::i64, format::bfyx};
}
input_layouts.push_back(in_layout);
auto prim_id = "input_" + std::to_string(idx);
auto const_data_prim = std::make_shared<input_layout>(prim_id, in_layout);
input_prims.push_back(const_data_prim);
input_prim_ids.push_back(input_info(prim_id));
}
auto random_uniform_prim = std::make_shared<random_uniform>("random_uniform", input_prim_ids, p.out_data_type, 0, 0);
auto& random_uniform_node = prog.get_or_create(random_uniform_prim);
for (auto& iprim : input_prims) {
auto& input_node = prog.get_or_create(iprim);
program_wrapper::add_connection(prog, input_node, random_uniform_node);
}
auto params = random_uniform_node.get_kernel_impl_params();
params->memory_deps.clear();
auto get_mem = [&](size_t idx, float val) -> memory::ptr {
auto in_layout = input_layouts[idx];
auto allocated_mem = engine.allocate_memory(in_layout);
switch (p.out_data_type) {
case data_types::f16:
set_values(allocated_mem, {float_to_half(val)});
break;
case data_types::f32:
set_values(allocated_mem, {static_cast<data_type_to_type<data_types::f32>::type>(val)});
break;
case data_types::i32:
set_values(allocated_mem, {static_cast<data_type_to_type<data_types::i32>::type>(val)});
break;
case data_types::i64:
set_values(allocated_mem, {static_cast<data_type_to_type<data_types::i64>::type>(val)});
break;
case data_types::i8:
set_values(allocated_mem, {static_cast<data_type_to_type<data_types::i8>::type>(val)});
break;
case data_types::u8:
set_values(allocated_mem, {static_cast<data_type_to_type<data_types::u8>::type>(val)});
break;
case data_types::bin:
default:
break;
}
return allocated_mem;
};
if (p.expected_out_pshape.is_static()) {
auto input_mem = engine.allocate_memory(input_layouts[0]);
set_values(input_mem, p.expected_out_pshape.get_shape());
params->memory_deps.emplace(0, input_mem);
}
params->memory_deps.emplace(1, get_mem(1, p.min_max_vals.first));
params->memory_deps.emplace(2, get_mem(2, p.min_max_vals.second));
if (p.min_max_vals.first < p.min_max_vals.second) {
auto res = random_uniform_inst::calc_output_layouts<ov::PartialShape>(random_uniform_node, *params);
auto expected_out_layout = layout{p.expected_out_pshape, p.out_data_type, format::get_default_format(p.expected_out_pshape.size())};
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], expected_out_layout);
} else {
ASSERT_ANY_THROW(random_uniform_inst::calc_output_layouts<ov::PartialShape>(random_uniform_node, *params));
}
}
INSTANTIATE_TEST_SUITE_P(smoke, random_uniform_si_test,
testing::ValuesIn(std::vector<random_uniform_si_test_params>{
{ov::PartialShape{2}, data_types::i32, {0, 10}},
{ov::PartialShape{2}, data_types::i8, {0, 10}},
{ov::PartialShape{2}, data_types::u8, {0, 10}},
{ov::PartialShape{2}, data_types::i64, {0, 10}},
{ov::PartialShape{2}, data_types::i32, {0, 10}},
{ov::PartialShape{2}, data_types::f32, {0, 10}},
{ov::PartialShape{2}, data_types::f16, {0, 10}},
{ov::PartialShape{2,4}, data_types::i32, {0, 10}},
{ov::PartialShape{2,4}, data_types::f32, {0, 10}},
{ov::PartialShape{2,4,3}, data_types::i32, {0, 10}},
{ov::PartialShape{2,4,3}, data_types::f32, {0, 10}},
{ov::PartialShape{2,4,3,2}, data_types::i32, {0, 10}},
{ov::PartialShape{2,4,3,2}, data_types::f32, {0, 10}},
{ov::PartialShape{2,4,3,1,2}, data_types::i32, {0, 10}},
{ov::PartialShape{2,4,3,1,2}, data_types::f32, {0, 10}},
// Dynamic output shape
{ov::PartialShape::dynamic(1), data_types::f32, {0, 10}},
{ov::PartialShape::dynamic(2), data_types::f32, {0, 10}},
{ov::PartialShape::dynamic(3), data_types::f32, {0, 10}},
{ov::PartialShape::dynamic(4), data_types::f32, {0, 10}},
{ov::PartialShape::dynamic(5), data_types::f32, {0, 10}},
// Incorrect min/max values
{ov::PartialShape{2}, data_types::i32, {20, 20}},
{ov::PartialShape{2,4,3,1,2}, data_types::i32, {20, 10}},
{ov::PartialShape::dynamic(1), data_types::f32, {20, 20}},
{ov::PartialShape::dynamic(5), data_types::f32, {20, 10}},
}));
}; // shape_infer_tests

View File

@ -18,8 +18,7 @@ using namespace ::tests;
*/
template<typename T>
struct RandomUniformParams {
tensor output_tensor;
format f;
ov::Shape output_shape;
T min_val;
T max_val;
uint64_t global_seed;
@ -36,20 +35,20 @@ public:
RandomUniformParams<T> params = testing::TestWithParam<RandomUniformParams<T> >::GetParam();
auto &engine = get_test_engine();
auto format = format::get_default_format(params.output_shape.size());
auto shape = engine.allocate_memory(
{data_type, params.f, {1, 1, static_cast<int32_t >(params.output_tensor.sizes().size()), 1}});
{{1, 1, 1, static_cast<long int>(params.output_shape.size())}, data_type, format});
auto min_val = engine.allocate_memory(layout(data_type, format::bfyx, {1, 1, 1, 1}));
auto max_val = engine.allocate_memory(layout(data_type, format::bfyx, {1, 1, 1, 1}));
set_values(shape, params.output_tensor.sizes());
set_values(shape, params.output_shape);
set_values(min_val, {params.min_val});
set_values(max_val, {params.max_val});
topology topology;
topology.add(
random_uniform("random_uniform", { input_info("shape"), input_info("min_val"), input_info("max_val") }, data_type, params.global_seed,
params.op_seed, params.output_tensor,
params.f));
params.op_seed, params.output_shape));
topology.add(input_layout("shape", shape->get_layout()));
topology.add(input_layout("min_val", min_val->get_layout()));
topology.add(input_layout("max_val", max_val->get_layout()));
@ -78,11 +77,11 @@ struct PrintToStringParamName {
template<class T>
std::string operator()(const testing::TestParamInfo<RandomUniformParams<T> > &param) {
std::stringstream buf;
buf << " output tensor" << param.param.output_tensor.to_string()
<< " min_value " << param.param.min_val
<< " max_value " << param.param.max_val
<< " global_seed " << param.param.global_seed
<< " op_seed " << param.param.op_seed;
buf << "output_tensor_" << param.param.output_shape
<< "_min_value_" << param.param.min_val
<< "_max_value_" << param.param.max_val
<< "_global_seed_" << param.param.global_seed
<< "_op_seed_" << param.param.op_seed;
return buf.str();
}
@ -91,11 +90,11 @@ struct PrintToStringParamName {
template<>
std::string PrintToStringParamName::operator()(const testing::TestParamInfo<RandomUniformParams<half_t> > &param) {
std::stringstream buf;
buf << " output tensor" << param.param.output_tensor.to_string()
<< " min_value " << static_cast<float>(param.param.min_val)
<< " max_value " << static_cast<float>(param.param.max_val)
<< " global_seed " << param.param.global_seed
<< " op_seed " << param.param.op_seed;
buf << "output_tensor_" << param.param.output_shape
<< "_min_value_" << static_cast<float>(param.param.min_val)
<< "_max_value_" << static_cast<float>(param.param.max_val)
<< "_global_seed_" << param.param.global_seed
<< "_op_seed_" << param.param.op_seed;
return buf.str();
}
@ -124,7 +123,7 @@ TEST_P(random_uniform_gpu_test_f16, random_f16) {
INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_int32,
random_uniform_gpu_test_i32,
::testing::Values(
RandomUniformParams<int32_t>{tensor(1, 1, 2, 3), format::bfyx, 50, 100, 80, 100,
RandomUniformParams<int32_t>{ov::Shape{1, 1, 3, 2}, 50, 100, 80, 100,
std::vector<int32_t>{
65, 70, 56,
59, 82, 92
@ -135,7 +134,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_int32,
INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_int64,
random_uniform_gpu_test_i64,
::testing::Values(
RandomUniformParams<int64_t>{tensor(1, 1, 5, 4, 3), format::bfzyx, -2600, 3700, 755,
RandomUniformParams<int64_t>{ov::Shape{1, 1, 3, 4, 5}, -2600, 3700, 755,
951,
{
2116L, -1581L, 2559L, -339L, -1660L, 519L, 90L,
@ -151,11 +150,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_int64,
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_f32,
random_uniform_gpu_test_f32,
::testing::Values(
RandomUniformParams<float>{tensor(1, 1, 3, 3), format::bfyx, 0.0, 1.0, 150, 10,
RandomUniformParams<float>{ov::Shape{1, 1, 3, 3}, 0.0, 1.0, 150, 10,
{
0.7011236, 0.30539632, 0.93931055,
0.9456035, 0.11694777, 0.50770056,
0.5197197, 0.22727466, 0.991374
}
},
RandomUniformParams<float>{ov::Shape{3, 3}, 0.0, 1.0, 150, 10,
{
0.7011236, 0.30539632, 0.93931055,
0.9456035, 0.11694777, 0.50770056,
@ -165,11 +170,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_f32,
),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(smoke_random_uniform_f16,
random_uniform_gpu_test_f16,
::testing::Values(
RandomUniformParams<half_t>{tensor(1, 1, 3, 2, 4), format::bfzyx, half_t(-1.5),
RandomUniformParams<half_t>{ov::Shape{1, 1, 4, 2, 3}, half_t(-1.5),
half_t(-1.0), 150, 10,
{half_t(-1.19726562), half_t(-1.09667969),
half_t(-1.08398438), half_t(-1.30859375),

View File

@ -5,6 +5,7 @@
#pragma once
#include "openvino/core/model.hpp"
#include "transformations/convert_precision.hpp"
#include "common_test_utils/test_common.hpp"
#include "functional_test_utils/ov_plugin_cache.hpp"
@ -69,6 +70,7 @@ protected:
virtual std::vector<ov::Tensor> calculate_refs();
virtual std::vector<ov::Tensor> get_plugin_outputs();
virtual precisions_map get_ref_precisions_convert_map();
friend void core_configuration(SubgraphBaseTest* test);
};

View File

@ -261,14 +261,12 @@ void SubgraphBaseTest::infer() {
inferRequest.infer();
}
std::vector<ov::Tensor> SubgraphBaseTest::calculate_refs() {
using InputsMap = std::map<std::shared_ptr<ov::Node>, ov::Tensor>;
auto functionToProcess = functionRefs->clone();
precisions_map SubgraphBaseTest::get_ref_precisions_convert_map() {
//TODO: remove this conversions as soon as function interpreter fully support bf16 and f16
precisions_map precisions = {
{ ngraph::element::bf16, ngraph::element::f32 }
};
auto convert_added = false;
for (const auto &param : function->get_parameters()) {
for (size_t i = 0; i < param->get_output_size(); i++) {
@ -281,11 +279,21 @@ std::vector<ov::Tensor> SubgraphBaseTest::calculate_refs() {
}
}
}
if (!convert_added) {
precisions.insert({ ngraph::element::f16, ngraph::element::f32});
}
return precisions;
}
std::vector<ov::Tensor> SubgraphBaseTest::calculate_refs() {
using InputsMap = std::map<std::shared_ptr<ov::Node>, ov::Tensor>;
auto functionToProcess = functionRefs->clone();
precisions_map convert_precisions = get_ref_precisions_convert_map();
pass::Manager manager;
manager.register_pass<ov::pass::ConvertPrecision>(precisions);
manager.register_pass<ov::pass::ConvertPrecision>(convert_precisions);
manager.run_passes(functionToProcess);
functionToProcess->validate_nodes_and_infer_types();

View File

@ -63,13 +63,15 @@ void RandomUniformLayerTest::SetUp() {
std::string targetName;
std::tie(output_shape, randomUniformParams, global_seed, op_seed, targetDevice) = this->GetParam();
const auto precision = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(randomUniformParams.precision);
auto out_shape_ = std::make_shared<ov::op::v0::Constant>(ov::element::i64,
ov::Shape{output_shape.size()},
output_shape);
// Use Parameter as input with desired precision to properly configure execution configuration
// in CoreConfiguration() function
auto input = std::make_shared<ov::op::v0::Parameter>(precision, output_shape);
auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(input);
auto min_value = createConstant(randomUniformParams.precision, randomUniformParams.min_value);
auto max_value = createConstant(randomUniformParams.precision, randomUniformParams.max_value);
auto random_uniform = std::make_shared<ngraph::op::v8::RandomUniform>(out_shape_,
auto random_uniform = std::make_shared<ngraph::op::v8::RandomUniform>(shape_of,
min_value,
max_value,
precision,
@ -77,7 +79,7 @@ void RandomUniformLayerTest::SetUp() {
op_seed);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(random_uniform)};
function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{}, "random_uniform");
function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{input}, "random_uniform");
}
void RandomUniformLayerTest::ConvertRefsParams() {