[BUGFIX][Core][Template] Multinomial shape filling for 1D input (#20359)
* [BUGFIX] Fix incorrect shape filling for 1D tensor smaller than requested sample size * [FIX] Remove redeclaration
This commit is contained in:
parent
df55e282e3
commit
e24b6211e3
@ -113,8 +113,7 @@ void multinomial(const T* probs,
|
||||
auto batch_size = probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[0]) : static_cast<size_t>(1);
|
||||
auto class_size =
|
||||
probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[1]) : static_cast<size_t>(probs_shape[0]);
|
||||
auto samples_size =
|
||||
probs_shape.size() == 2 ? static_cast<size_t>(num_samples[0]) : static_cast<size_t>(probs_shape[0]);
|
||||
auto samples_size = static_cast<size_t>(num_samples[0]);
|
||||
|
||||
// Iterate over each channel in uniform samples
|
||||
std::vector<U> output_samples(total_output_elements_count);
|
||||
@ -132,8 +131,8 @@ void multinomial(const T* probs,
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Additional step with replacement - change probability of a given class to 0, and update the cdf
|
||||
if (with_replacement) {
|
||||
// Additional step without replacement - change probability of a given class to 0, and update the cdf
|
||||
if (!with_replacement) {
|
||||
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
|
||||
cdf[i_translated + selected_class_idx - 1]
|
||||
: cdf[i_translated + selected_class_idx];
|
||||
|
@ -116,12 +116,13 @@ namespace multinomial {
|
||||
namespace validate {
|
||||
void input_types(const Node* op) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
op->get_input_element_type(0).is_real(),
|
||||
op->get_input_element_type(0).is_real() || op->get_input_element_type(0).is_dynamic(),
|
||||
"Expected floating point type as element type for the 'probs' input.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
op->get_input_element_type(1).is_integral_number(),
|
||||
"Expected integer type as element type for the 'num_samples' input.");
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
op->get_input_element_type(1).is_integral_number() || op->get_input_element_type(1).is_dynamic(),
|
||||
"Expected integer type as element type for the 'num_samples' input.");
|
||||
}
|
||||
} // namespace validate
|
||||
} // namespace multinomial
|
||||
|
@ -86,8 +86,11 @@ std::vector<MultinomialParams> generateMultinomialParams() {
|
||||
const ov::Shape prob_2d_shape{2, 4};
|
||||
const ov::Shape prob_1d_shape{4};
|
||||
const ov::Shape num_samples_shape{1};
|
||||
const ov::Shape prob_1d_shape_expand_small{2};
|
||||
const ov::Shape out_1d_shape_expand_big{16};
|
||||
|
||||
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
|
||||
reference_tests::Tensor num_samples_big(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{16});
|
||||
|
||||
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
|
||||
et,
|
||||
@ -95,50 +98,61 @@ std::vector<MultinomialParams> generateMultinomialParams() {
|
||||
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
|
||||
reference_tests::Tensor probabilities_1d_no_log(prob_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
|
||||
reference_tests::Tensor probabilities_1d_log(prob_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
|
||||
reference_tests::Tensor probabilities_1d_expand(prob_1d_shape_expand_small, et, std::vector<vt>{0.00001, 0.99999});
|
||||
|
||||
reference_tests::Tensor output_2d_no_log_no_replacement(prob_2d_shape,
|
||||
ov::element::Type_t::i32,
|
||||
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
|
||||
reference_tests::Tensor output_2d_log_no_replacement(prob_2d_shape,
|
||||
reference_tests::Tensor output_2d_no_log_replacement(prob_2d_shape,
|
||||
ov::element::Type_t::i32,
|
||||
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
|
||||
reference_tests::Tensor output_1d_no_log_replacement(prob_1d_shape,
|
||||
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
|
||||
reference_tests::Tensor output_2d_log_replacement(prob_2d_shape,
|
||||
ov::element::Type_t::i32,
|
||||
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
|
||||
reference_tests::Tensor output_1d_no_log_no_replacement(prob_1d_shape,
|
||||
ov::element::Type_t::i64,
|
||||
std::vector<int64_t>{3, 2, 1, 0});
|
||||
reference_tests::Tensor output_1d_log_no_replacement(prob_1d_shape,
|
||||
ov::element::Type_t::i64,
|
||||
std::vector<int64_t>{3, 2, 1, 0});
|
||||
reference_tests::Tensor output_1d_log_replacement(prob_1d_shape,
|
||||
ov::element::Type_t::i64,
|
||||
std::vector<int64_t>{1, 2, 3, 0});
|
||||
std::vector<int64_t>{1, 2, 3, 0});
|
||||
reference_tests::Tensor output_1d_expand(out_1d_shape_expand_big,
|
||||
ov::element::Type_t::i64,
|
||||
std::vector<int64_t>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
std::vector<MultinomialParams> params;
|
||||
// probabilities, num_samples, output, convert_type, log_probs, with_replacement, name
|
||||
params.emplace_back(probabilities_2d_no_log,
|
||||
num_samples,
|
||||
output_2d_no_log_no_replacement,
|
||||
output_2d_no_log_replacement,
|
||||
ov::element::Type_t::i32,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
"input_2d");
|
||||
params.emplace_back(probabilities_2d_log,
|
||||
num_samples,
|
||||
output_2d_log_no_replacement,
|
||||
output_2d_log_replacement,
|
||||
ov::element::Type_t::i32,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
"input_2d");
|
||||
params.emplace_back(probabilities_1d_no_log,
|
||||
num_samples,
|
||||
output_1d_no_log_replacement,
|
||||
output_1d_no_log_no_replacement,
|
||||
ov::element::Type_t::i64,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
"input_1d");
|
||||
params.emplace_back(probabilities_1d_log,
|
||||
num_samples,
|
||||
output_1d_log_replacement,
|
||||
output_1d_log_no_replacement,
|
||||
ov::element::Type_t::i64,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
"input_1d");
|
||||
params.emplace_back(probabilities_1d_expand,
|
||||
num_samples_big,
|
||||
output_1d_expand,
|
||||
ov::element::Type_t::i64,
|
||||
false,
|
||||
true,
|
||||
"input_1d_expand");
|
||||
return params;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user