[Ref][Core][Opset13] Add Multinomial Operation (#19655)

* [Ref] Multinomial base file

* [Ref] Add core & reference implementation reusing other ops

* [Ref] Fix reference implementation, add missing parameters, add tests

* [Core] Add opset13, register multinomial, add shape inference

* [Ref][Core] Fix compile errors

* [Ref][Core] Clang fix

* [TEMPLATE] Remove bf16, f16, f64 types

* [TEMPLATE] Remove incorrect input types for 'input' parameter

* [Ref][Tests] Remove deleted test types

* [Ref] Fix & optimize  shape inference

* [PT FE] Apply suggestions from review

* [Template] Migrate to new API

* [Core] Add a clause for dynamic input in shape inference

* [Tests] Add missing type_prop test (?)

* Update multinomial_shape_inference.hpp

* Update multinomial.hpp

* [Ref] Fix build issues

* [Ref] Fix clang and style

* [Ref] Fix tests without replacement

* [Ref] Fix with_replacement sampling error

* [Ref] Remove debugging artifacts

* [Ref] Cast to 64-bit size for 32-bit systems

* Update multinomial.hpp

* [Ref] Add missing type_prop tests, add shape inference tests

* Update multinomial.cpp

* Update multinomial_shape_inference_test.cpp

* Update multinomial.cpp

* Update multinomial.hpp

* [Ref] Fix compilation errors from shape inference test

* [Ref] Fix compilation error of type_prop, apply recommendations from review

* [Ref] Add multiple shape inference tests

* [Ref] Change TEST to TEST_F, add more type_prop tests

* [Ref] Clang fixes

* [Ref] Fix shape inference tests with mismatching args

* [Ref] Fix remaining type_prop errors

* [Ref] Replace HostTensor with normal Tensor in shape inference tests

* Update opset.cpp

* [Ref] Possible fix for 'function empty' error

* [Ref] Add a cast to remove conversion warning

* [Ref] Add conformance test of Multinomial

* [Ref] Match style of conf test to the remaining tests

* Update single_op_graph.cpp
This commit is contained in:
Piotr Krzemiński 2023-10-04 17:14:32 +02:00 committed by GitHub
parent bdb13aa28d
commit 48164e2279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 908 additions and 1 deletions

View File

@ -0,0 +1,67 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/op.hpp"
namespace ov {
namespace op {
namespace v13 {
/// \brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
///
/// \ingroup ov_ops_cpp_api
class OPENVINO_API Multinomial : public Op {
public:
OPENVINO_OP("Multinomial", "opset13");
Multinomial() = default;
/**
* @brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
*
* @param probs Input tensor containing at each index poisition probability/log probability of sampling a given
* class. Any floating-point precision values are allowed.
* @param num_samples Scalar or 1D tensor with a single value that determines the number of samples to generate per
* batch. Values should be of an integer type.
* @param convert_type Data type to which to convert the output class indices. Allowed values: i32/i64
* @param with_replacement Boolean that determines whether a sampled class can appear more than once in the output.
* @param log_probs Boolean that determines whether to treat input probabilities as log probabilities.
* @param global_seed First seed value (key) of Phillox random number generation algorithm. (See RandomUniform for
* details)
* @param op_seed Second seed value (counter) of Phillox random number generation algorithm. (See RandomUniform for
* details)
*/
Multinomial(const Output<Node>& input,
const Output<Node>& num_samples,
const ov::element::Type_t output_type,
const bool with_replacement,
const bool log_probs,
const uint64_t global_seed = 0,
const uint64_t op_seed = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
ov::element::Type_t get_convert_type() const;
bool get_with_replacement() const;
bool get_log_probs() const;
uint64_t get_global_seed() const;
uint64_t get_op_seed() const;
void set_convert_type(const ov::element::Type_t output_type);
void set_with_replacement(const bool with_replacement);
void set_log_probs(const bool log_probs);
void set_global_seed(const uint64_t global_seed);
void set_op_seed(const uint64_t op_seed);
private:
ov::element::Type_t m_convert_type;
bool m_with_replacement;
bool m_log_probs;
uint64_t m_global_seed;
uint64_t m_op_seed;
};
} // namespace v13
} // namespace op
} // namespace ov

View File

@ -110,6 +110,7 @@
#include "openvino/op/mish.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/multiclass_nms.hpp"
#include "openvino/op/multinomial.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/negative.hpp"

View File

@ -214,3 +214,4 @@ _OPENVINO_OP_REG(BitwiseNot, ov::op::v13)
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
_OPENVINO_OP_REG(Multinomial, ov::op::v13)

View File

@ -0,0 +1,163 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cmath>
#include <cstddef>
#include "openvino/reference/broadcast.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/copy.hpp"
#include "openvino/reference/cum_sum.hpp"
#include "openvino/reference/divide.hpp"
#include "openvino/reference/exp.hpp"
#include "openvino/reference/random_uniform.hpp"
#include "openvino/reference/slice.hpp"
namespace ov {
namespace reference {
namespace multinomial {
/**
* @brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
*
* @tparam T Data type of the probs' values.
* @tparam U Data type of num_samples' values.
* @tparam V Data type of output's values.
* @param probs Input tensor containing at each index poisition probability/log probability of sampling a given class.
* @param probs_shape Shape of the 'probs' tensor.
* @param num_samples Scalar or 1D tensor with a single value that determines the number of samples to generate per
* batch.
* @param num_samples_shape Shape of the 'num_samples' tensor.
* @param output Output tensor for the generated class indices.
* @param output_shape Shape of the 'output' tensor.
* @param with_replacement Boolean that determines whether a sampled class can appear more than once in the output.
* @param log_probs Boolean that determines whether to treat input probabilities as log probabilities.
* @param global_seed First seed value (key) of Phillox random number generation algorithm. (See RandomUniform for
* details)
* @param op_seed Second seed value (counter) of Phillox random number generation algorithm. (See RandomUniform for
* details)
*/
template <typename T, typename U, typename V>
void multinomial(const T* probs,
const Shape& probs_shape,
const U* num_samples,
const Shape& num_samples_shape,
V* output,
const Shape& output_shape,
const bool with_replacement,
const bool log_probs,
const uint64_t global_seed,
const uint64_t op_seed) {
const auto total_inputs_elements_count = shape_size<Shape>(probs_shape);
const auto total_output_elements_count = shape_size<Shape>(output_shape);
// If probabilities are log probabilities, exponentiate to get normal probabilities
std::vector<T> input_vals(total_inputs_elements_count);
if (log_probs) {
exp(probs, input_vals.data(), total_inputs_elements_count);
} else {
copy(probs, input_vals.data(), total_inputs_elements_count);
}
// Create a cdf of probabilties on the last axis, per batch. Note cumsum exclusive == false
std::vector<T> cdf(total_inputs_elements_count);
const auto last_axis = probs_shape.size() - 1;
cumsum(input_vals.data(), last_axis, cdf.data(), probs_shape, false, false);
// Obtain max value from cdf, per batch (from cumsum it is the last element)
std::vector<T> max_value_per_batch(total_inputs_elements_count / probs_shape[last_axis]);
Shape max_value_per_batch_shape(probs_shape);
max_value_per_batch_shape[last_axis] = 1;
const std::vector<int64_t> start{static_cast<int64_t>(probs_shape[last_axis] - 1)};
const std::vector<int64_t> step{1};
const std::vector<int64_t> target_axis_vec{static_cast<int64_t>(last_axis)};
slice(reinterpret_cast<const char*>(cdf.data()),
probs_shape, // == cdf shape
reinterpret_cast<char*>(max_value_per_batch.data()),
max_value_per_batch_shape,
sizeof(T),
start,
step,
target_axis_vec);
// Normalize the cdf by dividing all elements by the max value in each batch
std::vector<T> max_value_per_batch_divisor(total_inputs_elements_count);
ov::AxisSet target_axis_set = ov::AxisSet({last_axis});
broadcast(reinterpret_cast<const char*>(max_value_per_batch.data()),
reinterpret_cast<char*>(max_value_per_batch_divisor.data()),
max_value_per_batch_shape,
probs_shape, // expand to original shape (expands last dim)
target_axis_set,
sizeof(T));
divide(cdf.data(), max_value_per_batch_divisor.data(), cdf.data(), total_inputs_elements_count, false);
// Generate random probability samples
std::vector<double> uniform_samples(total_output_elements_count);
const double zero = 0;
const double one = 1;
const ov::Shape output_shape_shape{output_shape.size()};
const std::vector<uint64_t> output_shape_u64(output_shape.begin(), output_shape.end());
const std::pair<uint64_t, uint64_t> initial_state(0, 0);
random_uniform(output_shape_u64.data(),
reinterpret_cast<const char*>(&zero),
reinterpret_cast<const char*>(&one),
reinterpret_cast<char*>(uniform_samples.data()),
output_shape_shape,
ov::element::f64,
global_seed,
op_seed,
initial_state);
auto batch_size = probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[0]) : static_cast<size_t>(1);
auto class_size =
probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[1]) : static_cast<size_t>(probs_shape[0]);
auto samples_size =
probs_shape.size() == 2 ? static_cast<size_t>(num_samples[0]) : static_cast<size_t>(probs_shape[0]);
// Iterate over each channel in uniform samples
std::vector<U> output_samples(total_output_elements_count);
for (size_t i = 0; i < batch_size * samples_size; i += samples_size) {
for (size_t j = 0; j < samples_size; ++j) {
// Iterate over cdf to find the index for a given sample
// If no class found (all have 0 probability), selects last - undefined behavior
auto i_translated = i / samples_size * class_size;
auto selected_class_idx = class_size;
auto sample_value = uniform_samples[i + j];
for (size_t k = 0; k < class_size; ++k) {
if (sample_value <= cdf[i_translated + k]) {
output_samples[i + j] = static_cast<U>(k);
selected_class_idx = k;
break;
}
}
// Additional step with replacement - change probability of a given class to 0, and update the cdf
if (with_replacement) {
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
cdf[i_translated + selected_class_idx - 1]
: cdf[i_translated + selected_class_idx];
T divisor = 1 - class_probability;
for (size_t k = 0; k < class_size; ++k) {
if (k >= selected_class_idx) {
cdf[i_translated + k] -= class_probability;
}
cdf[i_translated + k] /= divisor;
}
}
}
}
// Finally convert the samples to the requested data type
convert<U, V>(output_samples.data(), output, total_output_elements_count);
}
} // namespace multinomial
} // namespace reference
namespace op {
namespace multinomial {
namespace validate {
void input_types(const Node* op);
} // namespace validate
} // namespace multinomial
} // namespace op
} // namespace ov

