[TF FE] Support Unique operation via transformation (#13463)
* [TF FE] Support Unique operation via transformation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix other part: place header for UniqueReplacer and test name * Remove redundant included headers in the transformation test * Add one more case in the transformation test * Implement the right way for mapping tensor names * Apply code-review: rename out_idx, add check for 1D * Fix unique_replacer test * Fix typo in the test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add FunctionsComparator CmpValues ACCURACY to the test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
2d8faed15b
commit
b08fa945bc
@ -8,6 +8,7 @@
|
||||
#include "helper_transforms/block_lstm_replacer.hpp"
|
||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
||||
#include "helper_transforms/unique_replacer.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
|
||||
@ -429,9 +430,10 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
|
||||
// Runs middle transformations to convert sub-graphs with intermediate (frontend internal) operations
|
||||
// into sub-graphs with only OpenVINO operations
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::GRUBlockCellReplacer>();
|
||||
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||
manager.register_pass<pass::UniqueReplacer>();
|
||||
|
||||
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
|
||||
// and preserve tensor names in case of sinking
|
||||
|
@ -21,7 +21,7 @@ public:
|
||||
ov::element::Type output_indices_type,
|
||||
const std::shared_ptr<DecoderBase>& decoder = nullptr)
|
||||
: ov::frontend::tensorflow::InternalOperation(decoder, OutputVector{input_values}, 2),
|
||||
out_idx(output_indices_type) {
|
||||
m_output_indices_type(output_indices_type) {
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -33,11 +33,15 @@ public:
|
||||
// 0) 1D tensor of unique elements
|
||||
// 1) 1D tensor of indices of the unique elements in the input
|
||||
set_output_type(0, get_input_element_type(0), ov::PartialShape({ov::Dimension::dynamic()}));
|
||||
set_output_type(1, out_idx, ov::PartialShape({ov::Dimension::dynamic()}));
|
||||
set_output_type(1, m_output_indices_type, ov::PartialShape({ov::Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
ov::element::Type get_output_indices_type() const {
|
||||
return m_output_indices_type;
|
||||
}
|
||||
|
||||
private:
|
||||
ov::element::Type out_idx;
|
||||
ov::element::Type m_output_indices_type;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
@ -0,0 +1,130 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_transforms/unique_replacer.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "helper_ops/unique.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::pass;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::frontend::tensorflow;
|
||||
|
||||
ov::frontend::tensorflow::pass::UniqueReplacer::UniqueReplacer() {
|
||||
auto unique = pattern::wrap_type<Unique>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& matcher) {
|
||||
NodeRegistry rg;
|
||||
|
||||
auto unique_node = std::dynamic_pointer_cast<Unique>(matcher.get_match_root());
|
||||
if (!unique_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto x = unique_node->input_value(0);
|
||||
auto output_indices_type = unique_node->get_output_indices_type();
|
||||
auto x_type = x.get_element_type();
|
||||
if (!x_type.is_real() && !x_type.is_integral_number()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// denote a number of elements in x as n
|
||||
auto n = get_elements_number_1d(x, element::i32, rg);
|
||||
|
||||
// create auxiliry constants to be re-used by different operations
|
||||
auto zero_const = rg.make<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = rg.make<Constant>(element::i32, Shape{1}, 1);
|
||||
auto one_const_scalar = rg.make<Constant>(element::i32, Shape{}, 1);
|
||||
auto minus_one_const = rg.make<Constant>(element::i32, Shape{1}, -1);
|
||||
auto true_const = rg.make<Constant>(element::boolean, Shape{1}, true);
|
||||
auto one_const_out_idx = rg.make<Constant>(output_indices_type, Shape{1}, 1);
|
||||
auto zero_const_out_idx = rg.make<Constant>(output_indices_type, Shape{1}, 0);
|
||||
|
||||
// compute unique elements but not in the original order
|
||||
// 1. sort elements in x in order to compute unique elements
|
||||
auto x_sorted = rg.make<TopK>(x, n, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32);
|
||||
// 2. generate two vectors from x_sorted vector by padding in the beginning and in the end:
|
||||
// x1 = [0, x0, x1, ..., xn]
|
||||
// x2 = [x0, x1, ..., xn, 0]
|
||||
auto pad = rg.make<Constant>(x_type, Shape{1}, 0);
|
||||
auto x1 = rg.make<Concat>(OutputVector{pad, x_sorted->output(0)}, 0);
|
||||
auto x2 = rg.make<Concat>(OutputVector{x_sorted->output(0), pad}, 0);
|
||||
// 3. compare two vectors to see where unique elements are placed
|
||||
// and correct a mask because the first element is always unique
|
||||
// because the latest boolean element must be removed from the mask since
|
||||
// the vectors are padded
|
||||
auto mask1 = rg.make<NotEqual>(x1, x2);
|
||||
auto mask1_part = rg.make<Slice>(mask1, one_const, minus_one_const, one_const, zero_const);
|
||||
auto is_unique = rg.make<Concat>(OutputVector{true_const, mask1_part}, 0);
|
||||
// 5. compute positions where unique elements are placed in the sorted x
|
||||
auto is_unique_01 = rg.make<Select>(is_unique, one_const, zero_const);
|
||||
auto indices = rg.make<NonZero>(is_unique_01, element::i64);
|
||||
auto unique_element_indices = rg.make<Squeeze>(indices, zero_const);
|
||||
// 6. collect unique elements but currently they are not in the original order
|
||||
auto unique_elements = rg.make<Gather>(x_sorted->output(0), unique_element_indices, zero_const);
|
||||
|
||||
// compute unique elements in the original order
|
||||
auto unsqueeze_x = rg.make<Unsqueeze>(x, zero_const);
|
||||
auto unsqueeze_unique_elements = rg.make<Unsqueeze>(unique_elements, one_const);
|
||||
// 1. compute a mask of pair comparison where each unique element is placed in the original
|
||||
auto nplus1 = rg.make<Add>(n, one_const_scalar);
|
||||
auto unique_vs_x = rg.make<Equal>(unsqueeze_unique_elements, unsqueeze_x);
|
||||
auto unique_vs_x_01 = rg.make<Select>(unique_vs_x, one_const_scalar, nplus1);
|
||||
auto range_1nplus1 = rg.make<Range>(one_const_scalar, nplus1, one_const_scalar, element::i32);
|
||||
auto unsqueeze_range_1nplus1 = rg.make<Unsqueeze>(range_1nplus1, zero_const);
|
||||
// 2. compute a mask with indices counting from one
|
||||
auto unique_vs_x_ind = rg.make<Multiply>(unique_vs_x_01, unsqueeze_range_1nplus1);
|
||||
// 3. compute positions of the first occurence for each unique element
|
||||
// or these are positions of unique elements in the original order
|
||||
auto minimum_indices_plus1 = rg.make<ReduceMin>(unique_vs_x_ind, one_const);
|
||||
auto minimum_indices = rg.make<Subtract>(minimum_indices_plus1, one_const);
|
||||
// denote a number of unique elements as m
|
||||
auto m = get_elements_number_1d(minimum_indices, element::i32, rg);
|
||||
auto sorted_minumum_indices =
|
||||
rg.make<TopK>(minimum_indices, m, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32);
|
||||
auto output_unique_elements = rg.make<Gather>(x, sorted_minumum_indices->output(0), zero_const);
|
||||
|
||||
if (!unique_node->get_output_target_inputs(0).empty()) {
|
||||
output_unique_elements->set_friendly_name(unique_node->get_friendly_name() + ":0");
|
||||
unique_node->output(0).replace(output_unique_elements->output(0));
|
||||
}
|
||||
|
||||
if (!unique_node->get_output_target_inputs(1).empty()) {
|
||||
// compute the second output
|
||||
// indices of elements of x in the vector of unique elements
|
||||
// 1. compute a mask for unique elements in the original order
|
||||
auto unsqueeze_output_unique_elements = rg.make<Unsqueeze>(output_unique_elements, one_const);
|
||||
auto unique_vs_x_orig = rg.make<Equal>(unsqueeze_output_unique_elements, unsqueeze_x);
|
||||
auto mplus1 = rg.make<Add>(m, one_const_scalar);
|
||||
auto unique_vs_x_orig_01 = rg.make<Select>(unique_vs_x_orig, one_const_out_idx, zero_const_out_idx);
|
||||
// 2. compute positions where each element from x is located in unique elements vector
|
||||
// the position counts from 1
|
||||
auto range_1mplus1 = rg.make<Range>(one_const_scalar, mplus1, one_const_scalar, output_indices_type);
|
||||
auto unsqueeze_range_1mplus1 = rg.make<Unsqueeze>(range_1mplus1, one_const);
|
||||
auto unique_vs_x_ind_orig = rg.make<Multiply>(unique_vs_x_orig_01, unsqueeze_range_1mplus1);
|
||||
auto output_idx_plus1 = rg.make<ReduceMax>(unique_vs_x_ind_orig, zero_const);
|
||||
auto output_idx = rg.make<Subtract>(output_idx_plus1, one_const_out_idx);
|
||||
|
||||
output_idx->set_friendly_name(unique_node->get_friendly_name() + ":1");
|
||||
unique_node->output(1).replace(output_idx->output(0));
|
||||
}
|
||||
|
||||
copy_runtime_info(unique_node, rg.get());
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(unique, "ov::frontend::tensorflow::pass::UniqueReplacer");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/frontend/tensorflow/visibility.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
// This transformation expresses Unique with a sub-graph of OpenVINO operations
|
||||
class TENSORFLOW_API UniqueReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::UniqueReplacer");
|
||||
UniqueReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -115,18 +115,6 @@ OutputVector translate_sparse_segment_sum_op(const NodeContext& node) {
|
||||
return sparse_segment_sum->outputs();
|
||||
}
|
||||
|
||||
OutputVector translate_unique_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"Unique"});
|
||||
auto input_values = node.get_input(0);
|
||||
|
||||
// retrieve attribute
|
||||
auto output_indices_type = node.get_attribute<ov::element::Type>("out_idx", ov::element::i32);
|
||||
|
||||
auto unique = make_shared<ov::frontend::tensorflow::Unique>(input_values, output_indices_type, node.get_decoder());
|
||||
set_node_name(node.get_name(), unique);
|
||||
return unique->outputs();
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
30
src/frontends/tensorflow/src/op/unique.cpp
Normal file
30
src/frontends/tensorflow/src/op/unique.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_ops/unique.hpp"
|
||||
|
||||
#include "op_table.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_unique_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"Unique"});
|
||||
auto input_values = node.get_input(0);
|
||||
|
||||
// retrieve attribute
|
||||
auto output_indices_type = node.get_attribute<ov::element::Type>("out_idx", ov::element::i32);
|
||||
|
||||
auto unique = make_shared<ov::frontend::tensorflow::Unique>(input_values, output_indices_type, node.get_decoder());
|
||||
set_node_name(node.get_name(), unique);
|
||||
return unique->outputs();
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -198,3 +198,16 @@ void ov::frontend::tensorflow::default_op_checks(const ov::frontend::tensorflow:
|
||||
bool ov::frontend::tensorflow::is_conditional_edge(const std::string& input_tensor_name) {
|
||||
return input_tensor_name.length() > 0 && input_tensor_name[0] == '^';
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> ov::frontend::tensorflow::get_elements_number_1d(const ov::Output<ov::Node>& output,
|
||||
ov::element::Type output_type,
|
||||
ov::pass::NodeRegistry& rg) {
|
||||
auto output_rank = output.get_partial_shape().rank();
|
||||
if (output_rank.is_static() && output_rank.get_length() != 1) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false,
|
||||
"Internal error: get_elements_number_1d method supports only 1D input tensor.");
|
||||
}
|
||||
auto shape = rg.make<ShapeOf>(output, output_type);
|
||||
auto num_elements = rg.make<Squeeze>(shape);
|
||||
return num_elements;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/frontend/tensorflow/node_context.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -24,34 +25,6 @@ void set_node_name(const std::string& node_name, const std::shared_ptr<Node>& no
|
||||
|
||||
bool is_conditional_edge(const std::string& input_tensor_name);
|
||||
|
||||
static bool vec_str_cmp(const std::vector<std::string>& a, const std::vector<std::string>& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void make_padding(const std::string& tf_padding_type,
|
||||
const ov::Shape& ng_image_shape,
|
||||
const ov::Shape& ng_kernel_shape,
|
||||
const ov::Strides& ng_strides,
|
||||
const ov::Shape& ng_dilations,
|
||||
T& ng_padding_below,
|
||||
T& ng_padding_above) {
|
||||
if (tf_padding_type == "SAME") {
|
||||
ov::Shape img_shape = {0, 0};
|
||||
img_shape.insert(img_shape.end(), ng_image_shape.begin(), ng_image_shape.end());
|
||||
ov::infer_auto_padding(img_shape,
|
||||
ng_kernel_shape,
|
||||
ng_strides,
|
||||
ng_dilations,
|
||||
ov::op::PadType::SAME_UPPER,
|
||||
ng_padding_above,
|
||||
ng_padding_below);
|
||||
} else if (tf_padding_type == "VALID") {
|
||||
ng_padding_below.assign(ng_image_shape.size(), 0);
|
||||
ng_padding_above.assign(ng_image_shape.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_const_input(const NodeContext& node, int64_t input_index, std::vector<T>* vector) {
|
||||
auto ng_input = node.get_input(static_cast<int>(input_index));
|
||||
@ -75,6 +48,10 @@ void fill_explicit_pads_vectors(const NodeContext& node,
|
||||
|
||||
void default_op_checks(const NodeContext& node, int min_input_size, const std::vector<std::string>& supported_ops);
|
||||
|
||||
ov::Output<Node> get_elements_number_1d(const Output<Node>& output,
|
||||
ov::element::Type output_type,
|
||||
ov::pass::NodeRegistry& rg);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
134
src/frontends/tensorflow/tests/unique_replacer.cpp
Normal file
134
src/frontends/tensorflow/tests/unique_replacer.cpp
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_transforms/unique_replacer.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <frontend/shared/include/utils.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "helper_ops/unique.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace opset9;
|
||||
using namespace element;
|
||||
using namespace frontend::tensorflow;
|
||||
using namespace frontend::tensorflow::pass;
|
||||
|
||||
namespace {
|
||||
Output<Node> get_elements_number_1d(const Output<Node>& output, element::Type output_type) {
|
||||
auto shape = make_shared<ShapeOf>(output, output_type);
|
||||
auto num_elements = make_shared<Squeeze>(shape);
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_model(PartialShape input_shape, element::Type out_idx) {
|
||||
auto x = make_shared<Parameter>(f32, input_shape);
|
||||
auto unique = make_shared<Unique>(x, out_idx);
|
||||
|
||||
return make_shared<Model>(OutputVector{unique->output(0), unique->output(1)}, ParameterVector{x});
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_model_ref(PartialShape input_shape, element::Type out_idx) {
|
||||
auto x = make_shared<Parameter>(f32, input_shape);
|
||||
|
||||
// denote a number of elements in x as n
|
||||
auto n = get_elements_number_1d(x, element::i32);
|
||||
|
||||
// create auxiliry constants to be re-used by different operations
|
||||
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto one_const_scalar = make_shared<Constant>(element::i32, Shape{}, 1);
|
||||
auto minus_one_const = make_shared<Constant>(element::i32, Shape{1}, -1);
|
||||
auto true_const = make_shared<Constant>(element::boolean, Shape{1}, true);
|
||||
auto one_const_out_idx = make_shared<Constant>(out_idx, Shape{1}, 1);
|
||||
auto zero_const_out_idx = make_shared<Constant>(out_idx, Shape{1}, 0);
|
||||
|
||||
// compute unique elements but not in the original order
|
||||
// 1. sort elements in x in order to compute unique elements
|
||||
auto x_sorted = make_shared<TopK>(x, n, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32);
|
||||
// 2. generate two vectors from x_sorted vector by padding in the beginning and in the end:
|
||||
// x1 = [0, x0, x1, ..., xn]
|
||||
// x2 = [x0, x1, ..., xn, 0]
|
||||
auto pad = make_shared<Constant>(x->get_element_type(), Shape{1}, 0);
|
||||
auto x1 = make_shared<Concat>(OutputVector{pad, x_sorted->output(0)}, 0);
|
||||
auto x2 = make_shared<Concat>(OutputVector{x_sorted->output(0), pad}, 0);
|
||||
// 3. compare two vectors to see where unique elements are placed
|
||||
// and correct a mask because the first element is always unique
|
||||
// because the latest boolean element must be removed from the mask since
|
||||
// the vectors are padded
|
||||
auto mask1 = make_shared<NotEqual>(x1, x2);
|
||||
auto mask1_part = make_shared<Slice>(mask1, one_const, minus_one_const, one_const, zero_const);
|
||||
auto is_unique = make_shared<Concat>(OutputVector{true_const, mask1_part}, 0);
|
||||
// 5. compute positions where unique elements are placed in the sorted x
|
||||
auto is_unique_01 = make_shared<Select>(is_unique, one_const, zero_const);
|
||||
auto indices = make_shared<NonZero>(is_unique_01, element::i64);
|
||||
auto unique_element_indices = make_shared<Squeeze>(indices, zero_const);
|
||||
// 6. collect unique elements but currently they are not in the original order
|
||||
auto unique_elements = make_shared<Gather>(x_sorted->output(0), unique_element_indices, zero_const);
|
||||
|
||||
// compute unique elements in the original order
|
||||
auto unsqueeze_x = make_shared<Unsqueeze>(x, zero_const);
|
||||
auto unsqueeze_unique_elements = make_shared<Unsqueeze>(unique_elements, one_const);
|
||||
// 1. compute a mask of pair comparison where each unique element is placed in the original
|
||||
auto nplus1 = make_shared<Add>(n, one_const_scalar);
|
||||
auto unique_vs_x = make_shared<Equal>(unsqueeze_unique_elements, unsqueeze_x);
|
||||
auto unique_vs_x_01 = make_shared<Select>(unique_vs_x, one_const_scalar, nplus1);
|
||||
auto range_1nplus1 = make_shared<Range>(one_const_scalar, nplus1, one_const_scalar, element::i32);
|
||||
auto unsqueeze_range_1nplus1 = make_shared<Unsqueeze>(range_1nplus1, zero_const);
|
||||
// 2. compute a mask with indices counting from one
|
||||
auto unique_vs_x_ind = make_shared<Multiply>(unique_vs_x_01, unsqueeze_range_1nplus1);
|
||||
// 3. compute positions of the first occurence for each unique element
|
||||
// or these are positions of unique elements in the original order
|
||||
auto minimum_indices_plus1 = make_shared<ReduceMin>(unique_vs_x_ind, one_const);
|
||||
auto minimum_indices = make_shared<Subtract>(minimum_indices_plus1, one_const);
|
||||
// denote a number of unique elements as m
|
||||
auto m = get_elements_number_1d(minimum_indices, element::i32);
|
||||
auto sorted_minumum_indices =
|
||||
make_shared<TopK>(minimum_indices, m, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32);
|
||||
auto output_unique_elements = make_shared<Gather>(x, sorted_minumum_indices->output(0), zero_const);
|
||||
|
||||
// compute the second output
|
||||
// indices of elements of x in the vector of unique elements
|
||||
// 1. compute a mask for unique elements in the original order
|
||||
auto unsqueeze_output_unique_elements = make_shared<Unsqueeze>(unique_elements, one_const);
|
||||
auto unique_vs_x_orig = make_shared<Equal>(unsqueeze_output_unique_elements, unsqueeze_x);
|
||||
auto unique_vs_x_orig_01 = make_shared<Select>(unique_vs_x_orig, one_const_out_idx, zero_const_out_idx);
|
||||
// 2. compute positions where each element from x is located in unique elements vector
|
||||
// the position counts from 1
|
||||
auto mplus1 = make_shared<Add>(m, one_const_scalar);
|
||||
auto range_1mplus1 = make_shared<Range>(one_const_scalar, mplus1, one_const_scalar, out_idx);
|
||||
auto unsqueeze_range_1mplus1 = make_shared<Unsqueeze>(range_1mplus1, one_const);
|
||||
auto unique_vs_x_ind_orig = make_shared<Multiply>(unique_vs_x_orig_01, unsqueeze_range_1mplus1);
|
||||
auto output_idx_plus1 = make_shared<ReduceMax>(unique_vs_x_ind_orig, zero_const);
|
||||
auto output_idx = make_shared<Subtract>(output_idx_plus1, one_const_out_idx);
|
||||
|
||||
return make_shared<Model>(OutputVector{output_idx->output(0), output_unique_elements->output(0)},
|
||||
ParameterVector{x});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, UniqueReplacerInt32) {
|
||||
{
|
||||
function = gen_model(PartialShape{10}, element::i32);
|
||||
manager.register_pass<UniqueReplacer>();
|
||||
}
|
||||
{ function_ref = gen_model_ref(PartialShape{10}, element::i32); }
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, UniqueReplacerInt64) {
|
||||
{
|
||||
function = gen_model(PartialShape{42}, element::i64);
|
||||
manager.register_pass<UniqueReplacer>();
|
||||
}
|
||||
{ function_ref = gen_model_ref(PartialShape{42}, element::i64); }
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
@ -83,7 +83,8 @@ class CommonLayerTest:
|
||||
else:
|
||||
ie_engine = InferAPI20(model=path_to_xml,
|
||||
weights=path_to_bin,
|
||||
device=ie_device)
|
||||
device=ie_device,
|
||||
use_new_frontend=use_new_frontend)
|
||||
# Prepare feed dict
|
||||
if 'kwargs_to_prepare_input' in kwargs and kwargs['kwargs_to_prepare_input']:
|
||||
inputs_dict = self._prepare_input(ie_engine.get_inputs_info(precision),
|
||||
|
@ -71,11 +71,12 @@ class IEInfer(BaseInfer):
|
||||
|
||||
|
||||
class InferAPI20(BaseInfer):
|
||||
def __init__(self, model, weights, device):
|
||||
def __init__(self, model, weights, device, use_new_frontend):
|
||||
super().__init__('Inference Engine')
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.weights = weights
|
||||
self.use_new_frontend = use_new_frontend
|
||||
|
||||
def fw_infer(self, input_data):
|
||||
print("Inference Engine version: {}".format(ie2_get_version()))
|
||||
@ -94,8 +95,18 @@ class InferAPI20(BaseInfer):
|
||||
# all input and output tensors have to be named
|
||||
assert out_obj.names, "Output tensor {} has no names".format(out_obj)
|
||||
|
||||
tensor_name = out_obj.get_any_name().split(':')[0]
|
||||
result[tensor_name] = out_tensor
|
||||
# For the new frontend we make this the right way because
|
||||
# we know that tensor can have several names due to fusing
|
||||
# and one of them the framework uses
|
||||
if self.use_new_frontend:
|
||||
for tensor_name in out_obj.get_names():
|
||||
result[tensor_name] = out_tensor
|
||||
else:
|
||||
# do not change behaviour for mapping tensor names
|
||||
# between the original framework and OpenVINO
|
||||
# because it leads to fixing this functionality in the legacy frontend
|
||||
tensor_name = out_obj.get_any_name().split(':')[0]
|
||||
result[tensor_name] = out_tensor
|
||||
|
||||
if "exec_net" in locals():
|
||||
del exec_net
|
||||
|
60
tests/layer_tests/tensorflow_tests/test_tf_Unique.py
Normal file
60
tests/layer_tests/tensorflow_tests/test_tf_Unique.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestUnique(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info, "Test error: inputs_info must contain `x`"
|
||||
x_shape = inputs_info['x']
|
||||
inputs_data = {}
|
||||
inputs_data['x'] = np.random.randint(-10, 10, x_shape)
|
||||
return inputs_data
|
||||
|
||||
def create_unique_net(self, x_shape, data_type, out_idx):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(data_type, x_shape, 'x')
|
||||
unique = tf.unique(x, out_idx)
|
||||
tf.identity(unique[0], name='unique_elements')
|
||||
tf.identity(unique[1], name='unique_indices')
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(x_shape=[50], data_type=tf.float32, out_idx=tf.int32),
|
||||
dict(x_shape=[100], data_type=tf.float32, out_idx=tf.int64),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
def test_unique_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
if not use_new_frontend:
|
||||
pytest.skip("Unique operation is not supported via legacy frontend.")
|
||||
self._test(*self.create_unique_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
test_data_other_types = [
|
||||
dict(x_shape=[10], data_type=tf.int32, out_idx=tf.int32),
|
||||
dict(x_shape=[4], data_type=tf.int64, out_idx=tf.int32),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_other_types)
|
||||
@pytest.mark.nightly
|
||||
def test_unique_other_types(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
if not use_new_frontend:
|
||||
pytest.skip("Unique operation is not supported via legacy frontend.")
|
||||
self._test(*self.create_unique_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user