[CPU] Multinomial implementation (#20406)
* [CPU] Temp save commit * [CPU] Add initial CPU implementation of Multinomial * [CPU] Add parallel implementation with mock randomuniform * [CPU] Fix accumulate incorrect iterator provided * [CPU] Add tests for multinomial * [CORE] Add lost tests * [CPU] Add dynamic shape inference and descriptors init * [CPU] Revamp tests to multiple files * [CPU/SPEC] Apply suggested changes * [CPU] Fix test compilation issues, clang fix * Update multinomial.cpp * [CPU] Fix Incorrect Primitive Descriptor for multiple combinations * [CPU] Change params to inputs in testing function * [CPU] Fix dynamic shape inference tensor access error * [CPU] Save stable version * [CPU] Add template execute for different input dtypes * [CPU] Introduce new method of loading data to tests, fix dynamic shape inference * [CPU] Improve parralelism * [CPU] Improve pararrelism - fix indexes * [CPU] Fix no_replacement tests, fix randomness in tests * [CPU] Split tests into log and no_log version to avoid rounding when values are close to 0 * [CPU] Add mersenne-twister seed and random_uniform distribution as source for randomness, add debug prints * [CPU] Apply suggestions from review, fix 4x4 log tests * [CPU] Force i32 convert format * [CPU] Fix double to float conversion warning * [CPU] Remove debugging prints, fix CIs float error * [CPU] Fix for convert_type in CIs * Update src/plugins/intel_cpu/src/shape_inference/custom/multinomial.hpp Co-authored-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com> * Update src/plugins/intel_cpu/src/nodes/multinomial.hpp Co-authored-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com> * Update src/plugins/intel_cpu/src/shape_inference/custom/multinomial.hpp Co-authored-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com> * [CPU] Migrate to CPU API 2.0 * [Ref/CPU] Remove support for 1D tensors, use Core Shape Inference * [CPU] Remove unnecessary symbols * Update multinomial.cpp * Update multinomial.cpp * Update ops.py * [CPU] Fix const identifier missing after reinterpret cast * [CPU] Fix Mac cpplint error * [CPU] Apply recommended changes - 0-seed nondeterminism, casts in testsshape_infer optimization * [CPU] Apply iterator optimization suggestion * [CPU] Replace casts with class constructors in tests * [CPU] Remove unnecessary static_casts to void* * Update multinomial.cpp * [CPU] Apply suggestions from review - move template, fix i64 precision, redturn off shape precision for const inputs, set always-execute for const inputs * [CPU] Relocate tests to shared, remove using namespace from header files * [CPU] Add definitions for files eaten by clang fix * [CPU] Fix seed for Mersenne Twister Engine * [CPU] Try fix incorrect 1x3 for 3 samples test (bf16) * [CPU] Use only mersenne for seed generation * [CPU] Relocate test, add debug prints * [CPU] Add relocated test that got eaten * [CPU] Remove uniform distribution, replace with division by max value * Update multinomial.cpp * Update multinomial.cpp * [CPU] Add explicit float cast for CIs * Update multinomial.cpp * [CPU] Use intel_cpu::bfloat16 to reduce innacuracies * [CPU] Remove debug caps, all tests pass * [CPU] Clang fix * [GPU] Remove GPU 1D test case * [CPU] Modify tests to add seed=0 case, add ignore statement for this test and add subtask to complete after current release --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com> Co-authored-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
This commit is contained in:
parent
d4c342fc79
commit
44f7bf7e3f
@ -162,7 +162,7 @@ def multinomial(
|
|||||||
) -> Node:
|
) -> Node:
|
||||||
"""Return a node which generates a sequence of class indices sampled from the multinomial distribution.
|
"""Return a node which generates a sequence of class indices sampled from the multinomial distribution.
|
||||||
|
|
||||||
:param probs: Tensor with probabilities of floating-point type, and shape [class_size] or [batch_size, class_size].
|
:param probs: Tensor with probabilities of floating-point type, and shape [batch_size, class_size].
|
||||||
:param num_samples: Tensor (scalar or 1D) a single element of type i32 or i64,
|
:param num_samples: Tensor (scalar or 1D) a single element of type i32 or i64,
|
||||||
specifying the number of samples to draw from the multinomial distribution.
|
specifying the number of samples to draw from the multinomial distribution.
|
||||||
:param convert_type: Specifies the output tensor type, possible values: 'i64', 'i32'.
|
:param convert_type: Specifies the output tensor type, possible values: 'i64', 'i32'.
|
||||||
|
@ -13,7 +13,7 @@ from openvino import PartialShape, Type
|
|||||||
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"),
|
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"),
|
||||||
[
|
[
|
||||||
([4, 16], [], "i32", False, True, 7461, 1546, PartialShape([4, -1])),
|
([4, 16], [], "i32", False, True, 7461, 1546, PartialShape([4, -1])),
|
||||||
([8], [1], "i64", True, False, 0, 0, PartialShape([-1])),
|
([1, 8], [1], "i64", True, False, 0, 0, PartialShape([1, -1])),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_multinomial_param_inputs(probs_shape, num_samples_shape, convert_type, with_replacement, log_probs, global_seed, op_seed, expected_out_shape):
|
def test_multinomial_param_inputs(probs_shape, num_samples_shape, convert_type, with_replacement, log_probs, global_seed, op_seed, expected_out_shape):
|
||||||
@ -35,7 +35,7 @@ def test_multinomial_param_inputs(probs_shape, num_samples_shape, convert_type,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("probs_array", "num_samples_val", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"),
|
("probs_array", "num_samples_val", "convert_type", "with_replacement", "log_probs", "global_seed", "op_seed", "expected_out_shape"),
|
||||||
[
|
[
|
||||||
(np.array([0.7, 0.3, 0.6, 0.5]), 3, "i32", False, True, 111, 222, PartialShape([3])),
|
(np.array([[0.7, 0.3, 0.6, 0.5]]), 3, "i32", False, True, 111, 222, PartialShape([1, 3])),
|
||||||
(np.array([[0.7, 0.3], [0.6, 0.5]]), 2, "i64", True, False, 111, 222, PartialShape([2, 2])),
|
(np.array([[0.7, 0.3], [0.6, 0.5]]), 2, "i64", True, False, 111, 222, PartialShape([2, 2])),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -59,7 +59,7 @@ def test_multinomial_const_inputs(probs_array, num_samples_val, convert_type, wi
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "expected_out_shape"),
|
("probs_shape", "num_samples_shape", "convert_type", "with_replacement", "log_probs", "expected_out_shape"),
|
||||||
[
|
[
|
||||||
([10], [1], "i32", True, True, PartialShape([-1])),
|
([1, 10], [1], "i32", True, True, PartialShape([1, -1])),
|
||||||
([2, 16], [], "i64", False, False, PartialShape([2, -1])),
|
([2, 16], [], "i64", False, False, PartialShape([2, -1])),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -33,7 +33,7 @@ public:
|
|||||||
*/
|
*/
|
||||||
Multinomial(const Output<Node>& input,
|
Multinomial(const Output<Node>& input,
|
||||||
const Output<Node>& num_samples,
|
const Output<Node>& num_samples,
|
||||||
const ov::element::Type_t output_type,
|
const ov::element::Type_t convert_type,
|
||||||
const bool with_replacement,
|
const bool with_replacement,
|
||||||
const bool log_probs,
|
const bool log_probs,
|
||||||
const uint64_t global_seed = 0,
|
const uint64_t global_seed = 0,
|
||||||
@ -49,7 +49,7 @@ public:
|
|||||||
uint64_t get_global_seed() const;
|
uint64_t get_global_seed() const;
|
||||||
uint64_t get_op_seed() const;
|
uint64_t get_op_seed() const;
|
||||||
|
|
||||||
void set_convert_type(const ov::element::Type_t output_type);
|
void set_convert_type(const ov::element::Type_t convert_type);
|
||||||
void set_with_replacement(const bool with_replacement);
|
void set_with_replacement(const bool with_replacement);
|
||||||
void set_log_probs(const bool log_probs);
|
void set_log_probs(const bool log_probs);
|
||||||
void set_global_seed(const uint64_t global_seed);
|
void set_global_seed(const uint64_t global_seed);
|
||||||
|
@ -135,7 +135,7 @@ void multinomial(const T* probs,
|
|||||||
if (!with_replacement) {
|
if (!with_replacement) {
|
||||||
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
|
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
|
||||||
cdf[i_translated + selected_class_idx - 1]
|
cdf[i_translated + selected_class_idx - 1]
|
||||||
: cdf[i_translated + selected_class_idx];
|
: cdf[i_translated];
|
||||||
T divisor = 1 - class_probability;
|
T divisor = 1 - class_probability;
|
||||||
for (size_t k = 0; k < class_size; ++k) {
|
for (size_t k = 0; k < class_size; ++k) {
|
||||||
if (k >= selected_class_idx) {
|
if (k >= selected_class_idx) {
|
||||||
|
@ -20,8 +20,8 @@ std::vector<TRShape> shape_infer(const Multinomial* op,
|
|||||||
const auto& input_shape = input_shapes[0];
|
const auto& input_shape = input_shapes[0];
|
||||||
NODE_SHAPE_INFER_CHECK(op,
|
NODE_SHAPE_INFER_CHECK(op,
|
||||||
input_shapes,
|
input_shapes,
|
||||||
input_shape.rank().compatible(1) || input_shape.rank().compatible(2),
|
input_shape.rank().compatible(2),
|
||||||
"The rank of the 'probs' tensor defining output shape must be either 1 or 2.");
|
"Input probabilities must be a 2D tensor.");
|
||||||
|
|
||||||
const auto& num_samples_shape = input_shapes[1];
|
const auto& num_samples_shape = input_shapes[1];
|
||||||
NODE_SHAPE_INFER_CHECK(op,
|
NODE_SHAPE_INFER_CHECK(op,
|
||||||
@ -33,19 +33,16 @@ std::vector<TRShape> shape_infer(const Multinomial* op,
|
|||||||
auto& result_shape = output_shapes[0];
|
auto& result_shape = output_shapes[0];
|
||||||
const auto input_rank_static = input_shape.rank().is_static();
|
const auto input_rank_static = input_shape.rank().is_static();
|
||||||
if (input_rank_static) {
|
if (input_rank_static) {
|
||||||
|
result_shape.push_back(input_shape[0]);
|
||||||
const auto& num_samples = get_input_const_data_as_shape<TRShape>(op, 1, ta);
|
const auto& num_samples = get_input_const_data_as_shape<TRShape>(op, 1, ta);
|
||||||
if (num_samples) {
|
if (num_samples) {
|
||||||
NODE_VALIDATION_CHECK(op,
|
NODE_VALIDATION_CHECK(op,
|
||||||
(*num_samples)[0].get_min_length() >= 0,
|
(*num_samples)[0].get_min_length() >= 0,
|
||||||
"Number of samples must be non-negative. Got number of samples: ",
|
"Number of samples must be non-negative. Got number of samples: ",
|
||||||
(*num_samples)[0].get_min_length());
|
(*num_samples)[0].get_min_length());
|
||||||
result_shape = *num_samples;
|
result_shape.push_back((*num_samples)[0]);
|
||||||
} else {
|
} else {
|
||||||
result_shape = ov::PartialShape::dynamic(1);
|
result_shape.push_back(ov::Dimension::dynamic());
|
||||||
}
|
|
||||||
|
|
||||||
if (input_shape.rank().compatible(2)) {
|
|
||||||
result_shape.insert(result_shape.begin(), input_shape[0]);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result_shape = ov::PartialShape::dynamic();
|
result_shape = ov::PartialShape::dynamic();
|
||||||
|
@ -8,17 +8,18 @@
|
|||||||
|
|
||||||
#include "common_test_utils/test_assertions.hpp"
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
#include "common_test_utils/type_prop.hpp"
|
#include "common_test_utils/type_prop.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
|
||||||
using namespace testing;
|
using namespace testing;
|
||||||
|
|
||||||
class TypePropMultinomialV13Test : public TypePropOpTest<ov::op::v13::Multinomial> {};
|
class TypePropMultinomialV13Test : public TypePropOpTest<ov::op::v13::Multinomial> {};
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, input_probs_f64_num_samples_i32_convert_i32) {
|
TEST_F(TypePropMultinomialV13Test, input_probs_const_f64_num_samples_i32_convert_i32) {
|
||||||
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f64, ov::Shape{4});
|
const auto probs = ov::op::v0::Constant::create(ov::element::f64, ov::Shape{2, 2}, {1.0f, 1.0f, 1.0f, 1.0f});
|
||||||
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
||||||
const auto op = make_op(probs, num_samples, ov::element::i32, false, false, 0, 0);
|
const auto op = make_op(probs, num_samples, ov::element::i32, false, false, 0, 0);
|
||||||
EXPECT_EQ(op->get_element_type(), ov::element::i32);
|
EXPECT_EQ(op->get_element_type(), ov::element::i32);
|
||||||
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape::dynamic(1)));
|
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{2, -1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_i32_convert_i64) {
|
TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_i32_convert_i64) {
|
||||||
@ -29,8 +30,16 @@ TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_i32_convert_i64)
|
|||||||
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{4, -1}));
|
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{4, -1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_const_i32_convert_i64) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4});
|
||||||
|
const auto num_samples = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {10});
|
||||||
|
const auto op = make_op(probs, num_samples, ov::element::i64, false, false, 0, 0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), ov::element::i64);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{4, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_data_type) {
|
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_data_type) {
|
||||||
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{4});
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{4, 4});
|
||||||
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
|
||||||
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
||||||
ov::NodeValidationFailure,
|
ov::NodeValidationFailure,
|
||||||
@ -38,23 +47,31 @@ TEST_F(TypePropMultinomialV13Test, probs_incompatibile_data_type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_data_type) {
|
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_data_type) {
|
||||||
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4});
|
||||||
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{});
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{});
|
||||||
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
||||||
ov::NodeValidationFailure,
|
ov::NodeValidationFailure,
|
||||||
HasSubstr("Expected integer type as element type for the 'num_samples' input."));
|
HasSubstr("Expected integer type as element type for the 'num_samples' input."));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_rank) {
|
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_rank_too_big) {
|
||||||
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4, 4});
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4, 4});
|
||||||
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
||||||
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
||||||
ov::NodeValidationFailure,
|
ov::NodeValidationFailure,
|
||||||
HasSubstr("The rank of the 'probs' tensor defining output shape must be either 1 or 2."));
|
HasSubstr("Input probabilities must be a 2D tensor."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_rank_too_small) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
||||||
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("Input probabilities must be a 2D tensor."));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_rank) {
|
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_rank) {
|
||||||
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4});
|
||||||
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1, 2});
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1, 2});
|
||||||
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
||||||
ov::NodeValidationFailure,
|
ov::NodeValidationFailure,
|
||||||
|
@ -205,6 +205,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
|
|||||||
{ "MatrixNms", Type::MatrixNms},
|
{ "MatrixNms", Type::MatrixNms},
|
||||||
{ "MulticlassNms", Type::MulticlassNms},
|
{ "MulticlassNms", Type::MulticlassNms},
|
||||||
{ "MulticlassNmsIEInternal", Type::MulticlassNms},
|
{ "MulticlassNmsIEInternal", Type::MulticlassNms},
|
||||||
|
{ "Multinomial", Type::Multinomial},
|
||||||
{ "Reference", Type::Reference},
|
{ "Reference", Type::Reference},
|
||||||
{ "Subgraph", Type::Subgraph},
|
{ "Subgraph", Type::Subgraph},
|
||||||
{ "PriorBox", Type::PriorBox},
|
{ "PriorBox", Type::PriorBox},
|
||||||
@ -321,6 +322,7 @@ std::string NameFromType(const Type type) {
|
|||||||
CASE(NonMaxSuppression);
|
CASE(NonMaxSuppression);
|
||||||
CASE(MatrixNms);
|
CASE(MatrixNms);
|
||||||
CASE(MulticlassNms);
|
CASE(MulticlassNms);
|
||||||
|
CASE(Multinomial);
|
||||||
CASE(Reference);
|
CASE(Reference);
|
||||||
CASE(Subgraph);
|
CASE(Subgraph);
|
||||||
CASE(PriorBox);
|
CASE(PriorBox);
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "caseless.hpp"
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "caseless.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
@ -105,6 +105,7 @@ enum class Type {
|
|||||||
NonMaxSuppression,
|
NonMaxSuppression,
|
||||||
MatrixNms,
|
MatrixNms,
|
||||||
MulticlassNms,
|
MulticlassNms,
|
||||||
|
Multinomial,
|
||||||
Subgraph,
|
Subgraph,
|
||||||
PriorBox,
|
PriorBox,
|
||||||
PriorBoxClustered,
|
PriorBoxClustered,
|
||||||
@ -262,5 +263,5 @@ std::string NameFromType(const Type type);
|
|||||||
|
|
||||||
std::string algToString(const Algorithm alg);
|
std::string algToString(const Algorithm alg);
|
||||||
|
|
||||||
} // namespace intel_cpu
|
} // namespace intel_cpu
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
287
src/plugins/intel_cpu/src/nodes/multinomial.cpp
Normal file
287
src/plugins/intel_cpu/src/nodes/multinomial.cpp
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "multinomial.hpp"
|
||||||
|
|
||||||
|
#include "ie_ngraph_utils.hpp"
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
#include "utils/bfloat16.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
namespace node {
|
||||||
|
|
||||||
|
Multinomial::Multinomial(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
|
||||||
|
: Node(op, context, NgraphShapeInferFactory(op, PortMask(NUM_SAMPLES_PORT))) {
|
||||||
|
std::string errorMessage;
|
||||||
|
if (!isSupportedOperation(op, errorMessage)) {
|
||||||
|
THROW_CPU_NODE_ERR(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto multinomial_op = as_type_ptr<op::v13::Multinomial>(op);
|
||||||
|
m_with_replacement = multinomial_op->get_with_replacement();
|
||||||
|
m_global_seed = multinomial_op->get_global_seed();
|
||||||
|
m_log_probs = multinomial_op->get_log_probs();
|
||||||
|
m_op_seed = multinomial_op->get_op_seed();
|
||||||
|
|
||||||
|
m_num_samples_precision = ov::element::i32;
|
||||||
|
m_output_precision = multinomial_op->get_convert_type();
|
||||||
|
|
||||||
|
constant = ConstantType::NoConst;
|
||||||
|
|
||||||
|
m_const_batch = op->get_input_partial_shape(PROBS_PORT)[0].is_static();
|
||||||
|
m_const_inputs[PROBS_PORT] = is_type<op::v0::Constant>(op->get_input_node_ptr(PROBS_PORT));
|
||||||
|
m_const_inputs[NUM_SAMPLES_PORT] = is_type<op::v0::Constant>(op->get_input_node_ptr(NUM_SAMPLES_PORT));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Multinomial::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||||
|
try {
|
||||||
|
if (op->get_type_info() != op::v13::Multinomial::get_type_info_static()) {
|
||||||
|
errorMessage = "Only Multinomial operation from the opset13 is supported by the CPU plugin.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multinomial::getSupportedDescriptors() {
|
||||||
|
if (getParentEdges().size() != 2) {
|
||||||
|
THROW_CPU_NODE_ERR("has incorrect number of input edges.");
|
||||||
|
}
|
||||||
|
if (getChildEdges().size() != 1) {
|
||||||
|
THROW_CPU_NODE_ERR("has incorrect number of output edges.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multinomial::initSupportedPrimitiveDescriptors() {
|
||||||
|
m_probs_precision = getOriginalInputPrecisionAtPort(PROBS_PORT);
|
||||||
|
if (!one_of(m_probs_precision, ov::element::f32, ov::element::f16, ov::element::bf16)) {
|
||||||
|
m_probs_precision = ov::element::f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
addSupportedPrimDesc({{LayoutType::ncsp, m_probs_precision, m_const_inputs[PROBS_PORT]},
|
||||||
|
{LayoutType::ncsp, m_num_samples_precision, m_const_inputs[NUM_SAMPLES_PORT]}},
|
||||||
|
{{LayoutType::ncsp, m_output_precision}},
|
||||||
|
ref_any);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Multinomial::getPrimitiveDescriptorType() const {
|
||||||
|
std::string str_type;
|
||||||
|
auto selectedPrimitiveDesc = getSelectedPrimitiveDescriptor();
|
||||||
|
|
||||||
|
impl_desc_type type = impl_desc_type::undef;
|
||||||
|
if (selectedPrimitiveDesc) {
|
||||||
|
type = selectedPrimitiveDesc->getImplementationType();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type == impl_desc_type::unknown)
|
||||||
|
str_type += "unknown_";
|
||||||
|
if ((type & impl_desc_type::jit) == impl_desc_type::jit)
|
||||||
|
str_type += "jit_";
|
||||||
|
if ((type & impl_desc_type::ref) == impl_desc_type::ref)
|
||||||
|
str_type += "ref_";
|
||||||
|
if ((type & impl_desc_type::avx512) == impl_desc_type::avx512)
|
||||||
|
str_type += "avx512_";
|
||||||
|
if ((type & impl_desc_type::avx2) == impl_desc_type::avx2)
|
||||||
|
str_type += "avx2_";
|
||||||
|
if ((type & impl_desc_type::sse42) == impl_desc_type::sse42)
|
||||||
|
str_type += "sse42_";
|
||||||
|
if ((type & impl_desc_type::any) == impl_desc_type::any)
|
||||||
|
str_type += "any_";
|
||||||
|
|
||||||
|
if (str_type.empty())
|
||||||
|
str_type += "undef_";
|
||||||
|
|
||||||
|
if (selectedPrimitiveDesc) {
|
||||||
|
str_type += m_output_precision.get_type_name();
|
||||||
|
} else {
|
||||||
|
str_type.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
return str_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Multinomial::needShapeInfer() const {
|
||||||
|
return !(m_const_inputs[NUM_SAMPLES_PORT] && m_const_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Multinomial::needPrepareParams() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multinomial::prepareParams() {
|
||||||
|
const auto& probs_shape = getParentEdgeAt(PROBS_PORT)->getMemory().getStaticDims();
|
||||||
|
const auto& num_samples_shape = getParentEdgeAt(NUM_SAMPLES_PORT)->getMemory().getStaticDims();
|
||||||
|
|
||||||
|
if (probs_shape.size() != 2) {
|
||||||
|
THROW_CPU_NODE_ERR("has incompatible 'probs' shape ",
|
||||||
|
PartialShape(probs_shape),
|
||||||
|
". Only 2D tensors are allowed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_samples_shape.size() != 1) {
|
||||||
|
THROW_CPU_NODE_ERR("has incompatible 'num_samples' shape ",
|
||||||
|
PartialShape(num_samples_shape),
|
||||||
|
". Only scalar and 1D single element tensors are allowed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_num_samples_precision == ov::element::i32) {
|
||||||
|
m_samples_count =
|
||||||
|
reinterpret_cast<const int32_t*>(getParentEdgeAt(NUM_SAMPLES_PORT)->getMemoryPtr()->getData())[0];
|
||||||
|
} else {
|
||||||
|
m_samples_count =
|
||||||
|
reinterpret_cast<const int64_t*>(getParentEdgeAt(NUM_SAMPLES_PORT)->getMemoryPtr()->getData())[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
m_batches_count = probs_shape[0];
|
||||||
|
m_probs_count = probs_shape[1];
|
||||||
|
m_samples_probs_count = m_samples_count * m_probs_count;
|
||||||
|
m_input_elements_count = m_batches_count * m_probs_count;
|
||||||
|
m_output_elements_count = m_batches_count * m_samples_count;
|
||||||
|
m_batches_samples_probs_count = m_output_elements_count * m_probs_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Multinomial::isExecutable() const {
|
||||||
|
return !isInputTensorAtPortEmpty(PROBS_PORT) && !isInputTensorAtPortEmpty(NUM_SAMPLES_PORT);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Multinomial::created() const {
|
||||||
|
return getType() == Type::Multinomial;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multinomial::execute(dnnl::stream strm) {
|
||||||
|
switch (m_probs_precision) {
|
||||||
|
case ov::element::f32:
|
||||||
|
return execute_probs_type<float>();
|
||||||
|
case ov::element::f16:
|
||||||
|
return execute_probs_type<float16>();
|
||||||
|
case ov::element::bf16:
|
||||||
|
return execute_probs_type<bfloat16_t>();
|
||||||
|
default:
|
||||||
|
THROW_CPU_NODE_ERR("Multinomial CPU implementation does not support probs element type: ", m_probs_precision);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multinomial::executeDynamicImpl(dnnl::stream strm) {
|
||||||
|
execute(strm);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
void Multinomial::execute_probs_type() {
|
||||||
|
switch (m_output_precision) {
|
||||||
|
case ov::element::i32:
|
||||||
|
return execute_convert_type<P, int32_t>();
|
||||||
|
default:
|
||||||
|
THROW_CPU_NODE_ERR("Multinomial CPU implementation does not support output convert type: ", m_output_precision);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P, typename O>
|
||||||
|
void Multinomial::execute_convert_type() {
|
||||||
|
const auto* probs = reinterpret_cast<const P*>(getParentEdgeAt(PROBS_PORT)->getMemoryPtr()->getData());
|
||||||
|
auto* output = reinterpret_cast<O*>(getChildEdgeAt(OUTPUT_PORT)->getMemoryPtr()->getData());
|
||||||
|
|
||||||
|
std::vector<P> m_cdf(m_input_elements_count);
|
||||||
|
std::vector<P> m_max_per_batch(m_batches_count);
|
||||||
|
std::vector<P> m_random_samples(m_output_elements_count);
|
||||||
|
|
||||||
|
// exp & cumsum
|
||||||
|
if (m_log_probs) {
|
||||||
|
parallel_for(m_batches_count, [&](size_t idx) {
|
||||||
|
const auto start_idx = idx * m_probs_count;
|
||||||
|
m_cdf[start_idx] = std::exp(probs[start_idx]);
|
||||||
|
for (size_t prev = start_idx, curr = prev + 1; curr < (start_idx + m_probs_count); ++prev, ++curr) {
|
||||||
|
m_cdf[curr] = std::exp(probs[curr]) + m_cdf[prev];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
parallel_for(m_batches_count, [&](size_t idx_batch) {
|
||||||
|
const auto start_idx = idx_batch * m_probs_count;
|
||||||
|
const auto* probs_start_idx = probs + start_idx;
|
||||||
|
std::partial_sum(probs_start_idx, probs_start_idx + m_probs_count, m_cdf.begin() + start_idx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO RandomUniform - should use RandomUniform kernel to match other frameworks' seed results
|
||||||
|
std::mt19937 gen;
|
||||||
|
if (m_global_seed == 0 && m_op_seed == 0) {
|
||||||
|
gen.seed(std::time(NULL));
|
||||||
|
} else {
|
||||||
|
std::seed_seq seed{m_global_seed, m_op_seed};
|
||||||
|
gen.seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto gen_max = static_cast<float>(gen.max());
|
||||||
|
std::generate(m_random_samples.begin(), m_random_samples.end(), [&]() {
|
||||||
|
return static_cast<P>(static_cast<float>(gen()) / gen_max);
|
||||||
|
});
|
||||||
|
|
||||||
|
// max & divide
|
||||||
|
const auto min_value_of_max = std::numeric_limits<P>::min();
|
||||||
|
parallel_for(m_batches_count, [&](size_t idx) {
|
||||||
|
m_max_per_batch[idx] = std::max(m_cdf[(idx + 1) * m_probs_count - 1], min_value_of_max);
|
||||||
|
});
|
||||||
|
|
||||||
|
parallel_for(m_input_elements_count, [&](size_t idx) {
|
||||||
|
size_t idx_max_elem = idx / m_probs_count;
|
||||||
|
m_cdf[idx] = m_cdf[idx] / m_max_per_batch[idx_max_elem];
|
||||||
|
});
|
||||||
|
|
||||||
|
if (m_with_replacement) {
|
||||||
|
parallel_for(m_batches_samples_probs_count, [&](size_t idx) {
|
||||||
|
size_t idx_batch = idx / m_samples_probs_count;
|
||||||
|
size_t idx_num_samples_probs = idx % m_samples_probs_count;
|
||||||
|
size_t idx_prob = idx_num_samples_probs % m_probs_count;
|
||||||
|
size_t idx_sample = idx_num_samples_probs / m_probs_count;
|
||||||
|
|
||||||
|
size_t idx_input = idx_batch * m_probs_count + idx_prob;
|
||||||
|
size_t idx_output = idx_batch * m_samples_count + idx_sample;
|
||||||
|
if (m_random_samples[idx_output] <= m_cdf[idx_input] &&
|
||||||
|
(!idx_prob || m_random_samples[idx_output] > m_cdf[idx_input - 1])) {
|
||||||
|
output[idx_output] = static_cast<O>(idx_prob);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else { // without replacement - adjust cdf after each sample drawn from batch, sequentially
|
||||||
|
parallel_for(m_batches_count, [&](size_t idx_batch) {
|
||||||
|
for (size_t idx_sample = 0LU; idx_sample < m_samples_count; ++idx_sample) {
|
||||||
|
size_t idx_input = idx_batch * m_probs_count;
|
||||||
|
size_t idx_output = idx_batch * m_samples_count + idx_sample;
|
||||||
|
|
||||||
|
bool class_selected = false;
|
||||||
|
size_t selected_class = m_probs_count;
|
||||||
|
P sample_value = m_random_samples[idx_output];
|
||||||
|
for (size_t idx_prob = 0LU; idx_prob < m_probs_count; ++idx_prob) {
|
||||||
|
if (sample_value <= m_cdf[idx_input + idx_prob]) {
|
||||||
|
output[idx_output] = static_cast<O>(idx_prob);
|
||||||
|
selected_class = idx_prob;
|
||||||
|
class_selected = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (class_selected) {
|
||||||
|
P class_probability;
|
||||||
|
if (selected_class) {
|
||||||
|
class_probability = m_cdf[idx_input + selected_class] - m_cdf[idx_input + selected_class - 1];
|
||||||
|
} else {
|
||||||
|
class_probability = m_cdf[idx_input];
|
||||||
|
}
|
||||||
|
P divisor = 1 - class_probability;
|
||||||
|
for (size_t idx_prob = 0LU; idx_prob < m_probs_count; ++idx_prob) {
|
||||||
|
if (idx_prob >= selected_class) {
|
||||||
|
m_cdf[idx_input + idx_prob] = m_cdf[idx_input + idx_prob] - class_probability;
|
||||||
|
}
|
||||||
|
m_cdf[idx_input + idx_prob] = m_cdf[idx_input + idx_prob] / divisor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace node
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
80
src/plugins/intel_cpu/src/nodes/multinomial.hpp
Normal file
80
src/plugins/intel_cpu/src/nodes/multinomial.hpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "ie_common.h"
|
||||||
|
#include "ie_parallel.hpp"
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
namespace node {
|
||||||
|
|
||||||
|
class Multinomial : public Node {
|
||||||
|
public:
|
||||||
|
Multinomial(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context);
|
||||||
|
|
||||||
|
void getSupportedDescriptors() override;
|
||||||
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
|
std::string getPrimitiveDescriptorType() const override;
|
||||||
|
|
||||||
|
bool created() const override;
|
||||||
|
|
||||||
|
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
|
||||||
|
bool needPrepareParams() const override;
|
||||||
|
void prepareParams() override;
|
||||||
|
|
||||||
|
bool isExecutable() const override;
|
||||||
|
void execute(dnnl::stream strm) override;
|
||||||
|
void executeDynamicImpl(dnnl::stream strm) override;
|
||||||
|
bool canBeInPlace() const override {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool needShapeInfer() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Multinomial params
|
||||||
|
bool m_with_replacement = false;
|
||||||
|
bool m_log_probs = false;
|
||||||
|
uint64_t m_global_seed = 0;
|
||||||
|
uint64_t m_op_seed = 0;
|
||||||
|
|
||||||
|
/// Shape inference
|
||||||
|
static constexpr size_t PROBS_PORT = 0lu;
|
||||||
|
static constexpr size_t NUM_SAMPLES_PORT = 1lu;
|
||||||
|
static constexpr size_t OUTPUT_PORT = 0lu;
|
||||||
|
bool m_const_inputs[2] = {false, false};
|
||||||
|
bool m_const_batch = false;
|
||||||
|
VectorDims m_output_shape = {};
|
||||||
|
|
||||||
|
/// General algorithm variables
|
||||||
|
ov::element::Type m_probs_precision;
|
||||||
|
ov::element::Type m_num_samples_precision;
|
||||||
|
ov::element::Type m_output_precision;
|
||||||
|
|
||||||
|
size_t m_probs_count = 0;
|
||||||
|
size_t m_batches_count = 0;
|
||||||
|
size_t m_samples_count = 0;
|
||||||
|
size_t m_samples_probs_count = 0;
|
||||||
|
size_t m_input_elements_count = 0;
|
||||||
|
size_t m_output_elements_count = 0;
|
||||||
|
size_t m_batches_samples_probs_count = 0;
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
void execute_probs_type();
|
||||||
|
|
||||||
|
template <typename P, typename O>
|
||||||
|
void execute_convert_type();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace node
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
@ -2,108 +2,107 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "nodes/reference.h"
|
|
||||||
#include "nodes/shapeof.h"
|
|
||||||
#include "nodes/batch_to_space.h"
|
|
||||||
#include "nodes/multiclass_nms.hpp"
|
|
||||||
#include "nodes/adaptive_pooling.h"
|
#include "nodes/adaptive_pooling.h"
|
||||||
#include "nodes/conv.h"
|
#include "nodes/batch_to_space.h"
|
||||||
#include "nodes/roi_align.h"
|
|
||||||
#include "nodes/lrn.h"
|
|
||||||
#include "nodes/generic.h"
|
|
||||||
#include "nodes/experimental_detectron_roifeatureextractor.h"
|
|
||||||
#include "nodes/eltwise.h"
|
|
||||||
#include "nodes/reorg_yolo.h"
|
|
||||||
#include "nodes/pooling.h"
|
|
||||||
#include "nodes/transpose.h"
|
|
||||||
#include "nodes/grn.h"
|
|
||||||
#include "nodes/interpolate.h"
|
|
||||||
#include "nodes/experimental_detectron_detection_output.h"
|
|
||||||
#include "nodes/roll.h"
|
|
||||||
#include "nodes/fake_quantize.h"
|
|
||||||
#include "nodes/embedding_segments_sum.h"
|
|
||||||
#include "nodes/region_yolo.h"
|
|
||||||
#include "nodes/matmul.h"
|
|
||||||
#include "nodes/detection_output.h"
|
|
||||||
#include "nodes/reverse_sequence.h"
|
|
||||||
#include "nodes/pad.h"
|
|
||||||
#include "nodes/ctc_greedy_decoder_seq_len.h"
|
|
||||||
#include "nodes/reshape.h"
|
|
||||||
#include "nodes/psroi_pooling.h"
|
|
||||||
#include "nodes/memory.hpp"
|
|
||||||
#include "nodes/bin_conv.h"
|
#include "nodes/bin_conv.h"
|
||||||
#include "nodes/gather_elements.h"
|
|
||||||
#include "nodes/experimental_detectron_priorgridgenerator.h"
|
|
||||||
#include "nodes/tile.h"
|
|
||||||
#include "nodes/mathematics.h"
|
|
||||||
#include "nodes/normalize.h"
|
|
||||||
#include "nodes/proposal.h"
|
|
||||||
#include "nodes/tensoriterator.h"
|
|
||||||
#include "nodes/fullyconnected.h"
|
|
||||||
#include "nodes/extract_image_patches.h"
|
|
||||||
#include "nodes/ctc_loss.h"
|
|
||||||
#include "nodes/reorder.h"
|
|
||||||
#include "nodes/gather_nd.h"
|
|
||||||
#include "nodes/shuffle_channels.h"
|
|
||||||
#include "nodes/bucketize.h"
|
|
||||||
#include "nodes/space_to_depth.h"
|
|
||||||
#include "nodes/concat.h"
|
|
||||||
#include "nodes/softmax.h"
|
|
||||||
#include "nodes/space_to_batch.h"
|
|
||||||
#include "nodes/topk.h"
|
|
||||||
#include "nodes/broadcast.h"
|
#include "nodes/broadcast.h"
|
||||||
#include "nodes/matrix_nms.h"
|
#include "nodes/bucketize.h"
|
||||||
#include "nodes/mvn.h"
|
|
||||||
#include "nodes/gather.h"
|
|
||||||
#include "nodes/grid_sample.hpp"
|
|
||||||
#include "nodes/scatter_update.h"
|
|
||||||
#include "nodes/gather_tree.h"
|
|
||||||
#include "nodes/def_conv.h"
|
|
||||||
#include "nodes/embedding_bag_offset_sum.h"
|
|
||||||
#include "nodes/deconv.h"
|
|
||||||
#include "nodes/roi_pooling.h"
|
|
||||||
#include "nodes/range.h"
|
|
||||||
#include "nodes/split.h"
|
|
||||||
#include "nodes/one_hot.h"
|
|
||||||
#include "nodes/log_softmax.h"
|
|
||||||
#include "nodes/strided_slice.h"
|
|
||||||
#include "nodes/dft.h"
|
|
||||||
#include "nodes/rdft.h"
|
|
||||||
#include "nodes/non_max_suppression.h"
|
|
||||||
#include "nodes/convert.h"
|
|
||||||
#include "nodes/rnn.h"
|
|
||||||
#include "nodes/experimental_detectron_topkrois.h"
|
|
||||||
#include "nodes/cum_sum.h"
|
|
||||||
#include "nodes/depth_to_space.h"
|
|
||||||
#include "nodes/input.h"
|
|
||||||
#include "nodes/experimental_detectron_generate_proposals_single_image.h"
|
|
||||||
#include "nodes/generate_proposals.h"
|
|
||||||
#include "nodes/embedding_bag_packed_sum.h"
|
|
||||||
#include "nodes/random_uniform.hpp"
|
|
||||||
#include "nodes/reduce.h"
|
|
||||||
#include "nodes/if.h"
|
|
||||||
#include "nodes/ctc_greedy_decoder.h"
|
|
||||||
#include "nodes/non_zero.h"
|
|
||||||
#include "nodes/color_convert.h"
|
#include "nodes/color_convert.h"
|
||||||
#include "nodes/subgraph.h"
|
#include "nodes/concat.h"
|
||||||
|
#include "nodes/conv.h"
|
||||||
|
#include "nodes/convert.h"
|
||||||
|
#include "nodes/ctc_greedy_decoder.h"
|
||||||
|
#include "nodes/ctc_greedy_decoder_seq_len.h"
|
||||||
|
#include "nodes/ctc_loss.h"
|
||||||
|
#include "nodes/cum_sum.h"
|
||||||
|
#include "nodes/deconv.h"
|
||||||
|
#include "nodes/def_conv.h"
|
||||||
|
#include "nodes/depth_to_space.h"
|
||||||
|
#include "nodes/detection_output.h"
|
||||||
|
#include "nodes/dft.h"
|
||||||
|
#include "nodes/eltwise.h"
|
||||||
|
#include "nodes/embedding_bag_offset_sum.h"
|
||||||
|
#include "nodes/embedding_bag_packed_sum.h"
|
||||||
|
#include "nodes/embedding_segments_sum.h"
|
||||||
|
#include "nodes/experimental_detectron_detection_output.h"
|
||||||
|
#include "nodes/experimental_detectron_generate_proposals_single_image.h"
|
||||||
|
#include "nodes/experimental_detectron_priorgridgenerator.h"
|
||||||
|
#include "nodes/experimental_detectron_roifeatureextractor.h"
|
||||||
|
#include "nodes/experimental_detectron_topkrois.h"
|
||||||
|
#include "nodes/extract_image_patches.h"
|
||||||
|
#include "nodes/eye.h"
|
||||||
|
#include "nodes/fake_quantize.h"
|
||||||
|
#include "nodes/fullyconnected.h"
|
||||||
|
#include "nodes/gather.h"
|
||||||
|
#include "nodes/gather_elements.h"
|
||||||
|
#include "nodes/gather_nd.h"
|
||||||
|
#include "nodes/gather_tree.h"
|
||||||
|
#include "nodes/generate_proposals.h"
|
||||||
|
#include "nodes/generic.h"
|
||||||
|
#include "nodes/grid_sample.hpp"
|
||||||
|
#include "nodes/grn.h"
|
||||||
|
#include "nodes/if.h"
|
||||||
|
#include "nodes/input.h"
|
||||||
|
#include "nodes/interaction.h"
|
||||||
|
#include "nodes/interpolate.h"
|
||||||
|
#include "nodes/log_softmax.h"
|
||||||
|
#include "nodes/lrn.h"
|
||||||
|
#include "nodes/mathematics.h"
|
||||||
|
#include "nodes/matmul.h"
|
||||||
|
#include "nodes/matrix_nms.h"
|
||||||
|
#include "nodes/memory.hpp"
|
||||||
|
#include "nodes/mha.h"
|
||||||
|
#include "nodes/multiclass_nms.hpp"
|
||||||
|
#include "nodes/multinomial.hpp"
|
||||||
|
#include "nodes/mvn.h"
|
||||||
|
#include "nodes/ngram.h"
|
||||||
|
#include "nodes/non_max_suppression.h"
|
||||||
|
#include "nodes/non_zero.h"
|
||||||
|
#include "nodes/normalize.h"
|
||||||
|
#include "nodes/one_hot.h"
|
||||||
|
#include "nodes/pad.h"
|
||||||
|
#include "nodes/pooling.h"
|
||||||
#include "nodes/priorbox.h"
|
#include "nodes/priorbox.h"
|
||||||
#include "nodes/priorbox_clustered.h"
|
#include "nodes/priorbox_clustered.h"
|
||||||
#include "nodes/eye.h"
|
#include "nodes/proposal.h"
|
||||||
#include "nodes/interaction.h"
|
#include "nodes/psroi_pooling.h"
|
||||||
#include "nodes/mha.h"
|
#include "nodes/random_uniform.hpp"
|
||||||
#include "nodes/unique.hpp"
|
#include "nodes/range.h"
|
||||||
#include "nodes/ngram.h"
|
#include "nodes/rdft.h"
|
||||||
#include "nodes/scaled_attn.h"
|
#include "nodes/reduce.h"
|
||||||
|
#include "nodes/reference.h"
|
||||||
|
#include "nodes/region_yolo.h"
|
||||||
|
#include "nodes/reorder.h"
|
||||||
|
#include "nodes/reorg_yolo.h"
|
||||||
|
#include "nodes/reshape.h"
|
||||||
|
#include "nodes/reverse_sequence.h"
|
||||||
|
#include "nodes/rnn.h"
|
||||||
|
#include "nodes/roi_align.h"
|
||||||
|
#include "nodes/roi_pooling.h"
|
||||||
|
#include "nodes/roll.h"
|
||||||
#include "nodes/rope.h"
|
#include "nodes/rope.h"
|
||||||
|
#include "nodes/scaled_attn.h"
|
||||||
|
#include "nodes/scatter_update.h"
|
||||||
|
#include "nodes/shapeof.h"
|
||||||
|
#include "nodes/shuffle_channels.h"
|
||||||
|
#include "nodes/softmax.h"
|
||||||
|
#include "nodes/space_to_batch.h"
|
||||||
|
#include "nodes/space_to_depth.h"
|
||||||
|
#include "nodes/split.h"
|
||||||
|
#include "nodes/strided_slice.h"
|
||||||
|
#include "nodes/subgraph.h"
|
||||||
|
#include "nodes/tensoriterator.h"
|
||||||
|
#include "nodes/tile.h"
|
||||||
|
#include "nodes/topk.h"
|
||||||
|
#include "nodes/transpose.h"
|
||||||
|
#include "nodes/unique.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
|
|
||||||
#define INTEL_CPU_NODE(__prim, __type) \
|
#define INTEL_CPU_NODE(__prim, __type) registerNodeIfRequired(intel_cpu, __prim, __type, NodeImpl<__prim>)
|
||||||
registerNodeIfRequired(intel_cpu, __prim, __type, NodeImpl<__prim>)
|
|
||||||
|
|
||||||
Node::NodesFactory::NodesFactory()
|
Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
|
||||||
: Factory("NodesFactory") {
|
|
||||||
using namespace node;
|
using namespace node;
|
||||||
INTEL_CPU_NODE(Generic, Type::Generic);
|
INTEL_CPU_NODE(Generic, Type::Generic);
|
||||||
INTEL_CPU_NODE(CumSum, Type::CumSum);
|
INTEL_CPU_NODE(CumSum, Type::CumSum);
|
||||||
@ -136,7 +135,8 @@ Node::NodesFactory::NodesFactory()
|
|||||||
INTEL_CPU_NODE(ReorgYolo, Type::ReorgYolo);
|
INTEL_CPU_NODE(ReorgYolo, Type::ReorgYolo);
|
||||||
INTEL_CPU_NODE(EmbeddingSegmentsSum, Type::EmbeddingSegmentsSum);
|
INTEL_CPU_NODE(EmbeddingSegmentsSum, Type::EmbeddingSegmentsSum);
|
||||||
INTEL_CPU_NODE(ShapeOf, Type::ShapeOf);
|
INTEL_CPU_NODE(ShapeOf, Type::ShapeOf);
|
||||||
INTEL_CPU_NODE(ExperimentalDetectronGenerateProposalsSingleImage, Type::ExperimentalDetectronGenerateProposalsSingleImage);
|
INTEL_CPU_NODE(ExperimentalDetectronGenerateProposalsSingleImage,
|
||||||
|
Type::ExperimentalDetectronGenerateProposalsSingleImage);
|
||||||
INTEL_CPU_NODE(GenerateProposals, Type::GenerateProposals);
|
INTEL_CPU_NODE(GenerateProposals, Type::GenerateProposals);
|
||||||
INTEL_CPU_NODE(ReverseSequence, Type::ReverseSequence);
|
INTEL_CPU_NODE(ReverseSequence, Type::ReverseSequence);
|
||||||
INTEL_CPU_NODE(ExperimentalDetectronPriorGridGenerator, Type::ExperimentalDetectronPriorGridGenerator);
|
INTEL_CPU_NODE(ExperimentalDetectronPriorGridGenerator, Type::ExperimentalDetectronPriorGridGenerator);
|
||||||
@ -162,6 +162,7 @@ Node::NodesFactory::NodesFactory()
|
|||||||
INTEL_CPU_NODE(Reshape, Type::Reshape);
|
INTEL_CPU_NODE(Reshape, Type::Reshape);
|
||||||
INTEL_CPU_NODE(MVN, Type::MVN);
|
INTEL_CPU_NODE(MVN, Type::MVN);
|
||||||
INTEL_CPU_NODE(MatMul, Type::MatMul);
|
INTEL_CPU_NODE(MatMul, Type::MatMul);
|
||||||
|
INTEL_CPU_NODE(Multinomial, Type::Multinomial);
|
||||||
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterUpdate);
|
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterUpdate);
|
||||||
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterElementsUpdate);
|
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterElementsUpdate);
|
||||||
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterNDUpdate);
|
INTEL_CPU_NODE(ScatterUpdate, Type::ScatterNDUpdate);
|
||||||
@ -208,5 +209,5 @@ Node::NodesFactory::NodesFactory()
|
|||||||
|
|
||||||
#undef INTEL_CPU_NODE
|
#undef INTEL_CPU_NODE
|
||||||
|
|
||||||
} // namespace intel_cpu
|
} // namespace intel_cpu
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
// Copyright (C) 2018-2023 Intel Corporation
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
#include "shape_inference.hpp"
|
||||||
|
|
||||||
#include <ngraph/runtime/host_tensor.hpp>
|
#include <ngraph/runtime/host_tensor.hpp>
|
||||||
#include <openvino/core/node.hpp>
|
#include <openvino/core/node.hpp>
|
||||||
#include <openvino/opsets/opset1.hpp>
|
#include <openvino/opsets/opset1.hpp>
|
||||||
#include <openvino/opsets/opset10.hpp>
|
#include <openvino/opsets/opset10.hpp>
|
||||||
#include <openvino/opsets/opset11.hpp>
|
#include <openvino/opsets/opset11.hpp>
|
||||||
#include <openvino/opsets/opset12.hpp>
|
#include <openvino/opsets/opset12.hpp>
|
||||||
|
#include <openvino/opsets/opset13.hpp>
|
||||||
#include <openvino/opsets/opset2.hpp>
|
#include <openvino/opsets/opset2.hpp>
|
||||||
#include <openvino/opsets/opset3.hpp>
|
#include <openvino/opsets/opset3.hpp>
|
||||||
#include <openvino/opsets/opset4.hpp>
|
#include <openvino/opsets/opset4.hpp>
|
||||||
@ -68,6 +71,7 @@
|
|||||||
#include "matmul_shape_inference.hpp"
|
#include "matmul_shape_inference.hpp"
|
||||||
#include "matrix_nms_shape_inference.hpp"
|
#include "matrix_nms_shape_inference.hpp"
|
||||||
#include "max_pool_shape_inference.hpp"
|
#include "max_pool_shape_inference.hpp"
|
||||||
|
#include "multinomial_shape_inference.hpp"
|
||||||
#include "nms_shape_inference.hpp"
|
#include "nms_shape_inference.hpp"
|
||||||
#include "nv12_shape_inference.hpp"
|
#include "nv12_shape_inference.hpp"
|
||||||
#include "one_hot_shape_inference.hpp"
|
#include "one_hot_shape_inference.hpp"
|
||||||
@ -93,7 +97,6 @@
|
|||||||
#include "scatter_elements_update_shape_inference.hpp"
|
#include "scatter_elements_update_shape_inference.hpp"
|
||||||
#include "scatter_nd_base_shape_inference.hpp"
|
#include "scatter_nd_base_shape_inference.hpp"
|
||||||
#include "select_shape_inference.hpp"
|
#include "select_shape_inference.hpp"
|
||||||
#include "shape_inference.hpp"
|
|
||||||
#include "shape_nodes.hpp"
|
#include "shape_nodes.hpp"
|
||||||
#include "shuffle_channels_shape_inference.hpp"
|
#include "shuffle_channels_shape_inference.hpp"
|
||||||
#include "slice_shape_inference.hpp"
|
#include "slice_shape_inference.hpp"
|
||||||
@ -393,6 +396,8 @@ using IStaticShapeInferFactory =
|
|||||||
// To use other version of operators, explicitly specify operator with opset version namespace.
|
// To use other version of operators, explicitly specify operator with opset version namespace.
|
||||||
template <>
|
template <>
|
||||||
const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
|
const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
|
||||||
|
// opset13
|
||||||
|
_OV_OP_SHAPE_INFER_MASK_REG(opset13::Multinomial, ShapeInferTA, util::bit::mask(1)),
|
||||||
// opset12
|
// opset12
|
||||||
_OV_OP_SHAPE_INFER_MASK_REG(opset12::Pad, ShapeInferTA, util::bit::mask(1, 2)),
|
_OV_OP_SHAPE_INFER_MASK_REG(opset12::Pad, ShapeInferTA, util::bit::mask(1, 2)),
|
||||||
_OV_OP_SHAPE_INFER_MASK_REG(opset12::ScatterElementsUpdate, ShapeInferTA, util::bit::mask(3)),
|
_OV_OP_SHAPE_INFER_MASK_REG(opset12::ScatterElementsUpdate, ShapeInferTA, util::bit::mask(3)),
|
||||||
|
@ -146,11 +146,15 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
// Issue: 95607
|
// Issue: 95607
|
||||||
R"(.*CachingSupportCase.*LoadNetworkCacheTestBase.*(TIwithLSTMcell1|MatMulBias|2InputSubtract)_(i|u).*)",
|
R"(.*CachingSupportCase.*LoadNetworkCacheTestBase.*(TIwithLSTMcell1|MatMulBias|2InputSubtract)_(i|u).*)",
|
||||||
R"(.*CachingSupportCase.*ReadConcatSplitAssign.*)",
|
R"(.*CachingSupportCase.*ReadConcatSplitAssign.*)",
|
||||||
// 94982. FP32->I32 conversion issue in the reference implementation. There can be some garbage in the rest of float values like 0.333333745.
|
// 94982. FP32->I32 conversion issue in the reference implementation. There can be some garbage in the rest of
|
||||||
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion.
|
// float values like 0.333333745.
|
||||||
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation.
|
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects
|
||||||
|
// further type conversion.
|
||||||
|
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of
|
||||||
|
// truncation.
|
||||||
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)",
|
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)",
|
||||||
// AdaptiveAvgPool is converted into Reduce op for suitable parameters. CPU Reduce impl doesn't support non planar layout for 3D case
|
// AdaptiveAvgPool is converted into Reduce op for suitable parameters. CPU Reduce impl doesn't support non
|
||||||
|
// planar layout for 3D case
|
||||||
R"(.*StaticAdaPoolAvg3DLayoutTest.*OS=\(1\).*_inFmts=(nwc|nCw16c|nCw8c).*)",
|
R"(.*StaticAdaPoolAvg3DLayoutTest.*OS=\(1\).*_inFmts=(nwc|nCw16c|nCw8c).*)",
|
||||||
// Issue: 111404
|
// Issue: 111404
|
||||||
R"(.*smoke_set1/GatherElementsCPUTest.*)",
|
R"(.*smoke_set1/GatherElementsCPUTest.*)",
|
||||||
@ -217,6 +221,8 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
R"(.*smoke_Snippets_MHA_.?D_SplitDimensionM.*)",
|
R"(.*smoke_Snippets_MHA_.?D_SplitDimensionM.*)",
|
||||||
// Issue: 122356
|
// Issue: 122356
|
||||||
R"(.*NmsRotatedOpTest.*(SortDesc=True|Clockwise=False).*)",
|
R"(.*NmsRotatedOpTest.*(SortDesc=True|Clockwise=False).*)",
|
||||||
|
// Issue: 126095
|
||||||
|
R"(^smoke_Multinomial(?:Static|Dynamic)+(?:Log)*.*seed_g=0_seed_o=0.*device=CPU.*)",
|
||||||
#ifdef OPENVINO_ARCH_32_BIT
|
#ifdef OPENVINO_ARCH_32_BIT
|
||||||
// Issue: 126177
|
// Issue: 126177
|
||||||
R"(.*smoke_CompareWithRefs_4D_Bitwise.*/EltwiseLayerCPUTest.CompareWithRefs/.*_eltwiseOpType=Bitwise.*_NetType=i32_.*)"
|
R"(.*smoke_CompareWithRefs_4D_Bitwise.*/EltwiseLayerCPUTest.CompareWithRefs/.*_eltwiseOpType=Bitwise.*_NetType=i32_.*)"
|
||||||
@ -231,29 +237,39 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
#elif defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_ARM)
|
#elif defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_ARM)
|
||||||
{
|
{
|
||||||
// Issue: 121709
|
// Issue: 121709
|
||||||
retVector.emplace_back(R"(smoke_ConversionLayerTest/ConversionLayerTest.Inference/conversionOpType=Convert_IS.*_inputPRC=f16_targetPRC=(u|i)8_trgDev=CPU.*)");
|
retVector.emplace_back(
|
||||||
|
R"(smoke_ConversionLayerTest/ConversionLayerTest.Inference/conversionOpType=Convert_IS.*_inputPRC=f16_targetPRC=(u|i)8_trgDev=CPU.*)");
|
||||||
// Issue: 121710
|
// Issue: 121710
|
||||||
retVector.emplace_back(R"(smoke_GRUCellCommon/GRUCellTest.Inference/decomposition0_batch=5_.*WType=CONSTANT_RType=CONSTANT_BType=CONSTANT_netPRC=f16_targetDevice=CPU_.*)");
|
retVector.emplace_back(
|
||||||
|
R"(smoke_GRUCellCommon/GRUCellTest.Inference/decomposition0_batch=5_.*WType=CONSTANT_RType=CONSTANT_BType=CONSTANT_netPRC=f16_targetDevice=CPU_.*)");
|
||||||
// Issue: 121715
|
// Issue: 121715
|
||||||
retVector.emplace_back(R"(smoke_CompareWithRefs_static/EltwiseLayerTest.Inference/IS.*_eltwise_op_type=Div_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_InType=undefined_OutType=undefined_trgDev=CPU.*)");
|
retVector.emplace_back(
|
||||||
retVector.emplace_back(R"(smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS.*_eltwise_op_type=Div_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_InType=undefined_OutType=undefined_trgDev=CPU.*)");
|
R"(smoke_CompareWithRefs_static/EltwiseLayerTest.Inference/IS.*_eltwise_op_type=Div_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_InType=undefined_OutType=undefined_trgDev=CPU.*)");
|
||||||
|
retVector.emplace_back(
|
||||||
|
R"(smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS.*_eltwise_op_type=Div_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_InType=undefined_OutType=undefined_trgDev=CPU.*)");
|
||||||
// TODO: enable once streams / tput mode is supported
|
// TODO: enable once streams / tput mode is supported
|
||||||
retVector.emplace_back(R"(OVClassConfigTestCPU.smoke_CpuExecNetworkCheck(Model|Core)StreamsHasHigherPriorityThanLatencyHint.*)");
|
retVector.emplace_back(
|
||||||
retVector.emplace_back(R"(smoke_BehaviorTests/CorrectConfigCheck.canSetConfigAndCheckGetConfig.*CPU_THROUGHPUT_STREAMS=8.*)");
|
R"(OVClassConfigTestCPU.smoke_CpuExecNetworkCheck(Model|Core)StreamsHasHigherPriorityThanLatencyHint.*)");
|
||||||
retVector.emplace_back(R"(smoke_BehaviorTests/CorrectConfigCheck.canSetConfigTwiceAndCheckGetConfig.*CPU_THROUGHPUT_STREAMS=8.*)");
|
retVector.emplace_back(
|
||||||
retVector.emplace_back(R"(smoke_CPU_OVClassLoadNetworkAndCheckWithSecondaryPropertiesTest/OVClassLoadNetworkAndCheckSecondaryPropertiesTest.LoadNetworkAndCheckSecondaryPropertiesTest.*)");
|
R"(smoke_BehaviorTests/CorrectConfigCheck.canSetConfigAndCheckGetConfig.*CPU_THROUGHPUT_STREAMS=8.*)");
|
||||||
retVector.emplace_back(R"(smoke_CPU_OVClassLoadNetworkAndCheckWithSecondaryPropertiesDoubleTest/OVClassLoadNetworkAndCheckSecondaryPropertiesTest.LoadNetworkAndCheckSecondaryPropertiesTest.*)");
|
retVector.emplace_back(
|
||||||
|
R"(smoke_BehaviorTests/CorrectConfigCheck.canSetConfigTwiceAndCheckGetConfig.*CPU_THROUGHPUT_STREAMS=8.*)");
|
||||||
|
retVector.emplace_back(
|
||||||
|
R"(smoke_CPU_OVClassLoadNetworkAndCheckWithSecondaryPropertiesTest/OVClassLoadNetworkAndCheckSecondaryPropertiesTest.LoadNetworkAndCheckSecondaryPropertiesTest.*)");
|
||||||
|
retVector.emplace_back(
|
||||||
|
R"(smoke_CPU_OVClassLoadNetworkAndCheckWithSecondaryPropertiesDoubleTest/OVClassLoadNetworkAndCheckSecondaryPropertiesTest.LoadNetworkAndCheckSecondaryPropertiesTest.*)");
|
||||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
|
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckSecondaryPropertiesTest.*)");
|
||||||
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
|
retVector.emplace_back(R"(smoke_CPU_OVClassCompileModelAndCheckWithSecondaryPropertiesDoubleTest.*)");
|
||||||
// Issue: 123321
|
// Issue: 123321
|
||||||
retVector.emplace_back(R"(.*smoke_RNNSequenceCommonZeroClip/RNNSequenceTest.Inference.*hidden_size=1.*relu.*direction=reverse.*)");
|
retVector.emplace_back(
|
||||||
|
R"(.*smoke_RNNSequenceCommonZeroClip/RNNSequenceTest.Inference.*hidden_size=1.*relu.*direction=reverse.*)");
|
||||||
}
|
}
|
||||||
// invalid test: checks u8 precision for runtime graph, while it should be f32
|
// invalid test: checks u8 precision for runtime graph, while it should be f32
|
||||||
retVector.emplace_back(R"(smoke_NegativeQuantizedMatMulMultiplyFusion.*)");
|
retVector.emplace_back(R"(smoke_NegativeQuantizedMatMulMultiplyFusion.*)");
|
||||||
// int8 specific
|
// int8 specific
|
||||||
retVector.emplace_back(R"(smoke_Quantized.*)");
|
retVector.emplace_back(R"(smoke_Quantized.*)");
|
||||||
|
|
||||||
#if defined(OV_CPU_ARM_ENABLE_FP16)
|
# if defined(OV_CPU_ARM_ENABLE_FP16)
|
||||||
// Issue: 123019
|
// Issue: 123019
|
||||||
retVector.emplace_back(R"(smoke_CompareWithRefs_Mvn.*INFERENCE_PRECISION_HINT=f16.*)");
|
retVector.emplace_back(R"(smoke_CompareWithRefs_Mvn.*INFERENCE_PRECISION_HINT=f16.*)");
|
||||||
retVector.emplace_back(R"(smoke_staticShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
|
retVector.emplace_back(R"(smoke_staticShapes4D.*INFERENCE_PRECISION_HINT=f16.*)");
|
||||||
@ -264,10 +280,11 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
retVector.emplace_back(R"(.*smoke_BehaviorTests/InferRequestPerfCountersTest.CheckOperationInPerfMap.*)");
|
retVector.emplace_back(R"(.*smoke_BehaviorTests/InferRequestPerfCountersTest.CheckOperationInPerfMap.*)");
|
||||||
retVector.emplace_back(R"(smoke_BehaviorTests/ExecutableNetworkBaseTest.CheckExecGraphInfo.*)");
|
retVector.emplace_back(R"(smoke_BehaviorTests/ExecutableNetworkBaseTest.CheckExecGraphInfo.*)");
|
||||||
retVector.emplace_back(R"(smoke_BehaviorTests/OVCompiledModelBaseTestOptional.CheckExecGraphInfo.*)");
|
retVector.emplace_back(R"(smoke_BehaviorTests/OVCompiledModelBaseTestOptional.CheckExecGraphInfo.*)");
|
||||||
retVector.emplace_back(R"(smoke_ExecGraph/ExecGraphRuntimePrecision.CheckRuntimePrecision/Function=FakeQuantizeBinaryConvolution.*)");
|
retVector.emplace_back(
|
||||||
|
R"(smoke_ExecGraph/ExecGraphRuntimePrecision.CheckRuntimePrecision/Function=FakeQuantizeBinaryConvolution.*)");
|
||||||
// Issue: 124395
|
// Issue: 124395
|
||||||
retVector.emplace_back(R"(smoke_VariableStateBasic/InferRequestVariableStateTest.*)");
|
retVector.emplace_back(R"(smoke_VariableStateBasic/InferRequestVariableStateTest.*)");
|
||||||
#endif
|
# endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -281,7 +298,8 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
retVector.emplace_back(R"(.*OVInferConsistencyTest.*)");
|
retVector.emplace_back(R"(.*OVInferConsistencyTest.*)");
|
||||||
// TODO: generate new 'expected' runtime graph for non-x64 CPU
|
// TODO: generate new 'expected' runtime graph for non-x64 CPU
|
||||||
retVector.emplace_back(R"(smoke_serialization/ExecGraphSerializationTest.ExecutionGraph.*)");
|
retVector.emplace_back(R"(smoke_serialization/ExecGraphSerializationTest.ExecutionGraph.*)");
|
||||||
retVector.emplace_back(R"(smoke_ExecGraph/ExecGraphRuntimePrecision.CheckRuntimePrecision/Function=(EltwiseWithTwoDynamicInputs|FakeQuantizeRelu).*)");
|
retVector.emplace_back(
|
||||||
|
R"(smoke_ExecGraph/ExecGraphRuntimePrecision.CheckRuntimePrecision/Function=(EltwiseWithTwoDynamicInputs|FakeQuantizeRelu).*)");
|
||||||
// Issue 108803: bug in CPU scalar implementation
|
// Issue 108803: bug in CPU scalar implementation
|
||||||
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
|
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.CompareWithRefs.*)");
|
||||||
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.Inference.*)");
|
retVector.emplace_back(R"(smoke_TestsDFT_(1|2|3|4)d/DFTLayerTest.Inference.*)");
|
||||||
@ -312,13 +330,14 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
|
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
|
||||||
}
|
}
|
||||||
#elif defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_ARM)
|
#elif defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_ARM)
|
||||||
#if !defined(OV_CPU_ARM_ENABLE_FP16)
|
# if !defined(OV_CPU_ARM_ENABLE_FP16)
|
||||||
// Skip fp16 tests for paltforms that don't support fp16 precision
|
// Skip fp16 tests for paltforms that don't support fp16 precision
|
||||||
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
|
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
|
||||||
#else
|
# else
|
||||||
// Issue 117407
|
// Issue 117407
|
||||||
retVector.emplace_back(R"(.*EltwiseLayerCPUTest.*IS=\(\[1\.\.10\.2\.5\.6\]_\).*eltwiseOpType=SqDiff.*_configItem=INFERENCE_PRECISION_HINT=f16.*)");
|
retVector.emplace_back(
|
||||||
#endif // OV_CPU_ARM_ENABLE_FP16
|
R"(.*EltwiseLayerCPUTest.*IS=\(\[1\.\.10\.2\.5\.6\]_\).*eltwiseOpType=SqDiff.*_configItem=INFERENCE_PRECISION_HINT=f16.*)");
|
||||||
|
# endif // OV_CPU_ARM_ENABLE_FP16
|
||||||
#endif
|
#endif
|
||||||
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||||
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
|
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
|
||||||
@ -329,7 +348,7 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
retVector.emplace_back(R"(.*Snippets.*MHAQuant.*)");
|
retVector.emplace_back(R"(.*Snippets.*MHAQuant.*)");
|
||||||
}
|
}
|
||||||
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
|
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
|
||||||
//TODO: Issue 92895
|
// TODO: Issue 92895
|
||||||
// on platforms which do not support AMX, we are disabling I8 input tests
|
// on platforms which do not support AMX, we are disabling I8 input tests
|
||||||
retVector.emplace_back(R"(smoke_LPT/FakeQuantizeWithNotOptimalTransformation.CompareWithRefImpl.*CPU.*i8.*)");
|
retVector.emplace_back(R"(smoke_LPT/FakeQuantizeWithNotOptimalTransformation.CompareWithRefImpl.*CPU.*i8.*)");
|
||||||
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) {
|
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) {
|
||||||
|
@ -0,0 +1,131 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "single_op_tests/multinomial.hpp"
|
||||||
|
|
||||||
|
#include <openvino/core/type/element_type.hpp>
|
||||||
|
#include <openvino/runtime/tensor.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ov::test::MultinomialLayerTest;
|
||||||
|
|
||||||
|
std::vector<std::pair<uint64_t, uint64_t>> global_op_seed = {{1ul, 2ul}, {0ul, 0ul}};
|
||||||
|
|
||||||
|
std::vector<float> probs_4x4_f32 = {0.00001f,
|
||||||
|
0.001f,
|
||||||
|
0.1f,
|
||||||
|
10.0f,
|
||||||
|
0.001f,
|
||||||
|
0.00001f,
|
||||||
|
10.0f,
|
||||||
|
0.1f,
|
||||||
|
0.1f,
|
||||||
|
10.0f,
|
||||||
|
0.00001f,
|
||||||
|
0.001f,
|
||||||
|
10.0f,
|
||||||
|
0.1f,
|
||||||
|
0.001f,
|
||||||
|
0.00001f};
|
||||||
|
|
||||||
|
std::vector<ov::float16> probs_2x3_f16 = {ov::float16(0.001f),
|
||||||
|
ov::float16(0.1f),
|
||||||
|
ov::float16(10.0f),
|
||||||
|
ov::float16(10.0f),
|
||||||
|
ov::float16(0.001f),
|
||||||
|
ov::float16(0.1f)};
|
||||||
|
|
||||||
|
std::vector<ov::bfloat16> probs_1x3_bf16 = {ov::bfloat16(0.1f), ov::bfloat16(1.0f), ov::bfloat16(10.0f)};
|
||||||
|
|
||||||
|
std::vector<float> probs_4x4_f32_log =
|
||||||
|
{3.0f, 6.0f, 10.0f, 0.0f, 3.0f, 0.0f, 10.0f, 6.0f, 6.0f, 10.0f, 0.0f, 3.0f, 10.0f, 6.0f, 3.0f, 0.0f};
|
||||||
|
|
||||||
|
std::vector<ov::float16> probs_2x3_f16_log = {ov::float16(3.0f),
|
||||||
|
ov::float16(6.0f),
|
||||||
|
ov::float16(10.0f),
|
||||||
|
ov::float16(10.0f),
|
||||||
|
ov::float16(3.0f),
|
||||||
|
ov::float16(6.0f)};
|
||||||
|
|
||||||
|
std::vector<ov::bfloat16> probs_1x3_bf16_log = {ov::bfloat16(3.0f), ov::bfloat16(6.0f), ov::bfloat16(10.0f)};
|
||||||
|
|
||||||
|
std::vector<int> num_samples_scalar_i32 = {1};
|
||||||
|
std::vector<int64_t> num_samples_1x1_i64 = {2};
|
||||||
|
std::vector<int64_t> num_samples_scalar_i64 = {3};
|
||||||
|
|
||||||
|
const std::vector<ov::Tensor> probs = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32.data()),
|
||||||
|
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16.data()),
|
||||||
|
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16.data())};
|
||||||
|
|
||||||
|
const std::vector<ov::Tensor> probs_log = {ov::Tensor(ov::element::f32, {4, 4}, probs_4x4_f32_log.data()),
|
||||||
|
ov::Tensor(ov::element::f16, {2, 3}, probs_2x3_f16_log.data()),
|
||||||
|
ov::Tensor(ov::element::bf16, {1, 3}, probs_1x3_bf16_log.data())};
|
||||||
|
|
||||||
|
const std::vector<ov::Tensor> num_samples = {ov::Tensor(ov::element::i32, {}, num_samples_scalar_i32.data()),
|
||||||
|
ov::Tensor(ov::element::i64, {1}, num_samples_1x1_i64.data()),
|
||||||
|
ov::Tensor(ov::element::i64, {}, num_samples_scalar_i64.data())};
|
||||||
|
|
||||||
|
const std::vector<ov::test::ElementType> convert_type = {ov::test::ElementType::i32};
|
||||||
|
|
||||||
|
const std::vector<bool> with_replacement = {
|
||||||
|
// true,
|
||||||
|
false};
|
||||||
|
|
||||||
|
const auto params_static = ::testing::Combine(::testing::Values("static"),
|
||||||
|
::testing::ValuesIn(probs),
|
||||||
|
::testing::ValuesIn(num_samples),
|
||||||
|
::testing::ValuesIn(convert_type),
|
||||||
|
::testing::ValuesIn(with_replacement),
|
||||||
|
::testing::Values(false), // log_probs
|
||||||
|
::testing::ValuesIn(global_op_seed),
|
||||||
|
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||||
|
|
||||||
|
const auto params_static_log = ::testing::Combine(::testing::Values("static"),
|
||||||
|
::testing::ValuesIn(probs_log),
|
||||||
|
::testing::ValuesIn(num_samples),
|
||||||
|
::testing::ValuesIn(convert_type),
|
||||||
|
::testing::ValuesIn(with_replacement),
|
||||||
|
::testing::Values(true), // log_probs
|
||||||
|
::testing::ValuesIn(global_op_seed),
|
||||||
|
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||||
|
|
||||||
|
const auto params_dynamic = ::testing::Combine(::testing::Values("dynamic"),
|
||||||
|
::testing::ValuesIn(probs),
|
||||||
|
::testing::ValuesIn(num_samples),
|
||||||
|
::testing::ValuesIn(convert_type),
|
||||||
|
::testing::ValuesIn(with_replacement),
|
||||||
|
::testing::Values(false), // log_probs
|
||||||
|
::testing::ValuesIn(global_op_seed),
|
||||||
|
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||||
|
|
||||||
|
const auto params_dynamic_log = ::testing::Combine(::testing::Values("dynamic"),
|
||||||
|
::testing::ValuesIn(probs_log),
|
||||||
|
::testing::ValuesIn(num_samples),
|
||||||
|
::testing::ValuesIn(convert_type),
|
||||||
|
::testing::ValuesIn(with_replacement),
|
||||||
|
::testing::Values(true), // log_probs
|
||||||
|
::testing::ValuesIn(global_op_seed),
|
||||||
|
::testing::Values(ov::test::utils::DEVICE_CPU));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStatic,
|
||||||
|
MultinomialLayerTest,
|
||||||
|
params_static,
|
||||||
|
MultinomialLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialStaticLog,
|
||||||
|
MultinomialLayerTest,
|
||||||
|
params_static_log,
|
||||||
|
MultinomialLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialDynamic,
|
||||||
|
MultinomialLayerTest,
|
||||||
|
params_dynamic,
|
||||||
|
MultinomialLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MultinomialDynamicLog,
|
||||||
|
MultinomialLayerTest,
|
||||||
|
params_dynamic_log,
|
||||||
|
MultinomialLayerTest::getTestCaseName);
|
||||||
|
} // namespace
|
@ -11,21 +11,6 @@
|
|||||||
using namespace ov;
|
using namespace ov;
|
||||||
using namespace ov::intel_cpu;
|
using namespace ov::intel_cpu;
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest1D) {
|
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4});
|
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
|
|
||||||
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
|
||||||
|
|
||||||
// Test Static Shape 1D input
|
|
||||||
std::vector<StaticShape> static_input_shapes = {StaticShape{4}, StaticShape{1}};
|
|
||||||
int32_t num_elements_val = 2;
|
|
||||||
auto const_data =
|
|
||||||
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
|
||||||
auto acc = make_tensor_accessor(const_data);
|
|
||||||
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
|
|
||||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest2D) {
|
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest2D) {
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4, 4});
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4, 4});
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
|
||||||
@ -34,28 +19,12 @@ TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest2D) {
|
|||||||
// Test Static Shape 2D input
|
// Test Static Shape 2D input
|
||||||
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 4}, StaticShape{1}};
|
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 4}, StaticShape{1}};
|
||||||
int32_t num_elements_val = 2;
|
int32_t num_elements_val = 2;
|
||||||
auto const_data =
|
auto const_data = std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
|
||||||
auto acc = make_tensor_accessor(const_data);
|
auto acc = make_tensor_accessor(const_data);
|
||||||
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
|
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
|
||||||
ASSERT_EQ(static_output_shapes[0], StaticShape({4, 2}));
|
ASSERT_EQ(static_output_shapes[0], StaticShape({4, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown1D) {
|
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3});
|
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
|
|
||||||
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
|
||||||
|
|
||||||
// Test Partial Shape 1D input
|
|
||||||
std::vector<PartialShape> partial_input_shapes = {PartialShape{3}, PartialShape{1}};
|
|
||||||
int32_t num_elements_val = 2;
|
|
||||||
auto const_data =
|
|
||||||
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
|
||||||
auto acc = make_tensor_accessor(const_data);
|
|
||||||
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
|
|
||||||
ASSERT_EQ(partial_output_shapes[0], PartialShape({2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown2D) {
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown2D) {
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{2, 3});
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{2, 3});
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
|
||||||
@ -64,24 +33,12 @@ TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown2D
|
|||||||
// Test Partial Shape 2D input
|
// Test Partial Shape 2D input
|
||||||
std::vector<PartialShape> partial_input_shapes = {PartialShape{2, 3}, PartialShape{1}};
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{2, 3}, PartialShape{1}};
|
||||||
int32_t num_elements_val = 2;
|
int32_t num_elements_val = 2;
|
||||||
auto const_data =
|
auto const_data = std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
|
||||||
auto acc = make_tensor_accessor(const_data);
|
auto acc = make_tensor_accessor(const_data);
|
||||||
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
|
||||||
ASSERT_EQ(partial_output_shapes[0], PartialShape({2, 2}));
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({2, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples1D) {
|
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4});
|
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
|
||||||
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
|
||||||
|
|
||||||
// Test Partial Shape 1D input, unknown num_samples
|
|
||||||
std::vector<PartialShape> partial_input_shapes = {PartialShape{4}, PartialShape{-1}};
|
|
||||||
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
|
||||||
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples2D) {
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples2D) {
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4, 4});
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4, 4});
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
@ -93,17 +50,6 @@ TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSam
|
|||||||
ASSERT_EQ(partial_output_shapes[0], PartialShape({4, -1}));
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({4, -1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples1D) {
|
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
|
||||||
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
|
||||||
|
|
||||||
// Test Partial Shape 1D input, unknown num_samples and probs shape
|
|
||||||
std::vector<PartialShape> partial_input_shapes = {PartialShape{-1}, PartialShape{-1}};
|
|
||||||
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
|
||||||
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples2D) {
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples2D) {
|
||||||
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
@ -15,7 +15,6 @@ const std::vector<ov::test::ElementType> netPrecisions = {
|
|||||||
const std::vector<ov::Shape> inputShapes = {
|
const std::vector<ov::Shape> inputShapes = {
|
||||||
{1, 32},
|
{1, 32},
|
||||||
{2, 28},
|
{2, 28},
|
||||||
{32},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<int64_t> numSamples = {
|
const std::vector<int64_t> numSamples = {
|
||||||
|
@ -83,36 +83,37 @@ template <ov::element::Type_t et>
|
|||||||
std::vector<MultinomialParams> generateMultinomialParams() {
|
std::vector<MultinomialParams> generateMultinomialParams() {
|
||||||
using vt = typename ov::element_type_traits<et>::value_type;
|
using vt = typename ov::element_type_traits<et>::value_type;
|
||||||
|
|
||||||
const ov::Shape prob_2d_shape{2, 4};
|
|
||||||
const ov::Shape prob_1d_shape{4};
|
|
||||||
const ov::Shape num_samples_shape{1};
|
const ov::Shape num_samples_shape{1};
|
||||||
const ov::Shape prob_1d_shape_expand_small{2};
|
|
||||||
const ov::Shape out_1d_shape_expand_big{16};
|
|
||||||
|
|
||||||
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
|
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
|
||||||
reference_tests::Tensor num_samples_big(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{16});
|
reference_tests::Tensor num_samples_big(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{16});
|
||||||
|
|
||||||
|
const ov::Shape prob_2d_shape{2, 4};
|
||||||
|
const ov::Shape prob_pseudo_1d_shape{1, 4};
|
||||||
|
const ov::Shape prob_pseudo_1d_shape_expand_small{1, 2};
|
||||||
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
|
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
|
||||||
et,
|
et,
|
||||||
std::vector<vt>{0.001, 0.01, 0.1, 0.899, 0.899, 0.1, 0.01, 0.001});
|
std::vector<vt>{0.001, 0.01, 0.1, 0.899, 0.899, 0.1, 0.01, 0.001});
|
||||||
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
|
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
|
||||||
reference_tests::Tensor probabilities_1d_no_log(prob_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
|
reference_tests::Tensor probabilities_1d_no_log(prob_pseudo_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
|
||||||
reference_tests::Tensor probabilities_1d_log(prob_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
|
reference_tests::Tensor probabilities_1d_log(prob_pseudo_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
|
||||||
reference_tests::Tensor probabilities_1d_expand(prob_1d_shape_expand_small, et, std::vector<vt>{0.00001, 0.99999});
|
reference_tests::Tensor probabilities_1d_expand(prob_pseudo_1d_shape_expand_small,
|
||||||
|
et,
|
||||||
|
std::vector<vt>{0.00001, 0.99999});
|
||||||
|
|
||||||
|
const ov::Shape out_pseudo_1d_shape_expand{1, 16};
|
||||||
reference_tests::Tensor output_2d_no_log_replacement(prob_2d_shape,
|
reference_tests::Tensor output_2d_no_log_replacement(prob_2d_shape,
|
||||||
ov::element::Type_t::i32,
|
ov::element::Type_t::i32,
|
||||||
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
|
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
|
||||||
reference_tests::Tensor output_2d_log_replacement(prob_2d_shape,
|
reference_tests::Tensor output_2d_log_replacement(prob_2d_shape,
|
||||||
ov::element::Type_t::i32,
|
ov::element::Type_t::i32,
|
||||||
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
|
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
|
||||||
reference_tests::Tensor output_1d_no_log_no_replacement(prob_1d_shape,
|
reference_tests::Tensor output_1d_no_log_no_replacement(prob_pseudo_1d_shape,
|
||||||
ov::element::Type_t::i64,
|
ov::element::Type_t::i64,
|
||||||
std::vector<int64_t>{3, 2, 1, 0});
|
std::vector<int64_t>{3, 2, 1, 0});
|
||||||
reference_tests::Tensor output_1d_log_no_replacement(prob_1d_shape,
|
reference_tests::Tensor output_1d_log_no_replacement(prob_pseudo_1d_shape,
|
||||||
ov::element::Type_t::i64,
|
ov::element::Type_t::i64,
|
||||||
std::vector<int64_t>{1, 2, 3, 0});
|
std::vector<int64_t>{1, 2, 3, 0});
|
||||||
reference_tests::Tensor output_1d_expand(out_1d_shape_expand_big,
|
reference_tests::Tensor output_1d_expand(out_pseudo_1d_shape_expand,
|
||||||
ov::element::Type_t::i64,
|
ov::element::Type_t::i64,
|
||||||
std::vector<int64_t>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
std::vector<int64_t>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||||
|
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "shared_test_classes/single_op/multinomial.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace test {
|
||||||
|
TEST_P(MultinomialLayerTest, Inference) {
|
||||||
|
run();
|
||||||
|
};
|
||||||
|
} // namespace test
|
||||||
|
} // namespace ov
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/runtime/tensor.hpp"
|
||||||
|
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace test {
|
||||||
|
|
||||||
|
typedef std::tuple<std::string, // test type
|
||||||
|
ov::Tensor, // probs
|
||||||
|
ov::Tensor, // num_samples
|
||||||
|
ov::test::ElementType, // convert_type
|
||||||
|
bool, // with_replacement
|
||||||
|
bool, // log_probs
|
||||||
|
std::pair<uint64_t, uint64_t>, // global_op_seed
|
||||||
|
std::string // device_name
|
||||||
|
>
|
||||||
|
MultinomialTestParams;
|
||||||
|
|
||||||
|
class MultinomialLayerTest : public testing::WithParamInterface<MultinomialTestParams>,
|
||||||
|
virtual public SubgraphBaseTest {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<MultinomialTestParams>& obj);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
void generate_inputs(const std::vector<ov::Shape>& target_shapes) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ov::Tensor m_probs;
|
||||||
|
ov::Tensor m_num_samples;
|
||||||
|
};
|
||||||
|
} // namespace test
|
||||||
|
} // namespace ov
|
@ -0,0 +1,118 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "shared_test_classes/single_op/multinomial.hpp"
|
||||||
|
|
||||||
|
#include "ov_models/builders.hpp"
|
||||||
|
|
||||||
|
using namespace ov::test;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace test {
|
||||||
|
std::string MultinomialLayerTest::getTestCaseName(const testing::TestParamInfo<MultinomialTestParams>& obj) {
|
||||||
|
std::string test_type;
|
||||||
|
ov::Tensor probs;
|
||||||
|
ov::Tensor num_samples;
|
||||||
|
ov::test::ElementType convert_type;
|
||||||
|
bool with_replacement;
|
||||||
|
bool log_probs;
|
||||||
|
std::pair<uint64_t, uint64_t> global_op_seed;
|
||||||
|
std::string device_name;
|
||||||
|
|
||||||
|
std::tie(test_type, probs, num_samples, convert_type, with_replacement, log_probs, global_op_seed, device_name) =
|
||||||
|
obj.param;
|
||||||
|
|
||||||
|
uint64_t global_seed = global_op_seed.first;
|
||||||
|
uint64_t op_seed = global_op_seed.second;
|
||||||
|
|
||||||
|
const char separator = '_';
|
||||||
|
std::ostringstream result;
|
||||||
|
result << test_type << separator;
|
||||||
|
result << "probs_shape=" << probs.get_shape().to_string() << separator;
|
||||||
|
if (num_samples.get_element_type() == ov::test::ElementType::i32) {
|
||||||
|
result << "num_samples=" << static_cast<int*>(num_samples.data())[0] << separator;
|
||||||
|
} else { // i64
|
||||||
|
result << "num_samples=" << static_cast<long*>(num_samples.data())[0] << separator;
|
||||||
|
}
|
||||||
|
result << "convert_type=" << convert_type << separator;
|
||||||
|
result << "replace=" << ov::test::utils::bool2str(with_replacement) << separator;
|
||||||
|
result << "log=" << ov::test::utils::bool2str(log_probs) << separator;
|
||||||
|
result << "seed_g=" << global_seed << separator;
|
||||||
|
result << "seed_o=" << op_seed << separator;
|
||||||
|
result << "device=" << device_name;
|
||||||
|
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultinomialLayerTest::SetUp() {
|
||||||
|
MultinomialTestParams test_params;
|
||||||
|
|
||||||
|
std::string test_type;
|
||||||
|
ov::Tensor probs;
|
||||||
|
ov::Tensor num_samples;
|
||||||
|
ov::test::ElementType convert_type;
|
||||||
|
bool with_replacement;
|
||||||
|
bool log_probs;
|
||||||
|
std::pair<uint64_t, uint64_t> global_op_seed;
|
||||||
|
|
||||||
|
std::tie(test_type, probs, num_samples, convert_type, with_replacement, log_probs, global_op_seed, targetDevice) =
|
||||||
|
GetParam();
|
||||||
|
|
||||||
|
m_probs = probs;
|
||||||
|
m_num_samples = num_samples;
|
||||||
|
|
||||||
|
uint64_t global_seed = global_op_seed.first;
|
||||||
|
uint64_t op_seed = global_op_seed.second;
|
||||||
|
|
||||||
|
InputShape probs_shape;
|
||||||
|
InputShape num_samples_shape;
|
||||||
|
const ov::Shape probs_tensor_shape = probs.get_shape();
|
||||||
|
const ov::Shape num_samples_tensor_shape = num_samples.get_shape();
|
||||||
|
if (test_type == "static") {
|
||||||
|
probs_shape = {ov::PartialShape(probs_tensor_shape), {probs_tensor_shape}};
|
||||||
|
num_samples_shape = {ov::PartialShape(num_samples_tensor_shape), {num_samples_tensor_shape}};
|
||||||
|
} else { // dynamic
|
||||||
|
probs_shape = {ov::PartialShape::dynamic(ov::Rank(probs_tensor_shape.size())), {probs_tensor_shape}};
|
||||||
|
num_samples_shape = {ov::PartialShape::dynamic(ov::Rank(num_samples_tensor_shape.size())),
|
||||||
|
{num_samples_tensor_shape}};
|
||||||
|
}
|
||||||
|
init_input_shapes({probs_shape, num_samples_shape});
|
||||||
|
|
||||||
|
ov::ParameterVector params;
|
||||||
|
std::vector<std::shared_ptr<ov::Node>> inputs;
|
||||||
|
|
||||||
|
auto probs_param = std::make_shared<ov::op::v0::Parameter>(probs.get_element_type(), probs_shape.first);
|
||||||
|
probs_param->set_friendly_name("probs");
|
||||||
|
inputs.push_back(probs_param);
|
||||||
|
params.push_back(probs_param);
|
||||||
|
|
||||||
|
auto num_samples_param =
|
||||||
|
std::make_shared<ov::op::v0::Parameter>(num_samples.get_element_type(), num_samples_shape.first);
|
||||||
|
num_samples_param->set_friendly_name("num_samples");
|
||||||
|
inputs.push_back(num_samples_param);
|
||||||
|
params.push_back(num_samples_param);
|
||||||
|
|
||||||
|
auto multinomial = std::make_shared<ov::op::v13::Multinomial>(inputs[0],
|
||||||
|
inputs[1],
|
||||||
|
convert_type,
|
||||||
|
with_replacement,
|
||||||
|
log_probs,
|
||||||
|
global_seed,
|
||||||
|
op_seed);
|
||||||
|
|
||||||
|
ov::ResultVector results{std::make_shared<ov::opset10::Result>(multinomial)};
|
||||||
|
function = std::make_shared<ov::Model>(results, params, "Multinomial");
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultinomialLayerTest::generate_inputs(const std::vector<ov::Shape>& target_shapes) {
|
||||||
|
inputs.clear();
|
||||||
|
const auto& func_inputs = function->inputs();
|
||||||
|
|
||||||
|
auto& probs = func_inputs[0];
|
||||||
|
inputs.insert({probs.get_node_shared_ptr(), m_probs});
|
||||||
|
auto& num_samples = func_inputs[1];
|
||||||
|
inputs.insert({num_samples.get_node_shared_ptr(), m_num_samples});
|
||||||
|
}
|
||||||
|
} // namespace test
|
||||||
|
} // namespace ov
|
Loading…
Reference in New Issue
Block a user