View File

@ -0,0 +1,58 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "dimension_util.hpp"
#include "openvino/op/multinomial.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace v13 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const Multinomial* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
const auto& input_shape = input_shapes[0];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
input_shape.rank().compatible(1) || input_shape.rank().compatible(2),
"The rank of the 'probs' tensor defining output shape must be either 1 or 2.");
const auto& num_samples_shape = input_shapes[1];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
num_samples_shape.compatible(TRShape{}) || num_samples_shape.compatible(TRShape{1}),
"Number of samples must be a scalar or one element 1D tensor.");
auto output_shapes = std::vector<TRShape>(1);
auto& result_shape = output_shapes[0];
const auto input_rank_static = input_shape.rank().is_static();
if (input_rank_static) {
const auto& num_samples = get_input_const_data_as_shape<TRShape>(op, 1, ta);
if (num_samples) {
NODE_VALIDATION_CHECK(op,
(*num_samples)[0].get_min_length() >= 0,
"Number of samples must be non-negative. Got number of samples: ",
(*num_samples)[0].get_min_length());
result_shape = *num_samples;
} else {
result_shape = ov::PartialShape::dynamic(1);
}
if (input_shape.rank().compatible(2)) {
result_shape.insert(result_shape.begin(), input_shape[0]);
}
} else {
result_shape = ov::PartialShape::dynamic();
}
return output_shapes;
}
} // namespace v13
} // namespace op
} // namespace ov

