[Core][CPU] Add transformation for i64 support for Multinomial (#21469)
* [CORE] Add transformation for i64 support for Multinomial * [CPU][TESTS] Apply test suggestions, remove unnecessary override - comment from cpu PR * Update src/common/transformations/src/transformations/convert_precision.cpp Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com> * Update src/common/transformations/src/transformations/convert_precision.cpp Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com> --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com> Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com>
This commit is contained in:
parent
1b2eed0bbe
commit
65439eda5d
@ -11,6 +11,7 @@
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/opsets/opset11.hpp"
|
||||
#include "openvino/opsets/opset13.hpp"
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
#include "openvino/opsets/opset4.hpp"
|
||||
#include "openvino/opsets/opset5.hpp"
|
||||
@ -58,6 +59,7 @@ bool fuse_type_to_nms9(const std::shared_ptr<ov::Node>& node, const precisions_m
|
||||
bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_topk(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||
@ -438,7 +440,9 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
|
||||
{opset4::Range::get_type_info_static(), fuse_type_to_range_v4},
|
||||
{opset9::Eye::get_type_info_static(), fuse_type_to_eye_v9},
|
||||
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
|
||||
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8}};
|
||||
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8},
|
||||
{opset13::Multinomial::get_type_info_static(), fuse_type_to_multinomial_v13},
|
||||
};
|
||||
|
||||
for (const auto& it : m_additional_type_to_fuse_map) {
|
||||
type_to_fuse[it.first] = it.second;
|
||||
@ -844,6 +848,17 @@ bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const pr
|
||||
});
|
||||
}
|
||||
|
||||
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
|
||||
auto multinomial = ov::as_type_ptr<opset13::Multinomial>(node);
|
||||
if (!multinomial) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return update_type(0, node, precisions, [&](const element::Type& type) {
|
||||
multinomial->set_convert_type(type);
|
||||
});
|
||||
}
|
||||
|
||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
|
||||
auto generate_proposals = ov::as_type_ptr<opset9::GenerateProposals>(node);
|
||||
if (!generate_proposals) {
|
||||
|
@ -68,42 +68,6 @@ void Multinomial::initSupportedPrimitiveDescriptors() {
|
||||
ref_any);
|
||||
}
|
||||
|
||||
std::string Multinomial::getPrimitiveDescriptorType() const {
|
||||
std::string str_type;
|
||||
auto selectedPrimitiveDesc = getSelectedPrimitiveDescriptor();
|
||||
|
||||
impl_desc_type type = impl_desc_type::undef;
|
||||
if (selectedPrimitiveDesc) {
|
||||
type = selectedPrimitiveDesc->getImplementationType();
|
||||
}
|
||||
|
||||
if (type == impl_desc_type::unknown)
|
||||
str_type += "unknown_";
|
||||
if ((type & impl_desc_type::jit) == impl_desc_type::jit)
|
||||
str_type += "jit_";
|
||||
if ((type & impl_desc_type::ref) == impl_desc_type::ref)
|
||||
str_type += "ref_";
|
||||
if ((type & impl_desc_type::avx512) == impl_desc_type::avx512)
|
||||
str_type += "avx512_";
|
||||
if ((type & impl_desc_type::avx2) == impl_desc_type::avx2)
|
||||
str_type += "avx2_";
|
||||
if ((type & impl_desc_type::sse42) == impl_desc_type::sse42)
|
||||
str_type += "sse42_";
|
||||
if ((type & impl_desc_type::any) == impl_desc_type::any)
|
||||
str_type += "any_";
|
||||
|
||||
if (str_type.empty())
|
||||
str_type += "undef_";
|
||||
|
||||
if (selectedPrimitiveDesc) {
|
||||
str_type += m_output_precision.get_type_name();
|
||||
} else {
|
||||
str_type.pop_back();
|
||||
}
|
||||
|
||||
return str_type;
|
||||
}
|
||||
|
||||
bool Multinomial::needShapeInfer() const {
|
||||
return !(m_const_inputs[NUM_SAMPLES_PORT] && m_const_batch);
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
std::string getPrimitiveDescriptorType() const override;
|
||||
|
||||
bool created() const override;
|
||||
|
||||
|
@ -11,8 +11,6 @@ namespace {
|
||||
|
||||
using ov::test::MultinomialLayerTest;
|
||||
|
||||
std::vector<std::pair<uint64_t, uint64_t>> global_op_seed = {{1ul, 2ul}, {0ul, 0ul}};
|
||||
|
||||
std::vector<float> probs_4x4_f32 = {0.00001f,
|
||||
0.001f,
|
||||
0.1f,
|
||||
@ -53,61 +51,69 @@ std::vector<ov::bfloat16> probs_1x3_bf16_log = {ov::bfloat16(3.0f), ov::bfloat16
|
||||
|
||||
std::vector<int> num_samples_scalar_i32 = {1};
|
||||
std::vector<int64_t> num_samples_1x1_i64 = {2};
|
||||
std::vector<int64_t> num_samples_scalar_i64 = {3};
|
||||
|
||||
const std::vector<ov::Tensor> probs = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
|
||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
|
||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data())};
|
||||
const auto probs = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
|
||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
|
||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data()));
|
||||
|
||||
const std::vector<ov::Tensor> probs_log = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
|
||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
|
||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data())};
|
||||
const auto probs_log = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
|
||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
|
||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data()));
|
||||
|
||||
const std::vector<ov::Tensor> num_samples = {ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
|
||||
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()),
|
||||
ov::Tensor(ov::element::i64, {}, num_samples_scalar_i64.data())};
|
||||
const auto num_samples = testing::Values(ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
|
||||
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()));
|
||||
|
||||
const std::vector<ov::test::ElementType> convert_type = {ov::test::ElementType::i32};
|
||||
const auto convert_type = testing::Values(ov::test::ElementType::i32, ov::test::ElementType::i64);
|
||||
|
||||
const std::vector<bool> with_replacement = {
|
||||
// true,
|
||||
false};
|
||||
const auto with_replacement = testing::Values(true, false);
|
||||
|
||||
const auto params_static = ::testing::Combine(::testing::Values("static"),
|
||||
::testing::ValuesIn(probs),
|
||||
::testing::ValuesIn(num_samples),
|
||||
::testing::ValuesIn(convert_type),
|
||||
::testing::ValuesIn(with_replacement),
|
||||
::testing::Values(false), // log_probs
|
||||
::testing::ValuesIn(global_op_seed),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||
const auto log_probs_true = testing::Values(true);
|
||||
const auto log_probs_false = testing::Values(false);
|
||||
|
||||
const auto params_static_log = ::testing::Combine(::testing::Values("static"),
|
||||
::testing::ValuesIn(probs_log),
|
||||
::testing::ValuesIn(num_samples),
|
||||
::testing::ValuesIn(convert_type),
|
||||
::testing::ValuesIn(with_replacement),
|
||||
::testing::Values(true), // log_probs
|
||||
::testing::ValuesIn(global_op_seed),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||
const auto test_type_static = testing::Values("static");
|
||||
const auto test_type_dynamic = testing::Values("dynamic");
|
||||
|
||||
const auto params_dynamic = ::testing::Combine(::testing::Values("dynamic"),
|
||||
::testing::ValuesIn(probs),
|
||||
::testing::ValuesIn(num_samples),
|
||||
::testing::ValuesIn(convert_type),
|
||||
::testing::ValuesIn(with_replacement),
|
||||
::testing::Values(false), // log_probs
|
||||
::testing::ValuesIn(global_op_seed),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||
// NOTE: (0,0) seeds are skipped (ticket 126095)
|
||||
const auto global_op_seed =
|
||||
testing::Values(std::pair<uint64_t, uint64_t>{1ul, 2ul}, std::pair<uint64_t, uint64_t>{0ul, 0ul});
|
||||
|
||||
const auto params_dynamic_log = ::testing::Combine(::testing::Values("dynamic"),
|
||||
::testing::ValuesIn(probs_log),
|
||||
::testing::ValuesIn(num_samples),
|
||||
::testing::ValuesIn(convert_type),
|
||||
::testing::ValuesIn(with_replacement),
|
||||
::testing::Values(true), // log_probs
|
||||
::testing::ValuesIn(global_op_seed),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||
const auto device_cpu = testing::Values(ov::test::utils::DEVICE_CPU);
|
||||
|
||||
const auto params_static = ::testing::Combine(test_type_static,
|
||||
probs,
|
||||
num_samples,
|
||||
convert_type,
|
||||
with_replacement,
|
||||
log_probs_false,
|
||||
global_op_seed,
|
||||
device_cpu);
|
||||
|
||||
const auto params_static_log = ::testing::Combine(test_type_static,
|
||||
probs_log,
|
||||
num_samples,
|
||||
convert_type,
|
||||
with_replacement,
|
||||
log_probs_true,
|
||||
global_op_seed,
|
||||
device_cpu);
|
||||
|
||||
const auto params_dynamic = ::testing::Combine(test_type_dynamic,
|
||||
probs,
|
||||
num_samples,
|
||||
convert_type,
|
||||
with_replacement,
|
||||
log_probs_false,
|
||||
global_op_seed,
|
||||
device_cpu);
|
||||
|
||||
const auto params_dynamic_log = ::testing::Combine(test_type_dynamic,
|
||||
probs_log,
|
||||
num_samples,
|
||||
convert_type,
|
||||
with_replacement,
|
||||
log_probs_true,
|
||||
global_op_seed,
|
||||
device_cpu);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStatic,
|
||||
MultinomialLayerTest,
|
||||
|
Loading…
Reference in New Issue
Block a user