[TF FE] Fix Wide and Deep model conversion (#14931)
* [TF FE] Fix Wide and Deep model conversion Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
10253c1b12
commit
a38366a707
@ -14,6 +14,7 @@
|
|||||||
#include "openvino/pass/pattern/op/or.hpp"
|
#include "openvino/pass/pattern/op/or.hpp"
|
||||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
#include "transformations/utils/utils.hpp"
|
#include "transformations/utils/utils.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ov::pass;
|
using namespace ov::pass;
|
||||||
@ -92,7 +93,22 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe
|
|||||||
|
|
||||||
auto zeros_like = make_shared<Broadcast>(make_shared<Constant>(ov::element::f32, Shape{1}, std::vector<int64_t>{0}),
|
auto zeros_like = make_shared<Broadcast>(make_shared<Constant>(ov::element::f32, Shape{1}, std::vector<int64_t>{0}),
|
||||||
make_shared<ShapeOf>(sparse_segment_op));
|
make_shared<ShapeOf>(sparse_segment_op));
|
||||||
auto select_pattern = make_shared<Select>(tile, zeros_like, sparse_segment_op);
|
|
||||||
|
// compute number of dimensions to unsqueeze the condition
|
||||||
|
auto cond_rank = compute_subgraph_scalar_rank(tile, element::i32);
|
||||||
|
auto x_rank = compute_subgraph_scalar_rank(zeros_like, element::i32);
|
||||||
|
auto num_new_axes = make_shared<Subtract>(x_rank, cond_rank);
|
||||||
|
|
||||||
|
// generate a new shape for the condition
|
||||||
|
auto const_one = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||||
|
auto new_subshape = make_shared<Broadcast>(const_one, num_new_axes);
|
||||||
|
auto cond_shape = make_shared<ShapeOf>(tile, element::i32);
|
||||||
|
auto new_cond_shape = make_shared<Concat>(OutputVector{cond_shape, new_subshape}, 0);
|
||||||
|
|
||||||
|
// prepare the condition to have the same rank as operands `x` and `y`
|
||||||
|
auto prep_cond = make_shared<Reshape>(tile, new_cond_shape, false);
|
||||||
|
|
||||||
|
auto select_pattern = make_shared<Select>(prep_cond, zeros_like, sparse_segment_op);
|
||||||
|
|
||||||
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
|
@ -59,6 +59,24 @@ TEST(FrontEndConvertTrickyModels, undefined_input_shape) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FrontEndConvertTrickyModels, simple_wide_and_deep) {
|
||||||
|
shared_ptr<Model> model;
|
||||||
|
try {
|
||||||
|
model = convert_model("simple_wide_and_deep/simple_wide_and_deep.pb");
|
||||||
|
} catch (std::exception& ex) {
|
||||||
|
ASSERT_TRUE(false) << ex.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_emb_segment_sum = 0;
|
||||||
|
for (auto& node : model->get_ordered_ops()) {
|
||||||
|
if (std::dynamic_pointer_cast<EmbeddingSegmentsSum>(node)) {
|
||||||
|
++num_emb_segment_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(num_emb_segment_sum, 1) << "The number of EmbeddingSegmentsSum nodes must be 1";
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FrontEndConvertTrickyModels, model_with_output_shapes) {
|
TEST(FrontEndConvertTrickyModels, model_with_output_shapes) {
|
||||||
shared_ptr<Model> model;
|
shared_ptr<Model> model;
|
||||||
try {
|
try {
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user