[Core][CPU] Add transformation for i64 support for Multinomial (#21469)

* [CORE] Add transformation for i64 support for Multinomial

* [CPU][TESTS] Apply test suggestions, remove unnecessary override - comment from cpu PR

* Update src/common/transformations/src/transformations/convert_precision.cpp

Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com>

* Update src/common/transformations/src/transformations/convert_precision.cpp

Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com>

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com>
This commit is contained in:
Piotr Krzemiński 2023-12-13 17:22:34 +01:00 committed by GitHub
parent 1b2eed0bbe
commit 65439eda5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 86 deletions

View File

@ -11,6 +11,7 @@
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/opsets/opset13.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/opsets/opset5.hpp"
@ -58,6 +59,7 @@ bool fuse_type_to_nms9(const std::shared_ptr<ov::Node>& node, const precisions_m
bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_matrix_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_topk(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
@ -438,7 +440,9 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
{opset4::Range::get_type_info_static(), fuse_type_to_range_v4},
{opset9::Eye::get_type_info_static(), fuse_type_to_eye_v9},
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8}};
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8},
{opset13::Multinomial::get_type_info_static(), fuse_type_to_multinomial_v13},
};
for (const auto& it : m_additional_type_to_fuse_map) {
type_to_fuse[it.first] = it.second;
@ -844,6 +848,17 @@ bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const pr
});
}
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto multinomial = ov::as_type_ptr<opset13::Multinomial>(node);
if (!multinomial) {
return false;
}
return update_type(0, node, precisions, [&](const element::Type& type) {
multinomial->set_convert_type(type);
});
}
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto generate_proposals = ov::as_type_ptr<opset9::GenerateProposals>(node);
if (!generate_proposals) {

View File

@ -68,42 +68,6 @@ void Multinomial::initSupportedPrimitiveDescriptors() {
ref_any);
}
std::string Multinomial::getPrimitiveDescriptorType() const {
std::string str_type;
auto selectedPrimitiveDesc = getSelectedPrimitiveDescriptor();
impl_desc_type type = impl_desc_type::undef;
if (selectedPrimitiveDesc) {
type = selectedPrimitiveDesc->getImplementationType();
}
if (type == impl_desc_type::unknown)
str_type += "unknown_";
if ((type & impl_desc_type::jit) == impl_desc_type::jit)
str_type += "jit_";
if ((type & impl_desc_type::ref) == impl_desc_type::ref)
str_type += "ref_";
if ((type & impl_desc_type::avx512) == impl_desc_type::avx512)
str_type += "avx512_";
if ((type & impl_desc_type::avx2) == impl_desc_type::avx2)
str_type += "avx2_";
if ((type & impl_desc_type::sse42) == impl_desc_type::sse42)
str_type += "sse42_";
if ((type & impl_desc_type::any) == impl_desc_type::any)
str_type += "any_";
if (str_type.empty())
str_type += "undef_";
if (selectedPrimitiveDesc) {
str_type += m_output_precision.get_type_name();
} else {
str_type.pop_back();
}
return str_type;
}
bool Multinomial::needShapeInfer() const {
return !(m_const_inputs[NUM_SAMPLES_PORT] && m_const_batch);
}

View File

@ -21,7 +21,6 @@ public:
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
std::string getPrimitiveDescriptorType() const override;
bool created() const override;

View File

@ -11,8 +11,6 @@ namespace {
using ov::test::MultinomialLayerTest;
std::vector<std::pair<uint64_t, uint64_t>> global_op_seed = {{1ul, 2ul}, {0ul, 0ul}};
std::vector<float> probs_4x4_f32 = {0.00001f,
0.001f,
0.1f,
@ -53,61 +51,69 @@ std::vector<ov::bfloat16> probs_1x3_bf16_log = {ov::bfloat16(3.0f), ov::bfloat16
std::vector<int> num_samples_scalar_i32 = {1};
std::vector<int64_t> num_samples_1x1_i64 = {2};
std::vector<int64_t> num_samples_scalar_i64 = {3};
const std::vector<ov::Tensor> probs = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data())};
const auto probs = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data()));
const std::vector<ov::Tensor> probs_log = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data())};
const auto probs_log = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data()));
const std::vector<ov::Tensor> num_samples = {ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()),
ov::Tensor(ov::element::i64, {}, num_samples_scalar_i64.data())};
const auto num_samples = testing::Values(ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()));
const std::vector<ov::test::ElementType> convert_type = {ov::test::ElementType::i32};
const auto convert_type = testing::Values(ov::test::ElementType::i32, ov::test::ElementType::i64);
const std::vector<bool> with_replacement = {
// true,
false};
const auto with_replacement = testing::Values(true, false);
const auto params_static = ::testing::Combine(::testing::Values("static"),
::testing::ValuesIn(probs),
::testing::ValuesIn(num_samples),
::testing::ValuesIn(convert_type),
::testing::ValuesIn(with_replacement),
::testing::Values(false), // log_probs
::testing::ValuesIn(global_op_seed),
::testing::Values(ov::test::utils::DEVICE_CPU));
const auto log_probs_true = testing::Values(true);
const auto log_probs_false = testing::Values(false);
const auto params_static_log = ::testing::Combine(::testing::Values("static"),
::testing::ValuesIn(probs_log),
::testing::ValuesIn(num_samples),
::testing::ValuesIn(convert_type),
::testing::ValuesIn(with_replacement),
::testing::Values(true), // log_probs
::testing::ValuesIn(global_op_seed),
::testing::Values(ov::test::utils::DEVICE_CPU));
const auto test_type_static = testing::Values("static");
const auto test_type_dynamic = testing::Values("dynamic");
const auto params_dynamic = ::testing::Combine(::testing::Values("dynamic"),
::testing::ValuesIn(probs),
::testing::ValuesIn(num_samples),
::testing::ValuesIn(convert_type),
::testing::ValuesIn(with_replacement),
::testing::Values(false), // log_probs
::testing::ValuesIn(global_op_seed),
::testing::Values(ov::test::utils::DEVICE_CPU));
// NOTE: (0,0) seeds are skipped (ticket 126095)
const auto global_op_seed =
testing::Values(std::pair<uint64_t, uint64_t>{1ul, 2ul}, std::pair<uint64_t, uint64_t>{0ul, 0ul});
const auto params_dynamic_log = ::testing::Combine(::testing::Values("dynamic"),
::testing::ValuesIn(probs_log),
::testing::ValuesIn(num_samples),
::testing::ValuesIn(convert_type),
::testing::ValuesIn(with_replacement),
::testing::Values(true), // log_probs
::testing::ValuesIn(global_op_seed),
::testing::Values(ov::test::utils::DEVICE_CPU));
const auto device_cpu = testing::Values(ov::test::utils::DEVICE_CPU);
const auto params_static = ::testing::Combine(test_type_static,
probs,
num_samples,
convert_type,
with_replacement,
log_probs_false,
global_op_seed,
device_cpu);
const auto params_static_log = ::testing::Combine(test_type_static,
probs_log,
num_samples,
convert_type,
with_replacement,
log_probs_true,
global_op_seed,
device_cpu);
const auto params_dynamic = ::testing::Combine(test_type_dynamic,
probs,
num_samples,
convert_type,
with_replacement,
log_probs_false,
global_op_seed,
device_cpu);
const auto params_dynamic_log = ::testing::Combine(test_type_dynamic,
probs_log,
num_samples,
convert_type,
with_replacement,
log_probs_true,
global_op_seed,
device_cpu);
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStatic,
MultinomialLayerTest,