[Core][CPU] Add transformation for i64 support for Multinomial (#21469)
* [CORE] Add transformation for i64 support for Multinomial * [CPU][TESTS] Apply test suggestions, remove unnecessary override - comment from cpu PR * Update src/common/transformations/src/transformations/convert_precision.cpp Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com> * Update src/common/transformations/src/transformations/convert_precision.cpp Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com> --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com> Co-authored-by: Mateusz Tabaka <mateusz.tabaka@intel.com>
This commit is contained in:
parent
1b2eed0bbe
commit
65439eda5d
@ -11,6 +11,7 @@
|
|||||||
#include "openvino/opsets/opset1.hpp"
|
#include "openvino/opsets/opset1.hpp"
|
||||||
#include "openvino/opsets/opset10.hpp"
|
#include "openvino/opsets/opset10.hpp"
|
||||||
#include "openvino/opsets/opset11.hpp"
|
#include "openvino/opsets/opset11.hpp"
|
||||||
|
#include "openvino/opsets/opset13.hpp"
|
||||||
#include "openvino/opsets/opset3.hpp"
|
#include "openvino/opsets/opset3.hpp"
|
||||||
#include "openvino/opsets/opset4.hpp"
|
#include "openvino/opsets/opset4.hpp"
|
||||||
#include "openvino/opsets/opset5.hpp"
|
#include "openvino/opsets/opset5.hpp"
|
||||||
@ -58,6 +59,7 @@ bool fuse_type_to_nms9(const std::shared_ptr<ov::Node>& node, const precisions_m
|
|||||||
bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
bool fuse_type_to_matrix_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_matrix_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
|
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
bool fuse_type_to_topk(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_topk(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
|
||||||
@ -438,7 +440,9 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
|
|||||||
{opset4::Range::get_type_info_static(), fuse_type_to_range_v4},
|
{opset4::Range::get_type_info_static(), fuse_type_to_range_v4},
|
||||||
{opset9::Eye::get_type_info_static(), fuse_type_to_eye_v9},
|
{opset9::Eye::get_type_info_static(), fuse_type_to_eye_v9},
|
||||||
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
|
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
|
||||||
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8}};
|
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8},
|
||||||
|
{opset13::Multinomial::get_type_info_static(), fuse_type_to_multinomial_v13},
|
||||||
|
};
|
||||||
|
|
||||||
for (const auto& it : m_additional_type_to_fuse_map) {
|
for (const auto& it : m_additional_type_to_fuse_map) {
|
||||||
type_to_fuse[it.first] = it.second;
|
type_to_fuse[it.first] = it.second;
|
||||||
@ -844,6 +848,17 @@ bool fuse_type_to_multiclass_nms(const std::shared_ptr<ov::Node>& node, const pr
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fuse_type_to_multinomial_v13(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
|
||||||
|
auto multinomial = ov::as_type_ptr<opset13::Multinomial>(node);
|
||||||
|
if (!multinomial) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return update_type(0, node, precisions, [&](const element::Type& type) {
|
||||||
|
multinomial->set_convert_type(type);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
|
bool fuse_type_to_generate_proposals(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
|
||||||
auto generate_proposals = ov::as_type_ptr<opset9::GenerateProposals>(node);
|
auto generate_proposals = ov::as_type_ptr<opset9::GenerateProposals>(node);
|
||||||
if (!generate_proposals) {
|
if (!generate_proposals) {
|
||||||
|
@ -68,42 +68,6 @@ void Multinomial::initSupportedPrimitiveDescriptors() {
|
|||||||
ref_any);
|
ref_any);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Multinomial::getPrimitiveDescriptorType() const {
|
|
||||||
std::string str_type;
|
|
||||||
auto selectedPrimitiveDesc = getSelectedPrimitiveDescriptor();
|
|
||||||
|
|
||||||
impl_desc_type type = impl_desc_type::undef;
|
|
||||||
if (selectedPrimitiveDesc) {
|
|
||||||
type = selectedPrimitiveDesc->getImplementationType();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type == impl_desc_type::unknown)
|
|
||||||
str_type += "unknown_";
|
|
||||||
if ((type & impl_desc_type::jit) == impl_desc_type::jit)
|
|
||||||
str_type += "jit_";
|
|
||||||
if ((type & impl_desc_type::ref) == impl_desc_type::ref)
|
|
||||||
str_type += "ref_";
|
|
||||||
if ((type & impl_desc_type::avx512) == impl_desc_type::avx512)
|
|
||||||
str_type += "avx512_";
|
|
||||||
if ((type & impl_desc_type::avx2) == impl_desc_type::avx2)
|
|
||||||
str_type += "avx2_";
|
|
||||||
if ((type & impl_desc_type::sse42) == impl_desc_type::sse42)
|
|
||||||
str_type += "sse42_";
|
|
||||||
if ((type & impl_desc_type::any) == impl_desc_type::any)
|
|
||||||
str_type += "any_";
|
|
||||||
|
|
||||||
if (str_type.empty())
|
|
||||||
str_type += "undef_";
|
|
||||||
|
|
||||||
if (selectedPrimitiveDesc) {
|
|
||||||
str_type += m_output_precision.get_type_name();
|
|
||||||
} else {
|
|
||||||
str_type.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
return str_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Multinomial::needShapeInfer() const {
|
bool Multinomial::needShapeInfer() const {
|
||||||
return !(m_const_inputs[NUM_SAMPLES_PORT] && m_const_batch);
|
return !(m_const_inputs[NUM_SAMPLES_PORT] && m_const_batch);
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ public:
|
|||||||
|
|
||||||
void getSupportedDescriptors() override;
|
void getSupportedDescriptors() override;
|
||||||
void initSupportedPrimitiveDescriptors() override;
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
std::string getPrimitiveDescriptorType() const override;
|
|
||||||
|
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
|
|
||||||
|
@ -11,8 +11,6 @@ namespace {
|
|||||||
|
|
||||||
using ov::test::MultinomialLayerTest;
|
using ov::test::MultinomialLayerTest;
|
||||||
|
|
||||||
std::vector<std::pair<uint64_t, uint64_t>> global_op_seed = {{1ul, 2ul}, {0ul, 0ul}};
|
|
||||||
|
|
||||||
std::vector<float> probs_4x4_f32 = {0.00001f,
|
std::vector<float> probs_4x4_f32 = {0.00001f,
|
||||||
0.001f,
|
0.001f,
|
||||||
0.1f,
|
0.1f,
|
||||||
@ -53,61 +51,69 @@ std::vector<ov::bfloat16> probs_1x3_bf16_log = {ov::bfloat16(3.0f), ov::bfloat16
|
|||||||
|
|
||||||
std::vector<int> num_samples_scalar_i32 = {1};
|
std::vector<int> num_samples_scalar_i32 = {1};
|
||||||
std::vector<int64_t> num_samples_1x1_i64 = {2};
|
std::vector<int64_t> num_samples_1x1_i64 = {2};
|
||||||
std::vector<int64_t> num_samples_scalar_i64 = {3};
|
|
||||||
|
|
||||||
const std::vector<ov::Tensor> probs = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
|
const auto probs = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
|
||||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
|
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
|
||||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data())};
|
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data()));
|
||||||
|
|
||||||
const std::vector<ov::Tensor> probs_log = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
|
const auto probs_log = testing::Values(ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
|
||||||
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
|
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
|
||||||
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data())};
|
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data()));
|
||||||
|
|
||||||
const std::vector<ov::Tensor> num_samples = {ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
|
const auto num_samples = testing::Values(ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
|
||||||
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()),
|
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()));
|
||||||
ov::Tensor(ov::element::i64, {}, num_samples_scalar_i64.data())};
|
|
||||||
|
|
||||||
const std::vector<ov::test::ElementType> convert_type = {ov::test::ElementType::i32};
|
const auto convert_type = testing::Values(ov::test::ElementType::i32, ov::test::ElementType::i64);
|
||||||
|
|
||||||
const std::vector<bool> with_replacement = {
|
const auto with_replacement = testing::Values(true, false);
|
||||||
// true,
|
|
||||||
false};
|
|
||||||
|
|
||||||
const auto params_static = ::testing::Combine(::testing::Values("static"),
|
const auto log_probs_true = testing::Values(true);
|
||||||
::testing::ValuesIn(probs),
|
const auto log_probs_false = testing::Values(false);
|
||||||
::testing::ValuesIn(num_samples),
|
|
||||||
::testing::ValuesIn(convert_type),
|
|
||||||
::testing::ValuesIn(with_replacement),
|
|
||||||
::testing::Values(false), // log_probs
|
|
||||||
::testing::ValuesIn(global_op_seed),
|
|
||||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
|
||||||
|
|
||||||
const auto params_static_log = ::testing::Combine(::testing::Values("static"),
|
const auto test_type_static = testing::Values("static");
|
||||||
::testing::ValuesIn(probs_log),
|
const auto test_type_dynamic = testing::Values("dynamic");
|
||||||
::testing::ValuesIn(num_samples),
|
|
||||||
::testing::ValuesIn(convert_type),
|
|
||||||
::testing::ValuesIn(with_replacement),
|
|
||||||
::testing::Values(true), // log_probs
|
|
||||||
::testing::ValuesIn(global_op_seed),
|
|
||||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
|
||||||
|
|
||||||
const auto params_dynamic = ::testing::Combine(::testing::Values("dynamic"),
|
// NOTE: (0,0) seeds are skipped (ticket 126095)
|
||||||
::testing::ValuesIn(probs),
|
const auto global_op_seed =
|
||||||
::testing::ValuesIn(num_samples),
|
testing::Values(std::pair<uint64_t, uint64_t>{1ul, 2ul}, std::pair<uint64_t, uint64_t>{0ul, 0ul});
|
||||||
::testing::ValuesIn(convert_type),
|
|
||||||
::testing::ValuesIn(with_replacement),
|
|
||||||
::testing::Values(false), // log_probs
|
|
||||||
::testing::ValuesIn(global_op_seed),
|
|
||||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
|
||||||
|
|
||||||
const auto params_dynamic_log = ::testing::Combine(::testing::Values("dynamic"),
|
const auto device_cpu = testing::Values(ov::test::utils::DEVICE_CPU);
|
||||||
::testing::ValuesIn(probs_log),
|
|
||||||
::testing::ValuesIn(num_samples),
|
const auto params_static = ::testing::Combine(test_type_static,
|
||||||
::testing::ValuesIn(convert_type),
|
probs,
|
||||||
::testing::ValuesIn(with_replacement),
|
num_samples,
|
||||||
::testing::Values(true), // log_probs
|
convert_type,
|
||||||
::testing::ValuesIn(global_op_seed),
|
with_replacement,
|
||||||
::testing::Values(ov::test::utils::DEVICE_CPU));
|
log_probs_false,
|
||||||
|
global_op_seed,
|
||||||
|
device_cpu);
|
||||||
|
|
||||||
|
const auto params_static_log = ::testing::Combine(test_type_static,
|
||||||
|
probs_log,
|
||||||
|
num_samples,
|
||||||
|
convert_type,
|
||||||
|
with_replacement,
|
||||||
|
log_probs_true,
|
||||||
|
global_op_seed,
|
||||||
|
device_cpu);
|
||||||
|
|
||||||
|
const auto params_dynamic = ::testing::Combine(test_type_dynamic,
|
||||||
|
probs,
|
||||||
|
num_samples,
|
||||||
|
convert_type,
|
||||||
|
with_replacement,
|
||||||
|
log_probs_false,
|
||||||
|
global_op_seed,
|
||||||
|
device_cpu);
|
||||||
|
|
||||||
|
const auto params_dynamic_log = ::testing::Combine(test_type_dynamic,
|
||||||
|
probs_log,
|
||||||
|
num_samples,
|
||||||
|
convert_type,
|
||||||
|
with_replacement,
|
||||||
|
log_probs_true,
|
||||||
|
global_op_seed,
|
||||||
|
device_cpu);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStatic,
|
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStatic,
|
||||||
MultinomialLayerTest,
|
MultinomialLayerTest,
|
||||||
|
Loading…
Reference in New Issue
Block a user