View File

@ -0,0 +1,129 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/multinomial.hpp"
#include <cstring>
#include "bound_evaluate.hpp"
#include "itt.hpp"
#include "multinomial_shape_inference.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/reference/multinomial.hpp"
namespace ov {
// ------------------------------ v13 ------------------------------
op::v13::Multinomial::Multinomial(const Output<Node>& probs,
const Output<Node>& num_samples,
const ov::element::Type_t convert_type,
const bool with_replacement,
const bool log_probs,
const uint64_t global_seed,
const uint64_t op_seed)
: Op({probs, num_samples}),
m_convert_type(convert_type),
m_with_replacement(with_replacement),
m_log_probs(log_probs),
m_global_seed(global_seed),
m_op_seed(op_seed) {
constructor_validate_and_infer_types();
}
bool op::v13::Multinomial::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v13_Multinomial_visit_attributes);
visitor.on_attribute("convert_type", m_convert_type);
visitor.on_attribute("with_replacement", m_with_replacement);
visitor.on_attribute("log_probs", m_log_probs);
visitor.on_attribute("global_seed", m_global_seed);
visitor.on_attribute("op_seed", m_op_seed);
return true;
}
void op::v13::Multinomial::validate_and_infer_types() {
OV_OP_SCOPE(v13_Multinomial_validate_and_infer_types);
OPENVINO_SUPPRESS_DEPRECATED_START
const auto input_shapes = get_node_input_partial_shapes(*this);
OPENVINO_SUPPRESS_DEPRECATED_END
const auto output_shapes = shape_infer(this, input_shapes);
multinomial::validate::input_types(this);
set_output_type(0, m_convert_type, output_shapes[0]);
}
std::shared_ptr<Node> op::v13::Multinomial::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v13_Multinomial_clone_with_new_inputs);
check_new_args_count<OutputVector>(this, new_args);
return std::make_shared<op::v13::Multinomial>(new_args.at(0),
new_args.at(1),
m_convert_type,
m_with_replacement,
m_log_probs,
m_global_seed,
m_op_seed);
}
ov::element::Type_t op::v13::Multinomial::get_convert_type() const {
return m_convert_type;
}
bool op::v13::Multinomial::get_with_replacement() const {
return m_with_replacement;
}
bool op::v13::Multinomial::get_log_probs() const {
return m_log_probs;
}
uint64_t op::v13::Multinomial::get_global_seed() const {
return m_global_seed;
}
uint64_t op::v13::Multinomial::get_op_seed() const {
return m_op_seed;
}
void op::v13::Multinomial::set_convert_type(const ov::element::Type_t convert_type) {
m_convert_type = convert_type;
}
void op::v13::Multinomial::set_with_replacement(const bool with_replacement) {
m_with_replacement = with_replacement;
}
void op::v13::Multinomial::set_log_probs(const bool log_probs) {
m_log_probs = log_probs;
}
void op::v13::Multinomial::set_global_seed(const uint64_t global_seed) {
m_global_seed = global_seed;
}
void op::v13::Multinomial::set_op_seed(const uint64_t op_seed) {
m_op_seed = op_seed;
}
namespace op {
namespace multinomial {
namespace validate {
void input_types(const Node* op) {
NODE_VALIDATION_CHECK(op,
op->get_input_element_type(0).is_real(),
"Expected floating point type as element type for the 'probs' input.");
NODE_VALIDATION_CHECK(op,
op->get_input_element_type(1).is_integral_number(),
"Expected integer type as element type for the 'num_samples' input.");
}
} // namespace validate
} // namespace multinomial
} // namespace op
} // namespace ov

