[Unique-10] Reference implementation fixes (#14160)
* Add more tests for Unique op reference impl * Use stable sort to preserve elements order * Update tensor comparator and unify approach for sorted * Slices comparator revert and test * Fix slices_ascending_order comparator
This commit is contained in:
@@ -115,20 +115,24 @@ UniqueElements<Index_t, Count_t> find_unique_elements(const Data_t* data,
|
||||
|
||||
const auto slices_ascending_order = [&](const TensorSlice<Index_t, Count_t>& lhs,
|
||||
const TensorSlice<Index_t, Count_t>& rhs) {
|
||||
const auto slices_offset = calc_slices_offset(lhs, rhs, data_shape_strides, *axis);
|
||||
const auto shape_to_iterate = slice_shape_to_iterate(data_shape, *axis);
|
||||
|
||||
for (auto it = CoordinateIterator(shape_to_iterate); it != CoordinateIterator::end(); ++it) {
|
||||
auto elem_coord = *it;
|
||||
elem_coord.insert(elem_coord.cbegin() + *axis, lhs.idx);
|
||||
const auto lhs_elem_idx = ngraph::coordinate_index(elem_coord, data_shape);
|
||||
const auto rhs_elem_idx = lhs_elem_idx + slices_offset;
|
||||
if (*(data + rhs_elem_idx) > *(data + lhs_elem_idx)) {
|
||||
return false;
|
||||
auto elem_coord_lhs = *it;
|
||||
elem_coord_lhs.insert(elem_coord_lhs.cbegin() + *axis, lhs.idx);
|
||||
|
||||
auto elem_coord_rhs = *it;
|
||||
elem_coord_rhs.insert(elem_coord_rhs.cbegin() + *axis, rhs.idx);
|
||||
|
||||
const auto lhs_elem_idx = ngraph::coordinate_index(elem_coord_lhs, data_shape);
|
||||
const auto rhs_elem_idx = ngraph::coordinate_index(elem_coord_rhs, data_shape);
|
||||
|
||||
if (*(data + lhs_elem_idx) < *(data + rhs_elem_idx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
const auto elements_are_equal = [&data](const TensorSlice<Index_t, Count_t>& lhs,
|
||||
@@ -187,7 +191,7 @@ UniqueElements<Index_t, Count_t> find_unique_elements(const Data_t* data,
|
||||
generate_descriptors<Index_t, Count_t>(data_elems_count, DescriptorType::SINGLE_VALUE);
|
||||
|
||||
if (sorted) {
|
||||
std::sort(begin(ret.all_tensor_elements), end(ret.all_tensor_elements), ascending_order);
|
||||
std::stable_sort(begin(ret.all_tensor_elements), end(ret.all_tensor_elements), ascending_order);
|
||||
}
|
||||
|
||||
ret.all_tensor_elements[0].rev_idx = 0;
|
||||
@@ -219,15 +223,15 @@ UniqueElements<Index_t, Count_t> find_unique_elements(const Data_t* data,
|
||||
ret.all_tensor_elements = generate_descriptors<Index_t, Count_t>(data_shape[*axis], DescriptorType::SLICE);
|
||||
|
||||
if (sorted) {
|
||||
std::sort(begin(ret.all_tensor_elements), end(ret.all_tensor_elements), slices_ascending_order);
|
||||
std::stable_sort(begin(ret.all_tensor_elements), end(ret.all_tensor_elements), slices_ascending_order);
|
||||
}
|
||||
|
||||
ret.all_tensor_elements[0].rev_idx = 0;
|
||||
ret.unique_tensor_elements.push_back(ret.all_tensor_elements[0]);
|
||||
|
||||
for (size_t i = 1; i < data_shape[*axis]; ++i) {
|
||||
auto& tensor_element = ret.all_tensor_elements[i];
|
||||
auto existing_unique = end(ret.unique_tensor_elements);
|
||||
|
||||
if (sorted) {
|
||||
existing_unique = std::lower_bound(begin(ret.unique_tensor_elements),
|
||||
end(ret.unique_tensor_elements),
|
||||
|
||||
@@ -179,6 +179,24 @@ std::vector<UniqueParams> params_unique_int() {
|
||||
nullptr,
|
||||
true,
|
||||
"1D with duplicates"},
|
||||
UniqueParams{Shape{7},
|
||||
std::vector<Data_t>{3, 1, 5, 3, 2, 4, 2},
|
||||
std::vector<Data_t>{1, 2, 3, 4, 5},
|
||||
std::vector<Index_t>{1, 4, 0, 5, 2},
|
||||
std::vector<Index_t>{2, 0, 4, 2, 1, 3, 1},
|
||||
std::vector<int64_t>{1, 2, 2, 1, 1},
|
||||
nullptr,
|
||||
true,
|
||||
"1D with duplicates, sort 1st element"},
|
||||
UniqueParams{Shape{7},
|
||||
std::vector<Data_t>{3, 3, 5, 3, 2, 4, 2},
|
||||
std::vector<Data_t>{2, 3, 4, 5},
|
||||
std::vector<Index_t>{4, 0, 5, 2},
|
||||
std::vector<Index_t>{1, 1, 3, 1, 0, 2, 0},
|
||||
std::vector<int64_t>{2, 3, 1, 1},
|
||||
nullptr,
|
||||
true,
|
||||
"1D with duplicates in row, sort 1st element"},
|
||||
UniqueParams{Shape{7},
|
||||
std::vector<Data_t>{1, 3, 5, 3, 2, 4, 2},
|
||||
std::vector<Data_t>{1, 2, 3, 4, 5},
|
||||
@@ -265,7 +283,54 @@ std::vector<UniqueParams> params_unique_int() {
|
||||
std::vector<int64_t>{2, 1},
|
||||
make_axis(2),
|
||||
false,
|
||||
"3D with duplicates (1 & 2)"}};
|
||||
"3D with duplicates (1 & 2)"},
|
||||
UniqueParams{Shape{2, 2, 3},
|
||||
// the second and the third slice over axis 2 are equal
|
||||
std::vector<Data_t>{2, -1, -1, -3, 5, 5, -8, 7, 7, 4, 4, 4},
|
||||
std::vector<Data_t>{2, -1, -3, 5, -8, 7, 4, 4},
|
||||
std::vector<Index_t>{0, 1},
|
||||
std::vector<Index_t>{0, 1, 1},
|
||||
std::vector<int64_t>{1, 2},
|
||||
make_axis(2),
|
||||
false,
|
||||
"3D with duplicates (2 & 3)"},
|
||||
UniqueParams{Shape{2, 2, 3},
|
||||
// the second and the third slice over axis 2 are equal
|
||||
std::vector<Data_t>{2, -1, -1, -3, 5, 5, -8, 7, 7, 4, 4, 4},
|
||||
std::vector<Data_t>{-1, 2, 5, -3, 7, -8, 4, 4},
|
||||
std::vector<Index_t>{1, 0},
|
||||
std::vector<Index_t>{1, 0, 0},
|
||||
std::vector<int64_t>{2, 1},
|
||||
make_axis(2),
|
||||
true,
|
||||
"3D with duplicates (2 & 3), output sorted"},
|
||||
UniqueParams{Shape{2, 2, 3},
|
||||
// the second and the third slice over axis 2 are equal
|
||||
std::vector<Data_t>{-1, -1, -1, 3, 2, 2, 6, 7, 7, 4, 4, 4},
|
||||
std::vector<Data_t>{-1, -1, 2, 3, 7, 6, 4, 4},
|
||||
std::vector<Index_t>{1, 0},
|
||||
std::vector<Index_t>{1, 0, 0},
|
||||
std::vector<int64_t>{2, 1},
|
||||
make_axis(2),
|
||||
true,
|
||||
"3D with duplicates (2 & 3), first elements equal, output sorted"},
|
||||
UniqueParams{
|
||||
Shape{1, 3, 16},
|
||||
std::vector<Data_t>{15, -20, -11, 10, -21, 8, -15, -10, 7, 20, -19, -14, -13, -16, -7, -2,
|
||||
-17, -4, 21, -6, 11, 8, 17, 6, 7, 20, -3, 2, -13, -16, -23, 14,
|
||||
-1, 12, 5, -6, 11, -8, 1, -10, 23, 20, -19, 18, 3, -16, -7, 14},
|
||||
std::vector<Data_t>{-23, -21, -20, -19, -17, -16, -15, -14, -13, -11, -10, -8, -7, -6, -4, -3, -2, -1,
|
||||
1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 14, 15, 17, 18, 20, 21, 23},
|
||||
std::vector<Index_t>{30, 4, 1, 10, 16, 13, 6, 11, 12, 2, 7, 37, 14, 19, 17, 26, 15, 32,
|
||||
38, 27, 44, 34, 23, 8, 5, 3, 20, 33, 31, 0, 22, 43, 9, 18, 40},
|
||||
std::vector<Index_t>{29, 2, 9, 25, 1, 24, 6, 10, 23, 32, 3, 7, 8, 5, 12, 16,
|
||||
4, 14, 33, 13, 26, 24, 30, 22, 23, 32, 15, 19, 8, 5, 0, 28,
|
||||
17, 27, 21, 13, 26, 11, 18, 10, 34, 32, 3, 31, 20, 5, 12, 28},
|
||||
std::vector<int64_t>{1, 1, 1, 2, 1, 3, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 1, 1, 1, 3, 1, 1},
|
||||
nullptr,
|
||||
true,
|
||||
"3D flattened with duplicates, output sorted"}};
|
||||
|
||||
return flatten({std::move(scalar_and_1D), std::move(N_C_layout), std::move(N_D_layout)});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user