[TF FE] Fix Wide and Deep model conversion (#14931)

* [TF FE] Fix Wide and Deep model conversion

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix build

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-01-05 21:29:39 +04:00 committed by GitHub
parent 10253c1b12
commit a38366a707
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 2092 additions and 1 deletions

View File

@ -14,6 +14,7 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::pass;
@ -92,7 +93,22 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe
auto zeros_like = make_shared<Broadcast>(make_shared<Constant>(ov::element::f32, Shape{1}, std::vector<int64_t>{0}),
make_shared<ShapeOf>(sparse_segment_op));
auto select_pattern = make_shared<Select>(tile, zeros_like, sparse_segment_op);
// compute number of dimensions to unsqueeze the condition
auto cond_rank = compute_subgraph_scalar_rank(tile, element::i32);
auto x_rank = compute_subgraph_scalar_rank(zeros_like, element::i32);
auto num_new_axes = make_shared<Subtract>(x_rank, cond_rank);
// generate a new shape for the condition
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
auto cond_shape = make_shared<ShapeOf>(tile, element::i32);
auto new_cond_shape = make_shared<Concat>(OutputVector{cond_shape, new_subshape}, 0);
// prepare the condition to have the same rank as operands `x` and `y`
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false);
auto select_pattern = make_shared<Select>(prep_cond, zeros_like, sparse_segment_op);
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

View File

@ -59,6 +59,24 @@ TEST(FrontEndConvertTrickyModels, undefined_input_shape) {
}
}
TEST(FrontEndConvertTrickyModels, simple_wide_and_deep) {
shared_ptr<Model> model;
try {
model = convert_model("simple_wide_and_deep/simple_wide_and_deep.pb");
} catch (std::exception& ex) {
ASSERT_TRUE(false) << ex.what();
}
int num_emb_segment_sum = 0;
for (auto& node : model->get_ordered_ops()) {
if (std::dynamic_pointer_cast<EmbeddingSegmentsSum>(node)) {
++num_emb_segment_sum;
}
}
ASSERT_EQ(num_emb_segment_sum, 1) << "The number of EmbeddingSegmentsSum nodes must be 1";
}
TEST(FrontEndConvertTrickyModels, model_with_output_shapes) {
shared_ptr<Model> model;
try {