View File

@ -71,7 +71,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset10, 177},
OpsetTestParams{ov::get_opset11, 177},
OpsetTestParams{ov::get_opset12, 178},
OpsetTestParams{ov::get_opset13, 183}),
OpsetTestParams{ov::get_opset13, 184}),
OpsetTestNameGenerator{});
class MyOpOld : public ov::op::Op {

View File

@ -0,0 +1,62 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/multinomial.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
using namespace testing;
class TypePropMultinomialV13Test : public TypePropOpTest<ov::op::v13::Multinomial> {};
TEST_F(TypePropMultinomialV13Test, input_probs_f64_num_samples_i32_convert_i32) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f64, ov::Shape{4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
const auto op = make_op(probs, num_samples, ov::element::i32, false, false, 0, 0);
EXPECT_EQ(op->get_element_type(), ov::element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape::dynamic(1)));
}
TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_i32_convert_i64) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
const auto op = make_op(probs, num_samples, ov::element::i64, false, false, 0, 0);
EXPECT_EQ(op->get_element_type(), ov::element::i64);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{4, -1}));
}
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_data_type) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
ov::NodeValidationFailure,
HasSubstr("Expected floating point type as element type for the 'probs' input."));
}
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_data_type) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{});
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
ov::NodeValidationFailure,
HasSubstr("Expected integer type as element type for the 'num_samples' input."));
}
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_rank) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4, 4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
ov::NodeValidationFailure,
HasSubstr("The rank of the 'probs' tensor defining output shape must be either 1 or 2."));
}
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_rank) {
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1, 2});
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
ov::NodeValidationFailure,
HasSubstr("Number of samples must be a scalar or one element 1D tensor."));
}

