[BUGFIX][Core][Template] Multinomial shape filling for 1D input (#20359)

* [BUGFIX] Fix incorrect shape filling for 1D tensor smaller than requested sample size

* [FIX] Remove redeclaration
This commit is contained in:
Piotr Krzemiński 2023-10-11 00:51:59 +02:00 committed by GitHub
parent df55e282e3
commit e24b6211e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 26 deletions

View File

@ -113,8 +113,7 @@ void multinomial(const T* probs,
auto batch_size = probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[0]) : static_cast<size_t>(1);
auto class_size =
probs_shape.size() == 2 ? static_cast<size_t>(probs_shape[1]) : static_cast<size_t>(probs_shape[0]);
auto samples_size =
probs_shape.size() == 2 ? static_cast<size_t>(num_samples[0]) : static_cast<size_t>(probs_shape[0]);
auto samples_size = static_cast<size_t>(num_samples[0]);
// Iterate over each channel in uniform samples
std::vector<U> output_samples(total_output_elements_count);
@ -132,8 +131,8 @@ void multinomial(const T* probs,
break;
}
}
// Additional step with replacement - change probability of a given class to 0, and update the cdf
if (with_replacement) {
// Additional step without replacement - change probability of a given class to 0, and update the cdf
if (!with_replacement) {
T class_probability = selected_class_idx ? cdf[i_translated + selected_class_idx] -
cdf[i_translated + selected_class_idx - 1]
: cdf[i_translated + selected_class_idx];

View File

@ -116,12 +116,13 @@ namespace multinomial {
namespace validate {
void input_types(const Node* op) {
NODE_VALIDATION_CHECK(op,
op->get_input_element_type(0).is_real(),
op->get_input_element_type(0).is_real() || op->get_input_element_type(0).is_dynamic(),
"Expected floating point type as element type for the 'probs' input.");
NODE_VALIDATION_CHECK(op,
op->get_input_element_type(1).is_integral_number(),
"Expected integer type as element type for the 'num_samples' input.");
NODE_VALIDATION_CHECK(
op,
op->get_input_element_type(1).is_integral_number() || op->get_input_element_type(1).is_dynamic(),
"Expected integer type as element type for the 'num_samples' input.");
}
} // namespace validate
} // namespace multinomial

View File

@ -86,8 +86,11 @@ std::vector<MultinomialParams> generateMultinomialParams() {
const ov::Shape prob_2d_shape{2, 4};
const ov::Shape prob_1d_shape{4};
const ov::Shape num_samples_shape{1};
const ov::Shape prob_1d_shape_expand_small{2};
const ov::Shape out_1d_shape_expand_big{16};
reference_tests::Tensor num_samples(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{4});
reference_tests::Tensor num_samples_big(num_samples_shape, ov::element::Type_t::i32, std::vector<int32_t>{16});
reference_tests::Tensor probabilities_2d_no_log(prob_2d_shape,
et,
@ -95,50 +98,61 @@ std::vector<MultinomialParams> generateMultinomialParams() {
reference_tests::Tensor probabilities_2d_log(prob_2d_shape, et, std::vector<vt>{1, 2, 3, 4, 2, 4, 6, 8});
reference_tests::Tensor probabilities_1d_no_log(prob_1d_shape, et, std::vector<vt>{0.001, 0.01, 0.1, 0.899});
reference_tests::Tensor probabilities_1d_log(prob_1d_shape, et, std::vector<vt>{1, 10, 7, 3});
reference_tests::Tensor probabilities_1d_expand(prob_1d_shape_expand_small, et, std::vector<vt>{0.00001, 0.99999});
reference_tests::Tensor output_2d_no_log_no_replacement(prob_2d_shape,
ov::element::Type_t::i32,
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
reference_tests::Tensor output_2d_log_no_replacement(prob_2d_shape,
reference_tests::Tensor output_2d_no_log_replacement(prob_2d_shape,
ov::element::Type_t::i32,
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
reference_tests::Tensor output_1d_no_log_replacement(prob_1d_shape,
std::vector<int32_t>{3, 3, 3, 3, 0, 0, 0, 0});
reference_tests::Tensor output_2d_log_replacement(prob_2d_shape,
ov::element::Type_t::i32,
std::vector<int32_t>{3, 3, 2, 3, 3, 3, 3, 3});
reference_tests::Tensor output_1d_no_log_no_replacement(prob_1d_shape,
ov::element::Type_t::i64,
std::vector<int64_t>{3, 2, 1, 0});
reference_tests::Tensor output_1d_log_no_replacement(prob_1d_shape,
ov::element::Type_t::i64,
std::vector<int64_t>{3, 2, 1, 0});
reference_tests::Tensor output_1d_log_replacement(prob_1d_shape,
ov::element::Type_t::i64,
std::vector<int64_t>{1, 2, 3, 0});
std::vector<int64_t>{1, 2, 3, 0});
reference_tests::Tensor output_1d_expand(out_1d_shape_expand_big,
ov::element::Type_t::i64,
std::vector<int64_t>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
std::vector<MultinomialParams> params;
// probabilities, num_samples, output, convert_type, log_probs, with_replacement, name
params.emplace_back(probabilities_2d_no_log,
num_samples,
output_2d_no_log_no_replacement,
output_2d_no_log_replacement,
ov::element::Type_t::i32,
false,
false,
true,
"input_2d");
params.emplace_back(probabilities_2d_log,
num_samples,
output_2d_log_no_replacement,
output_2d_log_replacement,
ov::element::Type_t::i32,
true,
false,
true,
"input_2d");
params.emplace_back(probabilities_1d_no_log,
num_samples,
output_1d_no_log_replacement,
output_1d_no_log_no_replacement,
ov::element::Type_t::i64,
false,
true,
false,
"input_1d");
params.emplace_back(probabilities_1d_log,
num_samples,
output_1d_log_replacement,
output_1d_log_no_replacement,
ov::element::Type_t::i64,
true,
true,
false,
"input_1d");
params.emplace_back(probabilities_1d_expand,
num_samples_big,
output_1d_expand,
ov::element::Type_t::i64,
false,
true,
"input_1d_expand");
return params;
}