[TF FE] Support SaveV2 operation (#15572)

* [TF FE] Support SaveV2 operation

Also, implement transformation to remove UnsupportedConstant to Result
isolated sub-graphs

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Revert change for Constant

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-02-08 14:35:39 +04:00 committed by GitHub
parent 82286cc2af
commit a350bd7e85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 236 additions and 12 deletions

View File

@ -8,6 +8,7 @@
#include "helper_transforms/block_lstm_replacer.hpp"
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
#include "helper_transforms/gru_block_cell_replacer.hpp"
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
@ -251,6 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
manager.register_pass<pass::BlockLSTMReplacer>();
manager.register_pass<pass::GRUBlockCellReplacer>();
manager.register_pass<pass::UnsupportedConstToResultRemover>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();

View File

@ -189,6 +189,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Roll", translate_roll_op},
{"Round", translate_round_op},
{"Rsqrt", translate_rsqrt_op},
{"SaveV2", translate_no_op},
{"ScatterNd", translate_scatter_nd_op},
{"SegmentSum", translate_segment_sum_op},
{"SparseToDense", translate_sparse_to_dense_op},

View File

@ -246,3 +246,19 @@ TEST_F(TransformationTestsF, DISABLED_ModelWithDilatedGroupConvolution) {
model_ref = make_shared<Model>(OutputVector{transpose_after}, ParameterVector{x});
}
}
TEST_F(TransformationTestsF, ModelWithSaveV2) {
{
model = convert_model("model_savev2/model_savev2.pb");
// need to call shape inference since body graphs can be injected with undefined shapes
model->validate_nodes_and_infer_types();
}
{
// create a reference graph
auto x = make_shared<Parameter>(element::f32, Shape{2});
auto const_2 = make_shared<Constant>(element::f32, Shape{2}, vector<float>{1, 2});
auto add = make_shared<Add>(x, const_2);
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x});
}
}

View File

@ -0,0 +1,152 @@
node {
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
}
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\000\000\200?\000\000\000@"
}
}
}
}
node {
name: "add"
op: "AddV2"
input: "x"
input: "Const"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "save/Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "model"
}
}
}
}
node {
name: "tensor_names"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "Const"
}
}
}
}
node {
name: "shape_and_slices"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: ""
}
}
}
}
node {
name: "SaveV2"
op: "SaveV2"
input: "save/Const"
input: "tensor_names"
input: "shape_and_slices"
input: "Const"
attr {
key: "dtypes"
value {
list {
type: DT_FLOAT
}
}
}
}
node {
name: "save/control_dependency"
op: "Identity"
input: "save/Const"
attr {
key: "T"
value {
type: DT_STRING
}
}
}

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
// This transformation removes isolated subgraph Unsupported constant going to the Result node
class UnsupportedConstToResultRemover : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::tensorflow::pass::UnsupportedConstToResultRemover");
UnsupportedConstToResultRemover() {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
#include "helper_ops/unsupported_constant.hpp"
using namespace std;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
bool UnsupportedConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
ResultVector results_to_remove;
// look for isolated UnsupportedConst->Result sub-graphs to remove
for (const auto& result : m->get_results()) {
auto unsupported_const = as_type_ptr<UnsupportedConstant>(result->get_input_node_shared_ptr(0));
if (unsupported_const && unsupported_const->output(0).get_target_inputs().size() == 1) {
results_to_remove.push_back(result);
}
}
for (const auto& result : results_to_remove) {
m->remove_result(result);
}
return true;
}
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -14,18 +14,9 @@ namespace tensorflow {
namespace op {
OutputVector translate_no_op(const NodeContext& node) {
if (node.get_input_size() == 0) {
return OutputVector{};
}
TENSORFLOW_OP_VALIDATION(node,
node.get_input_size() == 1,
"NoOp has " + to_string(node.get_input_size()) + " inputs, should have 1");
auto input = node.get_input(0);
set_out_name(node.get_name(), input);
set_out_name(node.get_name() + ":" + "0", input);
return {input};
// the operation does nothing in terms of data generation
default_op_checks(node, 0, {"NoOp", "SaveV2"});
return {};
}
} // namespace op
} // namespace tensorflow