View File

@ -0,0 +1,31 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/multinomial.hpp"
#include <gtest/gtest.h>
#include "openvino/op/unique.hpp"
#include "visitors/visitors.hpp"
using namespace ov;
using ov::test::NodeBuilder;
TEST(attributes, multinomial) {
NodeBuilder::get_ops().register_factory<ov::op::v13::Multinomial>();
const auto probs = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(element::i32, Shape{1});
const auto op = std::make_shared<ov::op::v13::Multinomial>(probs, num_samples, element::f32, false, true, 0, 0);
NodeBuilder builder(op, {probs, num_samples});
auto g_multi = ov::as_type_ptr<ov::op::v13::Multinomial>(builder.create());
constexpr auto expected_attr_count = 5;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
EXPECT_EQ(op->get_with_replacement(), g_multi->get_with_replacement());
EXPECT_EQ(op->get_global_seed(), g_multi->get_global_seed());
EXPECT_EQ(op->get_convert_type(), g_multi->get_convert_type());
EXPECT_EQ(op->get_log_probs(), g_multi->get_log_probs());
EXPECT_EQ(op->get_op_seed(), g_multi->get_op_seed());
}

View File

@ -0,0 +1,127 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "multinomial_shape_inference.hpp"
#include <gtest/gtest.h>
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest1D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Static Shape 1D input
std::vector<StaticShape> static_input_shapes = {StaticShape{4}, StaticShape{1}};
int32_t num_elements_val = 2;
auto const_data =
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
auto acc = make_tensor_accessor(const_data);
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
ASSERT_EQ(static_output_shapes[0], StaticShape({2}));
}
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest2D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4, 4});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Static Shape 2D input
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 4}, StaticShape{1}};
int32_t num_elements_val = 2;
auto const_data =
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
auto acc = make_tensor_accessor(const_data);
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
ASSERT_EQ(static_output_shapes[0], StaticShape({4, 2}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown1D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 1D input
std::vector<PartialShape> partial_input_shapes = {PartialShape{3}, PartialShape{1}};
int32_t num_elements_val = 2;
auto const_data =
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
auto acc = make_tensor_accessor(const_data);
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
ASSERT_EQ(partial_output_shapes[0], PartialShape({2}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown2D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{2, 3});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 2D input
std::vector<PartialShape> partial_input_shapes = {PartialShape{2, 3}, PartialShape{1}};
int32_t num_elements_val = 2;
auto const_data =
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
auto acc = make_tensor_accessor(const_data);
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
ASSERT_EQ(partial_output_shapes[0], PartialShape({2, 2}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples1D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 1D input, unknown num_samples
std::vector<PartialShape> partial_input_shapes = {PartialShape{4}, PartialShape{-1}};
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples2D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4, 4});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 2D input, unknown num_samples
std::vector<PartialShape> partial_input_shapes = {PartialShape{4, 4}, PartialShape{-1}};
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
ASSERT_EQ(partial_output_shapes[0], PartialShape({4, -1}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples1D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 1D input, unknown num_samples and probs shape
std::vector<PartialShape> partial_input_shapes = {PartialShape{-1}, PartialShape{-1}};
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples2D) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape 2D input, unknown num_samples and probs shape
std::vector<PartialShape> partial_input_shapes = {PartialShape{-1, -1}, PartialShape{-1}};
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1, -1}));
}
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamplesDynamicRank) {
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
// Test Partial Shape dynamic input, unknown num_samples and probs shape
std::vector<PartialShape> partial_input_shapes = {PartialShape::dynamic(), PartialShape{-1}};
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
ASSERT_EQ(partial_output_shapes[0], PartialShape::dynamic());
}

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/reference/multinomial.hpp"
#include "evaluate_node.hpp"
#include "multinomial_shape_inference.hpp"
template <ov::element::Type_t INPUT_T, ov::element::Type_t SAMPLES_T, ov::element::Type_t OUTPUT_T>
inline void evaluate_output_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T1 = typename ov::element_type_traits<INPUT_T>::value_type;
using T2 = typename ov::element_type_traits<SAMPLES_T>::value_type;
using T3 = typename ov::element_type_traits<OUTPUT_T>::value_type;
const auto tensor_acc = make_tensor_accessor(inputs);
const std::vector<ov::PartialShape> input_shapes{op->get_input_shape(0), op->get_input_shape(1)};
const auto out_shape = ov::op::v13::shape_infer(op.get(), input_shapes, tensor_acc).front().to_shape();
outputs[0].set_shape(out_shape);
ov::reference::multinomial::multinomial<T1, T2, T3>(inputs[0].data<const T1>(),
op->get_input_shape(0),
inputs[1].data<const T2>(),
op->get_input_shape(1),
outputs[0].data<T3>(),
out_shape,
op->get_with_replacement(),
op->get_log_probs(),
op->get_global_seed(),
op->get_op_seed());
}
template <ov::element::Type_t INPUT_T, ov::element::Type_t SAMPLES_T>
inline void evaluate_samples_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
switch (op->get_convert_type()) {
case ov::element::Type_t::i32:
evaluate_output_t<INPUT_T, SAMPLES_T, ov::element::Type_t::i32>(op, outputs, inputs);
return;
case ov::element::Type_t::i64:
evaluate_output_t<INPUT_T, SAMPLES_T, ov::element::Type_t::i64>(op, outputs, inputs);
return;
default:
OPENVINO_THROW(std::string("Unhandled convert data type '") +
ov::element::Type(op->get_convert_type()).get_type_name() +
std::string("' in evaluate_node(). Use either i32 or i64 and apply conversion manually."));
}
}
template <ov::element::Type_t INPUT_T>
bool evaluate_input_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
switch (inputs[1].get_element_type()) {
case ov::element::Type_t::i64:
evaluate_samples_t<INPUT_T, ov::element::Type_t::i64>(op, outputs, inputs);
break;
default:
evaluate_samples_t<INPUT_T, ov::element::Type_t::i32>(op, outputs, inputs);
break;
}
return true;
}
template <>
bool evaluate_node<ov::op::v13::Multinomial>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
switch (node->get_input_element_type(0)) {
case ov::element::Type_t::f16:
return evaluate_input_t<ov::element::Type_t::f16>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
outputs,
inputs);
case ov::element::Type_t::f32:
return evaluate_input_t<ov::element::Type_t::f32>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
outputs,
inputs);
case ov::element::Type_t::f64:
return evaluate_input_t<ov::element::Type_t::f64>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
outputs,
inputs);
case ov::element::Type_t::bf16:
return evaluate_input_t<ov::element::Type_t::bf16>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
outputs,
inputs);
default:
OPENVINO_THROW(std::string("Unhandled input data type ") + node->get_input_element_type(0).get_type_name() +
std::string(" in evaluate_node()."));
}
}

