[Ref][Core][Opset13] Add Multinomial Operation (#19655)
* [Ref] Multinomial base file * [Ref] Add core & reference implementation reusing other ops * [Ref] Fix reference implementation, add missing parameters, add tests * [Core] Add opset13, register multinomial, add shape inference * [Ref][Core] Fix compile errors * [Ref][Core] Clang fix * [TEMPLATE] Remove bf16, f16, f64 types * [TEMPLATE] Remove incorrect input types for 'input' parameter * [Ref][Tests] Remove deleted test types * [Ref] Fix & optimize shape inference * [PT FE] Apply suggestions from review * [Template] Migrate to new API * [Core] Add a clause for dynamic input in shape inference * [Tests] Add missing type_prop test (?) * Update multinomial_shape_inference.hpp * Update multinomial.hpp * [Ref] Fix build issues * [Ref] Fix clang and style * [Ref] Fix tests without replacement * [Ref] Fix with_replacement sampling error * [Ref] Remove debugging artifacts * [Ref] Cast to 64-bit size for 32-bit systems * Update multinomial.hpp * [Ref] Add missing type_prop tests, add shape inference tests * Update multinomial.cpp * Update multinomial_shape_inference_test.cpp * Update multinomial.cpp * Update multinomial.hpp * [Ref] Fix compilation errors from shape inference test * [Ref] Fix compilation error of type_prop, apply recommendations from review * [Ref] Add multiple shape inference tests * [Ref] Change TEST to TEST_F, add more type_prop tests * [Ref] Clang fixes * [Ref] Fix shape inference tests with mismatching args * [Ref] Fix remaining type_prop errors * [Ref] Replace HostTensor with normal Tensor in shape inference tests * Update opset.cpp * [Ref] Possible fix for 'function empty' error * [Ref] Add a cast to remove conversion warning * [Ref] Add conformance test of Multinomial * [Ref] Match style of conf test to the remaining tests * Update single_op_graph.cpp
This commit is contained in:
parent
bdb13aa28d
commit
48164e2279
67
src/core/include/openvino/op/multinomial.hpp
Normal file
67
src/core/include/openvino/op/multinomial.hpp
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/op/op.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace v13 {
|
||||||
|
/// \brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
|
||||||
|
///
|
||||||
|
/// \ingroup ov_ops_cpp_api
|
||||||
|
class OPENVINO_API Multinomial : public Op {
|
||||||
|
public:
|
||||||
|
OPENVINO_OP("Multinomial", "opset13");
|
||||||
|
Multinomial() = default;
|
||||||
|
/**
|
||||||
|
* @brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
|
||||||
|
*
|
||||||
|
* @param probs Input tensor containing at each index poisition probability/log probability of sampling a given
|
||||||
|
* class. Any floating-point precision values are allowed.
|
||||||
|
* @param num_samples Scalar or 1D tensor with a single value that determines the number of samples to generate per
|
||||||
|
* batch. Values should be of an integer type.
|
||||||
|
* @param convert_type Data type to which to convert the output class indices. Allowed values: i32/i64
|
||||||
|
* @param with_replacement Boolean that determines whether a sampled class can appear more than once in the output.
|
||||||
|
* @param log_probs Boolean that determines whether to treat input probabilities as log probabilities.
|
||||||
|
* @param global_seed First seed value (key) of Phillox random number generation algorithm. (See RandomUniform for
|
||||||
|
* details)
|
||||||
|
* @param op_seed Second seed value (counter) of Phillox random number generation algorithm. (See RandomUniform for
|
||||||
|
* details)
|
||||||
|
*/
|
||||||
|
Multinomial(const Output<Node>& input,
|
||||||
|
const Output<Node>& num_samples,
|
||||||
|
const ov::element::Type_t output_type,
|
||||||
|
const bool with_replacement,
|
||||||
|
const bool log_probs,
|
||||||
|
const uint64_t global_seed = 0,
|
||||||
|
const uint64_t op_seed = 0);
|
||||||
|
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
ov::element::Type_t get_convert_type() const;
|
||||||
|
bool get_with_replacement() const;
|
||||||
|
bool get_log_probs() const;
|
||||||
|
uint64_t get_global_seed() const;
|
||||||
|
uint64_t get_op_seed() const;
|
||||||
|
|
||||||
|
void set_convert_type(const ov::element::Type_t output_type);
|
||||||
|
void set_with_replacement(const bool with_replacement);
|
||||||
|
void set_log_probs(const bool log_probs);
|
||||||
|
void set_global_seed(const uint64_t global_seed);
|
||||||
|
void set_op_seed(const uint64_t op_seed);
|
||||||
|
|
||||||
|
private:
|
||||||
|
ov::element::Type_t m_convert_type;
|
||||||
|
bool m_with_replacement;
|
||||||
|
bool m_log_probs;
|
||||||
|
uint64_t m_global_seed;
|
||||||
|
uint64_t m_op_seed;
|
||||||
|
};
|
||||||
|
} // namespace v13
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -110,6 +110,7 @@
|
|||||||
#include "openvino/op/mish.hpp"
|
#include "openvino/op/mish.hpp"
|
||||||
#include "openvino/op/mod.hpp"
|
#include "openvino/op/mod.hpp"
|
||||||
#include "openvino/op/multiclass_nms.hpp"
|
#include "openvino/op/multiclass_nms.hpp"
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
#include "openvino/op/multiply.hpp"
|
#include "openvino/op/multiply.hpp"
|
||||||
#include "openvino/op/mvn.hpp"
|
#include "openvino/op/mvn.hpp"
|
||||||
#include "openvino/op/negative.hpp"
|
#include "openvino/op/negative.hpp"
|
||||||
|
@ -214,3 +214,4 @@ _OPENVINO_OP_REG(BitwiseNot, ov::op::v13)
|
|||||||
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
|
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
|
||||||
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
|
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
|
||||||
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
|
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
|
||||||
|
_OPENVINO_OP_REG(Multinomial, ov::op::v13)
|
||||||
|
163
src/core/reference/include/openvino/reference/multinomial.hpp
Normal file
163
src/core/reference/include/openvino/reference/multinomial.hpp
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
#include "openvino/reference/broadcast.hpp"
|
||||||
|
#include "openvino/reference/convert.hpp"
|
||||||
|
#include "openvino/reference/copy.hpp"
|
||||||
|
#include "openvino/reference/cum_sum.hpp"
|
||||||
|
#include "openvino/reference/divide.hpp"
|
||||||
|
#include "openvino/reference/exp.hpp"
|
||||||
|
#include "openvino/reference/random_uniform.hpp"
|
||||||
|
#include "openvino/reference/slice.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace reference {
|
||||||
|
namespace multinomial {
|
||||||
|
/**
|
||||||
|
* @brief Multinomial operation creates a sequence of indices of classes sampled from the multinomial distribution.
|
||||||
|
*
|
||||||
|
* @tparam T Data type of the probs' values.
|
||||||
|
* @tparam U Data type of num_samples' values.
|
||||||
|
* @tparam V Data type of output's values.
|
||||||
|
* @param probs Input tensor containing at each index poisition probability/log probability of sampling a given class.
|
||||||
|
* @param probs_shape Shape of the 'probs' tensor.
|
||||||
|
* @param num_samples Scalar or 1D tensor with a single value that determines the number of samples to generate per
|
||||||
|
* batch.
|
||||||
|
* @param num_samples_shape Shape of the 'num_samples' tensor.
|
||||||
|
* @param output Output tensor for the generated class indices.
|
||||||
|
* @param output_shape Shape of the 'output' tensor.
|
||||||
|
* @param with_replacement Boolean that determines whether a sampled class can appear more than once in the output.
|
||||||
|
* @param log_probs Boolean that determines whether to treat input probabilities as log probabilities.
|
||||||
|
* @param global_seed First seed value (key) of Phillox random number generation algorithm. (See RandomUniform for
|
||||||
|
* details)
|
||||||
|
* @param op_seed Second seed value (counter) of Phillox random number generation algorithm. (See RandomUniform for
|
||||||
|
* details)
|
||||||
|
*/
|
||||||
|
template <typename T, typename U, typename V>
|
||||||
|
void multinomial(const T* probs,
|
||||||
|
const Shape& probs_shape,
|
||||||
|
const U* num_samples,
|
||||||
|
const Shape& num_samples_shape,
|
||||||
|
V* output,
|
||||||
|
const Shape& output_shape,
|
||||||
|
const bool with_replacement,
|
||||||
|
const bool log_probs,
|
||||||
|
const uint64_t global_seed,
|
||||||
|
const uint64_t op_seed) {
|
||||||
|
const auto total_inputs_elements_count = shape_size<Shape>(probs_shape);
|
||||||
|
const auto total_output_elements_count = shape_size<Shape>(output_shape);
|
||||||
|
|
||||||
|
// If probabilities are log probabilities, exponentiate to get normal probabilities
|
||||||
|
std::vector<T> input_vals(total_inputs_elements_count);
|
||||||
|
if (log_probs) {
|
||||||
|
exp(probs, input_vals.data(), total_inputs_elements_count);
|
||||||
|
} else {
|
||||||
|
copy(probs, input_vals.data(), total_inputs_elements_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a cdf of probabilties on the last axis, per batch. Note cumsum exclusive == false
|
||||||
|
std::vector<T> cdf(total_inputs_elements_count);
|
||||||
|
const auto last_axis = probs_shape.size() - 1;
|
||||||
|
cumsum(input_vals.data(), last_axis, cdf.data(), probs_shape, false, false);
|
||||||
|
|
||||||
|
// Obtain max value from cdf, per batch (from cumsum it is the last element)
|
||||||
|
std::vector<T> max_value_per_batch(total_inputs_elements_count / probs_shape[last_axis]);
|
||||||
|
Shape max_value_per_batch_shape(probs_shape);
|
||||||
|
max_value_per_batch_shape[last_axis] = 1;
|
||||||
|
const std::vector<int64_t> start{static_cast<int64_t>(probs_shape[last_axis] - 1)};
|
||||||
|
const std::vector<int64_t> step{1};
|
||||||
|
const std::vector<int64_t> target_axis_vec{static_cast<int64_t>(last_axis)};
|
||||||
|
slice(reinterpret_cast<const char*>(cdf.data()),
|
||||||
|
probs_shape, // == cdf shape
|
||||||
|
reinterpret_cast<char*>(max_value_per_batch.data()),
|
||||||
|
max_value_per_batch_shape,
|
||||||
|
sizeof(T),
|
||||||
|
start,
|
||||||
|
step,
|
||||||
|
target_axis_vec);
|
||||||
|
|
||||||
|
// Normalize the cdf by dividing all elements by the max value in each batch
|
||||||
|
std::vector<T> max_value_per_batch_divisor(total_inputs_elements_count);
|
||||||
|
ov::AxisSet target_axis_set = ov::AxisSet({last_axis});
|
||||||
|
broadcast(reinterpret_cast<const char*>(max_value_per_batch.data()),
|
||||||
|
reinterpret_cast<char*>(max_value_per_batch_divisor.data()),
|
||||||
|
max_value_per_batch_shape,
|
||||||
|
probs_shape, // expand to original shape (expands last dim)
|
||||||
|
target_axis_set,
|
||||||
|
sizeof(T));
|
||||||
|
divide(cdf.data(), max_value_per_batch_divisor.data(), cdf.data(), total_inputs_elements_count, false);
|
||||||
|
|
||||||
|
// Generate random probability samples
|
||||||
|
std::vector<double> uniform_samples(total_output_elements_count);
|
||||||
|
const double zero = 0;
|
||||||
|
const double one = 1;
|
||||||
|
const ov::Shape output_shape_shape{output_shape.size()};
|
||||||
|
const std::vector<uint64_t> output_shape_u64(output_shape.begin(), output_shape.end());
|
||||||
|
const std::pair<uint64_t, uint64_t> initial_state(0, 0);
|
||||||
|
random_uniform(output_shape_u64.data(),
|
||||||
|
reinterpret_cast<const char*>(&zero),
|
||||||
|
reinterpret_cast<const char*>(&one),
|
||||||
|
reinterpret_cast<char*>(uniform_samples.data()),
|
||||||
|
output_shape_shape,
|
||||||
|
ov::element::f64,
|
||||||
|
global_seed,
|
||||||
|
op_seed,
|
||||||
|
initial_state);
|
||||||
|
|
||||||
|
auto batch_size = probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[0]) : static_cast<size_t>(1);
|
||||||
|
auto class_size =
|
||||||
|
probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[1]) : static_cast<size_t>(probs_shape[0]);
|
||||||
|
auto samples_size =
|
||||||
|
probs_shape.size() == 2 ? static_cast<size_t>(num_samples[0]) : static_cast<size_t>(probs_shape[0]);
|
||||||
|
|
||||||
|
// Iterate over each channel in uniform samples
|
||||||
|
std::vector<U> output_samples(total_output_elements_count);
|
||||||
|
for (size_t i = 0; i < batch_size * samples_size; i += samples_size) {
|
||||||
|
for (size_t j = 0; j < samples_size; ++j) {
|
||||||
|
// Iterate over cdf to find the index for a given sample
|
||||||
|
// If no class found (all have 0 probability), selects last - undefined behavior
|
||||||
|
auto i_translated = i / samples_size * class_size;
|
||||||
|
auto selected_class_idx = class_size;
|
||||||
|
auto sample_value = uniform_samples[i + j];
|
||||||
|
for (size_t k = 0; k < class_size; ++k) {
|
||||||
|
if (sample_value <= cdf[i_translated + k]) {
|
||||||
|
output_samples[i + j] = static_cast<U>(k);
|
||||||
|
selected_class_idx = k;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Additional step with replacement - change probability of a given class to 0, and update the cdf
|
||||||
|
if (with_replacement) {
|
||||||
|
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
|
||||||
|
cdf[i_translated + selected_class_idx - 1]
|
||||||
|
: cdf[i_translated + selected_class_idx];
|
||||||
|
T divisor = 1 - class_probability;
|
||||||
|
for (size_t k = 0; k < class_size; ++k) {
|
||||||
|
if (k >= selected_class_idx) {
|
||||||
|
cdf[i_translated + k] -= class_probability;
|
||||||
|
}
|
||||||
|
cdf[i_translated + k] /= divisor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Finally convert the samples to the requested data type
|
||||||
|
convert<U, V>(output_samples.data(), output, total_output_elements_count);
|
||||||
|
}
|
||||||
|
} // namespace multinomial
|
||||||
|
} // namespace reference
|
||||||
|
|
||||||
|
namespace op {
|
||||||
|
namespace multinomial {
|
||||||
|
namespace validate {
|
||||||
|
void input_types(const Node* op);
|
||||||
|
} // namespace validate
|
||||||
|
} // namespace multinomial
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "dimension_util.hpp"
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace v13 {
|
||||||
|
template <class TShape, class TRShape = result_shape_t<TShape>>
|
||||||
|
std::vector<TRShape> shape_infer(const Multinomial* op,
|
||||||
|
const std::vector<TShape>& input_shapes,
|
||||||
|
const ITensorAccessor& ta = make_tensor_accessor()) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
|
||||||
|
|
||||||
|
const auto& input_shape = input_shapes[0];
|
||||||
|
NODE_SHAPE_INFER_CHECK(op,
|
||||||
|
input_shapes,
|
||||||
|
input_shape.rank().compatible(1) || input_shape.rank().compatible(2),
|
||||||
|
"The rank of the 'probs' tensor defining output shape must be either 1 or 2.");
|
||||||
|
|
||||||
|
const auto& num_samples_shape = input_shapes[1];
|
||||||
|
NODE_SHAPE_INFER_CHECK(op,
|
||||||
|
input_shapes,
|
||||||
|
num_samples_shape.compatible(TRShape{}) || num_samples_shape.compatible(TRShape{1}),
|
||||||
|
"Number of samples must be a scalar or one element 1D tensor.");
|
||||||
|
|
||||||
|
auto output_shapes = std::vector<TRShape>(1);
|
||||||
|
auto& result_shape = output_shapes[0];
|
||||||
|
const auto input_rank_static = input_shape.rank().is_static();
|
||||||
|
if (input_rank_static) {
|
||||||
|
const auto& num_samples = get_input_const_data_as_shape<TRShape>(op, 1, ta);
|
||||||
|
if (num_samples) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
(*num_samples)[0].get_min_length() >= 0,
|
||||||
|
"Number of samples must be non-negative. Got number of samples: ",
|
||||||
|
(*num_samples)[0].get_min_length());
|
||||||
|
result_shape = *num_samples;
|
||||||
|
} else {
|
||||||
|
result_shape = ov::PartialShape::dynamic(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_shape.rank().compatible(2)) {
|
||||||
|
result_shape.insert(result_shape.begin(), input_shape[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result_shape = ov::PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_shapes;
|
||||||
|
}
|
||||||
|
} // namespace v13
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
129
src/core/src/op/multinomial.cpp
Normal file
129
src/core/src/op/multinomial.cpp
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "bound_evaluate.hpp"
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "multinomial_shape_inference.hpp"
|
||||||
|
#include "openvino/core/attribute_visitor.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/reference/multinomial.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
|
||||||
|
// ------------------------------ v13 ------------------------------
|
||||||
|
|
||||||
|
op::v13::Multinomial::Multinomial(const Output<Node>& probs,
|
||||||
|
const Output<Node>& num_samples,
|
||||||
|
const ov::element::Type_t convert_type,
|
||||||
|
const bool with_replacement,
|
||||||
|
const bool log_probs,
|
||||||
|
const uint64_t global_seed,
|
||||||
|
const uint64_t op_seed)
|
||||||
|
: Op({probs, num_samples}),
|
||||||
|
m_convert_type(convert_type),
|
||||||
|
m_with_replacement(with_replacement),
|
||||||
|
m_log_probs(log_probs),
|
||||||
|
m_global_seed(global_seed),
|
||||||
|
m_op_seed(op_seed) {
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v13::Multinomial::visit_attributes(AttributeVisitor& visitor) {
|
||||||
|
OV_OP_SCOPE(v13_Multinomial_visit_attributes);
|
||||||
|
visitor.on_attribute("convert_type", m_convert_type);
|
||||||
|
visitor.on_attribute("with_replacement", m_with_replacement);
|
||||||
|
visitor.on_attribute("log_probs", m_log_probs);
|
||||||
|
visitor.on_attribute("global_seed", m_global_seed);
|
||||||
|
visitor.on_attribute("op_seed", m_op_seed);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::validate_and_infer_types() {
|
||||||
|
OV_OP_SCOPE(v13_Multinomial_validate_and_infer_types);
|
||||||
|
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
|
const auto output_shapes = shape_infer(this, input_shapes);
|
||||||
|
|
||||||
|
multinomial::validate::input_types(this);
|
||||||
|
|
||||||
|
set_output_type(0, m_convert_type, output_shapes[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Node> op::v13::Multinomial::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
|
OV_OP_SCOPE(v13_Multinomial_clone_with_new_inputs);
|
||||||
|
check_new_args_count<OutputVector>(this, new_args);
|
||||||
|
|
||||||
|
return std::make_shared<op::v13::Multinomial>(new_args.at(0),
|
||||||
|
new_args.at(1),
|
||||||
|
m_convert_type,
|
||||||
|
m_with_replacement,
|
||||||
|
m_log_probs,
|
||||||
|
m_global_seed,
|
||||||
|
m_op_seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::element::Type_t op::v13::Multinomial::get_convert_type() const {
|
||||||
|
return m_convert_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v13::Multinomial::get_with_replacement() const {
|
||||||
|
return m_with_replacement;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v13::Multinomial::get_log_probs() const {
|
||||||
|
return m_log_probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t op::v13::Multinomial::get_global_seed() const {
|
||||||
|
return m_global_seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t op::v13::Multinomial::get_op_seed() const {
|
||||||
|
return m_op_seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::set_convert_type(const ov::element::Type_t convert_type) {
|
||||||
|
m_convert_type = convert_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::set_with_replacement(const bool with_replacement) {
|
||||||
|
m_with_replacement = with_replacement;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::set_log_probs(const bool log_probs) {
|
||||||
|
m_log_probs = log_probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::set_global_seed(const uint64_t global_seed) {
|
||||||
|
m_global_seed = global_seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v13::Multinomial::set_op_seed(const uint64_t op_seed) {
|
||||||
|
m_op_seed = op_seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace op {
|
||||||
|
namespace multinomial {
|
||||||
|
namespace validate {
|
||||||
|
void input_types(const Node* op) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
op->get_input_element_type(0).is_real(),
|
||||||
|
"Expected floating point type as element type for the 'probs' input.");
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
op->get_input_element_type(1).is_integral_number(),
|
||||||
|
"Expected integer type as element type for the 'num_samples' input.");
|
||||||
|
}
|
||||||
|
} // namespace validate
|
||||||
|
} // namespace multinomial
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -71,7 +71,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
|
|||||||
OpsetTestParams{ov::get_opset10, 177},
|
OpsetTestParams{ov::get_opset10, 177},
|
||||||
OpsetTestParams{ov::get_opset11, 177},
|
OpsetTestParams{ov::get_opset11, 177},
|
||||||
OpsetTestParams{ov::get_opset12, 178},
|
OpsetTestParams{ov::get_opset12, 178},
|
||||||
OpsetTestParams{ov::get_opset13, 183}),
|
OpsetTestParams{ov::get_opset13, 184}),
|
||||||
OpsetTestNameGenerator{});
|
OpsetTestNameGenerator{});
|
||||||
|
|
||||||
class MyOpOld : public ov::op::Op {
|
class MyOpOld : public ov::op::Op {
|
||||||
|
62
src/core/tests/type_prop/multinomial.cpp
Normal file
62
src/core/tests/type_prop/multinomial.cpp
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
|
#include "common_test_utils/type_prop.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
class TypePropMultinomialV13Test : public TypePropOpTest<ov::op::v13::Multinomial> {};
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, input_probs_f64_num_samples_i32_convert_i32) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f64, ov::Shape{4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
||||||
|
const auto op = make_op(probs, num_samples, ov::element::i32, false, false, 0, 0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), ov::element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape::dynamic(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, input_probs_f32_num_samples_i32_convert_i64) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
|
||||||
|
const auto op = make_op(probs, num_samples, ov::element::i64, false, false, 0, 0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), ov::element::i64);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{4, -1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_data_type) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{});
|
||||||
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("Expected floating point type as element type for the 'probs' input."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_data_type) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{});
|
||||||
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::u64, false, false, 0, 0),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("Expected integer type as element type for the 'num_samples' input."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, probs_incompatibile_rank) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4, 4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1});
|
||||||
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("The rank of the 'probs' tensor defining output shape must be either 1 or 2."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropMultinomialV13Test, num_samples_incompatibile_rank) {
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1, 2});
|
||||||
|
OV_EXPECT_THROW(std::ignore = make_op(probs, num_samples, ov::element::boolean, false, false, 0, 0),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("Number of samples must be a scalar or one element 1D tensor."));
|
||||||
|
}
|
31
src/core/tests/visitors/op/multinomial.cpp
Normal file
31
src/core/tests/visitors/op/multinomial.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "openvino/op/unique.hpp"
|
||||||
|
#include "visitors/visitors.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using ov::test::NodeBuilder;
|
||||||
|
|
||||||
|
TEST(attributes, multinomial) {
|
||||||
|
NodeBuilder::get_ops().register_factory<ov::op::v13::Multinomial>();
|
||||||
|
const auto probs = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
|
||||||
|
const auto num_samples = std::make_shared<ov::op::v0::Parameter>(element::i32, Shape{1});
|
||||||
|
|
||||||
|
const auto op = std::make_shared<ov::op::v13::Multinomial>(probs, num_samples, element::f32, false, true, 0, 0);
|
||||||
|
NodeBuilder builder(op, {probs, num_samples});
|
||||||
|
auto g_multi = ov::as_type_ptr<ov::op::v13::Multinomial>(builder.create());
|
||||||
|
|
||||||
|
constexpr auto expected_attr_count = 5;
|
||||||
|
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||||
|
EXPECT_EQ(op->get_with_replacement(), g_multi->get_with_replacement());
|
||||||
|
EXPECT_EQ(op->get_global_seed(), g_multi->get_global_seed());
|
||||||
|
EXPECT_EQ(op->get_convert_type(), g_multi->get_convert_type());
|
||||||
|
EXPECT_EQ(op->get_log_probs(), g_multi->get_log_probs());
|
||||||
|
EXPECT_EQ(op->get_op_seed(), g_multi->get_op_seed());
|
||||||
|
}
|
@ -0,0 +1,127 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "multinomial_shape_inference.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::intel_cpu;
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest1D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Static Shape 1D input
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{4}, StaticShape{1}};
|
||||||
|
int32_t num_elements_val = 2;
|
||||||
|
auto const_data =
|
||||||
|
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
|
auto acc = make_tensor_accessor(const_data);
|
||||||
|
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialStaticShapeInferenceTest2D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, Shape{4, 4});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Static Shape 2D input
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{4, 4}, StaticShape{1}};
|
||||||
|
int32_t num_elements_val = 2;
|
||||||
|
auto const_data =
|
||||||
|
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
|
auto acc = make_tensor_accessor(const_data);
|
||||||
|
auto static_output_shapes = shape_infer(multinomial.get(), static_input_shapes, acc);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({4, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown1D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 1D input
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{3}, PartialShape{1}};
|
||||||
|
int32_t num_elements_val = 2;
|
||||||
|
auto const_data =
|
||||||
|
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
|
auto acc = make_tensor_accessor(const_data);
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestAllDimKnown2D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{2, 3});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 2D input
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{2, 3}, PartialShape{1}};
|
||||||
|
int32_t num_elements_val = 2;
|
||||||
|
auto const_data =
|
||||||
|
std::unordered_map<size_t, Tensor>{{1, {element::i32, Shape{1}, &num_elements_val}}};
|
||||||
|
auto acc = make_tensor_accessor(const_data);
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, acc);
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({2, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples1D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 1D input, unknown num_samples
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{4}, PartialShape{-1}};
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicNumSamples2D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{4, 4});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 2D input, unknown num_samples
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{4, 4}, PartialShape{-1}};
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({4, -1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples1D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 1D input, unknown num_samples and probs shape
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{-1}, PartialShape{-1}};
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamples2D) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape 2D input, unknown num_samples and probs shape
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape{-1, -1}, PartialShape{-1}};
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape({-1, -1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, MultinomialDynamicShapeInferenceTestDynamicProbsDynamicNumSamplesDynamicRank) {
|
||||||
|
auto probs = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto num_samples = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto multinomial = std::make_shared<op::v13::Multinomial>(probs, num_samples, element::i32, false, false, 0, 0);
|
||||||
|
|
||||||
|
// Test Partial Shape dynamic input, unknown num_samples and probs shape
|
||||||
|
std::vector<PartialShape> partial_input_shapes = {PartialShape::dynamic(), PartialShape{-1}};
|
||||||
|
auto partial_output_shapes = shape_infer(multinomial.get(), partial_input_shapes, make_tensor_accessor());
|
||||||
|
ASSERT_EQ(partial_output_shapes[0], PartialShape::dynamic());
|
||||||
|
}
|
93
src/plugins/template/backend/ops/multinomial.cpp
Normal file
93
src/plugins/template/backend/ops/multinomial.cpp
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/reference/multinomial.hpp"
|
||||||
|
|
||||||
|
#include "evaluate_node.hpp"
|
||||||
|
#include "multinomial_shape_inference.hpp"
|
||||||
|
|
||||||
|
template <ov::element::Type_t INPUT_T, ov::element::Type_t SAMPLES_T, ov::element::Type_t OUTPUT_T>
|
||||||
|
inline void evaluate_output_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
|
||||||
|
ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs) {
|
||||||
|
using T1 = typename ov::element_type_traits<INPUT_T>::value_type;
|
||||||
|
using T2 = typename ov::element_type_traits<SAMPLES_T>::value_type;
|
||||||
|
using T3 = typename ov::element_type_traits<OUTPUT_T>::value_type;
|
||||||
|
|
||||||
|
const auto tensor_acc = make_tensor_accessor(inputs);
|
||||||
|
const std::vector<ov::PartialShape> input_shapes{op->get_input_shape(0), op->get_input_shape(1)};
|
||||||
|
const auto out_shape = ov::op::v13::shape_infer(op.get(), input_shapes, tensor_acc).front().to_shape();
|
||||||
|
outputs[0].set_shape(out_shape);
|
||||||
|
|
||||||
|
ov::reference::multinomial::multinomial<T1, T2, T3>(inputs[0].data<const T1>(),
|
||||||
|
op->get_input_shape(0),
|
||||||
|
inputs[1].data<const T2>(),
|
||||||
|
op->get_input_shape(1),
|
||||||
|
outputs[0].data<T3>(),
|
||||||
|
out_shape,
|
||||||
|
op->get_with_replacement(),
|
||||||
|
op->get_log_probs(),
|
||||||
|
op->get_global_seed(),
|
||||||
|
op->get_op_seed());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <ov::element::Type_t INPUT_T, ov::element::Type_t SAMPLES_T>
|
||||||
|
inline void evaluate_samples_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
|
||||||
|
ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs) {
|
||||||
|
switch (op->get_convert_type()) {
|
||||||
|
case ov::element::Type_t::i32:
|
||||||
|
evaluate_output_t<INPUT_T, SAMPLES_T, ov::element::Type_t::i32>(op, outputs, inputs);
|
||||||
|
return;
|
||||||
|
case ov::element::Type_t::i64:
|
||||||
|
evaluate_output_t<INPUT_T, SAMPLES_T, ov::element::Type_t::i64>(op, outputs, inputs);
|
||||||
|
return;
|
||||||
|
default:
|
||||||
|
OPENVINO_THROW(std::string("Unhandled convert data type '") +
|
||||||
|
ov::element::Type(op->get_convert_type()).get_type_name() +
|
||||||
|
std::string("' in evaluate_node(). Use either i32 or i64 and apply conversion manually."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <ov::element::Type_t INPUT_T>
|
||||||
|
bool evaluate_input_t(const std::shared_ptr<ov::op::v13::Multinomial>& op,
|
||||||
|
ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs) {
|
||||||
|
switch (inputs[1].get_element_type()) {
|
||||||
|
case ov::element::Type_t::i64:
|
||||||
|
evaluate_samples_t<INPUT_T, ov::element::Type_t::i64>(op, outputs, inputs);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
evaluate_samples_t<INPUT_T, ov::element::Type_t::i32>(op, outputs, inputs);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool evaluate_node<ov::op::v13::Multinomial>(std::shared_ptr<ov::Node> node,
|
||||||
|
ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs) {
|
||||||
|
switch (node->get_input_element_type(0)) {
|
||||||
|
case ov::element::Type_t::f16:
|
||||||
|
return evaluate_input_t<ov::element::Type_t::f16>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
|
||||||
|
outputs,
|
||||||
|
inputs);
|
||||||
|
case ov::element::Type_t::f32:
|
||||||
|
return evaluate_input_t<ov::element::Type_t::f32>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
|
||||||
|
outputs,
|
||||||
|
inputs);
|
||||||
|
case ov::element::Type_t::f64:
|
||||||
|
return evaluate_input_t<ov::element::Type_t::f64>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
|
||||||
|
outputs,
|
||||||
|
inputs);
|
||||||
|
case ov::element::Type_t::bf16:
|
||||||
|
return evaluate_input_t<ov::element::Type_t::bf16>(ov::as_type_ptr<ov::op::v13::Multinomial>(node),
|
||||||
|
outputs,
|
||||||
|
inputs);
|
||||||
|
default:
|
||||||
|
OPENVINO_THROW(std::string("Unhandled input data type ") + node->get_input_element_type(0).get_type_name() +
|
||||||
|
std::string(" in evaluate_node()."));
|
||||||
|
}
|
||||||
|
}
|
@ -465,6 +465,10 @@ extern template bool evaluate_node<ov::op::v13::NMSRotated>(std::shared_ptr<ov::
|
|||||||
ov::TensorVector& outputs,
|
ov::TensorVector& outputs,
|
||||||
const ov::TensorVector& inputs);
|
const ov::TensorVector& inputs);
|
||||||
|
|
||||||
|
extern template bool evaluate_node<ov::op::v13::Multinomial>(std::shared_ptr<ov::Node> node,
|
||||||
|
ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs);
|
||||||
|
|
||||||
extern template bool evaluate_node<ov::op::internal::AUGRUCell>(std::shared_ptr<ov::Node> node,
|
extern template bool evaluate_node<ov::op::internal::AUGRUCell>(std::shared_ptr<ov::Node> node,
|
||||||
ov::TensorVector& outputs,
|
ov::TensorVector& outputs,
|
||||||
const ov::TensorVector& inputs);
|
const ov::TensorVector& inputs);
|
||||||
|
@ -155,6 +155,7 @@ _OPENVINO_OP_REG(BitwiseNot, ov::op::v13)
|
|||||||
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
|
_OPENVINO_OP_REG(BitwiseOr, ov::op::v13)
|
||||||
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
|
_OPENVINO_OP_REG(BitwiseXor, ov::op::v13)
|
||||||
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
|
_OPENVINO_OP_REG(NMSRotated, ov::op::v13)
|
||||||
|
_OPENVINO_OP_REG(Multinomial, ov::op::v13)
|
||||||
|
|
||||||
_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
|
_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
|
||||||
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
|
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
|
||||||
|
@ -0,0 +1,161 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/op/multinomial.hpp"
|
||||||
|
|
||||||
|
#include "base_reference_test.hpp"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "openvino/op/parameter.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct MultinomialParams {
|
||||||
|
MultinomialParams(const reference_tests::Tensor& probabilities,
|
||||||
|
const reference_tests::Tensor& num_samples,
|
||||||
|
const reference_tests::Tensor& expected_tensor,
|
||||||
|
ov::element::Type_t convert_type,
|
||||||
|
bool log_probs,
|
||||||
|
bool with_replacement,
|
||||||
|
std::string name)
|
||||||
|
: probabilities{probabilities},
|
||||||
|
num_samples{num_samples},
|
||||||
|
expected_tensor(expected_tensor),
|
||||||
|
convert_type{convert_type},
|
||||||
|
log_probs(log_probs),
|
||||||
|
with_replacement(with_replacement),
|
||||||
|
test_case_name{std::move(name)} {}
|
||||||
|
|
||||||
|
reference_tests::Tensor probabilities;
|
||||||
|
reference_tests::Tensor num_samples;
|
||||||
|
reference_tests::Tensor expected_tensor;
|
||||||
|
|
||||||
|
ov::element::Type_t convert_type;
|
||||||
|
bool log_probs;
|
||||||
|
bool with_replacement;
|
||||||
|
std::string test_case_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReferenceMultinomial : public testing::TestWithParam<MultinomialParams>,
|
||||||
|
public reference_tests::CommonReferenceTest {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
const auto& params = GetParam();
|
||||||
|
function = CreateFunction(params);
|
||||||
|
inputData = {params.probabilities.data, params.num_samples.data};
|
||||||
|
refOutData = {params.expected_tensor.data};
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<MultinomialParams>& obj) {
|
||||||
|
std::ostringstream name;
|
||||||
|
name << obj.param.test_case_name;
|
||||||
|
name << "_input_type_";
|
||||||
|
name << obj.param.probabilities.type;
|
||||||
|
name << "_samples_type_";
|
||||||
|
name << obj.param.num_samples.type;
|
||||||
|
name << "_convert_type_";
|
||||||
|
name << obj.param.convert_type;
|
||||||
|
name << "_log_";
|
||||||
|
name << obj.param.log_probs;
|
||||||
|
name << "_replacement_";
|
||||||
|
name << obj.param.with_replacement;
|
||||||
|
return name.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::shared_ptr<ov::Model> CreateFunction(const MultinomialParams& params) {
|
||||||
|
const auto in_probabilities =
|
||||||
|
std::make_shared<ov::op::v0::Parameter>(params.probabilities.type, params.probabilities.shape);
|
||||||
|
const auto in_num_samples =
|
||||||
|
std::make_shared<ov::op::v0::Parameter>(params.num_samples.type, params.num_samples.shape);
|
||||||
|
const auto multinomial = std::make_shared<ov::op::v13::Multinomial>(in_probabilities,
|
||||||
|
in_num_samples,
|
||||||
|
params.convert_type,
|
||||||
|
params.with_replacement,
|
||||||
|
params.log_probs,
|
||||||
|
1,
|
||||||
|
1);
|
||||||
|
return std::make_shared<ov::Model>(multinomial->outputs(),
|
||||||
|
ov::ParameterVector{in_probabilities, in_num_samples});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <ov::element::Type_t et>
|
||||||
|
std::vector<MultinomialParams> generateMultinomialParams() {
|
||||||
|
using vt = typename ov::element_type_traits<et>::value_type;
|
||||||
|
|
||||||
|
const ov::Shape prob_2d_shape{2, 4};
|
||||||
|
const ov::Shape prob_1d_shape{4};
|
||||||
|
const ov::Shape num_samples_shape{1};
|
||||||
|
|
||||||
|
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
|
||||||
|
|
||||||
|
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
|
||||||
|
et,
|
||||||
|
std::vector<vt>{0.001, 0.01, 0.1, 0.899, 0.899, 0.1, 0.01, 0.001});
|
||||||
|
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
|
||||||
|
reference_tests::Tensor probabilities_1d_no_log(prob_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
|
||||||
|
reference_tests::Tensor probabilities_1d_log(prob_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
|
||||||
|
|
||||||
|
reference_tests::Tensor output_2d_no_log_no_replacement(prob_2d_shape,
|
||||||
|
ov::element::Type_t::i32,
|
||||||
|
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
|
||||||
|
reference_tests::Tensor output_2d_log_no_replacement(prob_2d_shape,
|
||||||
|
ov::element::Type_t::i32,
|
||||||
|
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
|
||||||
|
reference_tests::Tensor output_1d_no_log_replacement(prob_1d_shape,
|
||||||
|
ov::element::Type_t::i64,
|
||||||
|
std::vector<int64_t>{3, 2, 1, 0});
|
||||||
|
reference_tests::Tensor output_1d_log_replacement(prob_1d_shape,
|
||||||
|
ov::element::Type_t::i64,
|
||||||
|
std::vector<int64_t>{1, 2, 3, 0});
|
||||||
|
|
||||||
|
std::vector<MultinomialParams> params;
|
||||||
|
// probabilities, num_samples, output, convert_type, log_probs, with_replacement, name
|
||||||
|
params.emplace_back(probabilities_2d_no_log,
|
||||||
|
num_samples,
|
||||||
|
output_2d_no_log_no_replacement,
|
||||||
|
ov::element::Type_t::i32,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
"input_2d");
|
||||||
|
params.emplace_back(probabilities_2d_log,
|
||||||
|
num_samples,
|
||||||
|
output_2d_log_no_replacement,
|
||||||
|
ov::element::Type_t::i32,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
"input_2d");
|
||||||
|
params.emplace_back(probabilities_1d_no_log,
|
||||||
|
num_samples,
|
||||||
|
output_1d_no_log_replacement,
|
||||||
|
ov::element::Type_t::i64,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
"input_1d");
|
||||||
|
params.emplace_back(probabilities_1d_log,
|
||||||
|
num_samples,
|
||||||
|
output_1d_log_replacement,
|
||||||
|
ov::element::Type_t::i64,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
"input_1d");
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MultinomialParams> generateMultinomialParams() {
|
||||||
|
std::vector<std::vector<MultinomialParams>> combo_params{generateMultinomialParams<ov::element::f32>()};
|
||||||
|
std::vector<MultinomialParams> test_params;
|
||||||
|
for (auto& params : combo_params)
|
||||||
|
std::move(params.begin(), params.end(), std::back_inserter(test_params));
|
||||||
|
return test_params;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_P(ReferenceMultinomial, CompareWithRefs) {
|
||||||
|
Exec();
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke,
|
||||||
|
ReferenceMultinomial,
|
||||||
|
::testing::ValuesIn(generateMultinomialParams()),
|
||||||
|
ReferenceMultinomial::getTestCaseName);
|
@ -606,6 +606,15 @@ std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v0::MatMul> &n
|
|||||||
return std::make_shared<ov::Model>(results, params, "MatMul-1");
|
return std::make_shared<ov::Model>(results, params, "MatMul-1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v13::Multinomial>& node) {
|
||||||
|
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 5}}),
|
||||||
|
std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::Shape{1})};
|
||||||
|
auto multinomial =
|
||||||
|
std::make_shared<ov::op::v13::Multinomial>(params[0], params[1], ov::element::i32, false, false, 0, 0);
|
||||||
|
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(multinomial)};
|
||||||
|
return std::make_shared<ov::Model>(results, params, "Multinomial-13");
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v13::NMSRotated> &node) {
|
std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v13::NMSRotated> &node) {
|
||||||
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 6, 5}}),
|
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 6, 5}}),
|
||||||
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 1, 6}}),
|
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{1, 1, 6}}),
|
||||||
|
Loading…
Reference in New Issue
Block a user