View File

@ -465,6 +465,10 @@ extern template bool evaluate_node<ov::op::v13::NMSRotated>(std::shared_ptr<ov::
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
extern template bool evaluate_node<ov::op::v13::Multinomial>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
extern template bool evaluate_node<ov::op::internal::AUGRUCell>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

View File

@ -155,6 +155,7 @@ _OPENVINO_OP_REG(BitwiseNot, ov::op::v13)
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
_OPENVINO_OP_REG(Multinomial, ov::op::v13)
_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)

View File

@ -0,0 +1,161 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/multinomial.hpp"
#include "base_reference_test.hpp"
#include "gtest/gtest.h"
#include "openvino/op/parameter.hpp"
namespace {
struct MultinomialParams {
MultinomialParams(const reference_tests::Tensor& probabilities,
const reference_tests::Tensor& num_samples,
const reference_tests::Tensor& expected_tensor,
ov::element::Type_t convert_type,
bool log_probs,
bool with_replacement,
std::string name)
: probabilities{probabilities},
num_samples{num_samples},
expected_tensor(expected_tensor),
convert_type{convert_type},
log_probs(log_probs),
with_replacement(with_replacement),
test_case_name{std::move(name)} {}
reference_tests::Tensor probabilities;
reference_tests::Tensor num_samples;
reference_tests::Tensor expected_tensor;
ov::element::Type_t convert_type;
bool log_probs;
bool with_replacement;
std::string test_case_name;
};
class ReferenceMultinomial : public testing::TestWithParam<MultinomialParams>,
public reference_tests::CommonReferenceTest {
public:
void SetUp() override {
const auto& params = GetParam();
function = CreateFunction(params);
inputData = {params.probabilities.data, params.num_samples.data};
refOutData = {params.expected_tensor.data};
}
static std::string getTestCaseName(const testing::TestParamInfo<MultinomialParams>& obj) {
std::ostringstream name;
name << obj.param.test_case_name;
name << "_input_type_";
name << obj.param.probabilities.type;
name << "_samples_type_";
name << obj.param.num_samples.type;
name << "_convert_type_";
name << obj.param.convert_type;
name << "_log_";
name << obj.param.log_probs;
name << "_replacement_";
name << obj.param.with_replacement;
return name.str();
}
private:
static std::shared_ptr<ov::Model> CreateFunction(const MultinomialParams& params) {
const auto in_probabilities =
std::make_shared<ov::op::v0::Parameter>(params.probabilities.type, params.probabilities.shape);
const auto in_num_samples =
std::make_shared<ov::op::v0::Parameter>(params.num_samples.type, params.num_samples.shape);
const auto multinomial = std::make_shared<ov::op::v13::Multinomial>(in_probabilities,
in_num_samples,
params.convert_type,
params.with_replacement,
params.log_probs,
1,
1);
return std::make_shared<ov::Model>(multinomial->outputs(),
ov::ParameterVector{in_probabilities, in_num_samples});
}
};
template <ov::element::Type_t et>
std::vector<MultinomialParams> generateMultinomialParams() {
using vt = typename ov::element_type_traits<et>::value_type;
const ov::Shape prob_2d_shape{2, 4};
const ov::Shape prob_1d_shape{4};
const ov::Shape num_samples_shape{1};
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
et,
std::vector<vt>{0.001, 0.01, 0.1, 0.899, 0.899, 0.1, 0.01, 0.001});
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
reference_tests::Tensor probabilities_1d_no_log(prob_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
reference_tests::Tensor probabilities_1d_log(prob_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
reference_tests::Tensor output_2d_no_log_no_replacement(prob_2d_shape,
ov::element::Type_t::i32,
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
reference_tests::Tensor output_2d_log_no_replacement(prob_2d_shape,
ov::element::Type_t::i32,
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
reference_tests::Tensor output_1d_no_log_replacement(prob_1d_shape,
ov::element::Type_t::i64,
std::vector<int64_t>{3, 2, 1, 0});
reference_tests::Tensor output_1d_log_replacement(prob_1d_shape,
ov::element::Type_t::i64,
std::vector<int64_t>{1, 2, 3, 0});
std::vector<MultinomialParams> params;
// probabilities, num_samples, output, convert_type, log_probs, with_replacement, name
params.emplace_back(probabilities_2d_no_log,
num_samples,
output_2d_no_log_no_replacement,
ov::element::Type_t::i32,
false,
false,
"input_2d");
params.emplace_back(probabilities_2d_log,
num_samples,
output_2d_log_no_replacement,
ov::element::Type_t::i32,
true,
false,
"input_2d");
params.emplace_back(probabilities_1d_no_log,
num_samples,
output_1d_no_log_replacement,
ov::element::Type_t::i64,
false,
true,
"input_1d");
params.emplace_back(probabilities_1d_log,
num_samples,
output_1d_log_replacement,
ov::element::Type_t::i64,
true,
true,
"input_1d");
return params;
}
std::vector<MultinomialParams> generateMultinomialParams() {
std::vector<std::vector<MultinomialParams>> combo_params{generateMultinomialParams<ov::element::f32>()};
std::vector<MultinomialParams> test_params;
for (auto& params : combo_params)
std::move(params.begin(), params.end(), std::back_inserter(test_params));
return test_params;
}
} // namespace
TEST_P(ReferenceMultinomial, CompareWithRefs) {
Exec();
}
INSTANTIATE_TEST_SUITE_P(smoke,
ReferenceMultinomial,
::testing::ValuesIn(generateMultinomialParams()),
ReferenceMultinomial::getTestCaseName);

View File

@ -606,6 +606,15 @@ std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v0::MatMul> &n
return std::make_shared<ov::Model>(results, params, "MatMul-1");
}
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v13::Multinomial>& node) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 5}}),
std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1})};
auto multinomial =
std::make_shared<ov::op::v13::Multinomial>(params[0], params[1], ov::element::i32, false, false, 0, 0);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(multinomial)};
return std::make_shared<ov::Model>(results, params, "Multinomial-13");
}
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v13::NMSRotated> &node) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 6, 5}}),
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 1, 6}}),