[TF FE] Support SaveV2 operation (#15572)
* [TF FE] Support SaveV2 operation Also, implement transformation to remove UnsupportedConstant to Result isolated sub-graphs Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Revert change for Constant --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
82286cc2af
commit
a350bd7e85
@ -8,6 +8,7 @@
|
|||||||
#include "helper_transforms/block_lstm_replacer.hpp"
|
#include "helper_transforms/block_lstm_replacer.hpp"
|
||||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||||
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
||||||
|
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
|
||||||
#include "input_model.hpp"
|
#include "input_model.hpp"
|
||||||
#include "op_table.hpp"
|
#include "op_table.hpp"
|
||||||
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
|
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
|
||||||
@ -251,6 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
|||||||
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||||
|
manager.register_pass<pass::UnsupportedConstToResultRemover>();
|
||||||
|
|
||||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||||
|
@ -189,6 +189,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Roll", translate_roll_op},
|
{"Roll", translate_roll_op},
|
||||||
{"Round", translate_round_op},
|
{"Round", translate_round_op},
|
||||||
{"Rsqrt", translate_rsqrt_op},
|
{"Rsqrt", translate_rsqrt_op},
|
||||||
|
{"SaveV2", translate_no_op},
|
||||||
{"ScatterNd", translate_scatter_nd_op},
|
{"ScatterNd", translate_scatter_nd_op},
|
||||||
{"SegmentSum", translate_segment_sum_op},
|
{"SegmentSum", translate_segment_sum_op},
|
||||||
{"SparseToDense", translate_sparse_to_dense_op},
|
{"SparseToDense", translate_sparse_to_dense_op},
|
||||||
|
@ -246,3 +246,19 @@ TEST_F(TransformationTestsF, DISABLED_ModelWithDilatedGroupConvolution) {
|
|||||||
model_ref = make_shared<Model>(OutputVector{transpose_after}, ParameterVector{x});
|
model_ref = make_shared<Model>(OutputVector{transpose_after}, ParameterVector{x});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ModelWithSaveV2) {
|
||||||
|
{
|
||||||
|
model = convert_model("model_savev2/model_savev2.pb");
|
||||||
|
// need to call shape inference since body graphs can be injected with undefined shapes
|
||||||
|
model->validate_nodes_and_infer_types();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// create a reference graph
|
||||||
|
auto x = make_shared<Parameter>(element::f32, Shape{2});
|
||||||
|
auto const_2 = make_shared<Constant>(element::f32, Shape{2}, vector<float>{1, 2});
|
||||||
|
auto add = make_shared<Add>(x, const_2);
|
||||||
|
|
||||||
|
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -0,0 +1,152 @@
|
|||||||
|
node {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Const"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_content: "\000\000\200?\000\000\000@"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "add"
|
||||||
|
op: "AddV2"
|
||||||
|
input: "x"
|
||||||
|
input: "Const"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "save/Const"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_STRING
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
string_val: "model"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "tensor_names"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_STRING
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
string_val: "Const"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "shape_and_slices"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_STRING
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
string_val: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "SaveV2"
|
||||||
|
op: "SaveV2"
|
||||||
|
input: "save/Const"
|
||||||
|
input: "tensor_names"
|
||||||
|
input: "shape_and_slices"
|
||||||
|
input: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtypes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "save/control_dependency"
|
||||||
|
op: "Identity"
|
||||||
|
input: "save/Const"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
// This transformation removes isolated subgraph Unsupported constant going to the Result node
|
||||||
|
class UnsupportedConstToResultRemover : public ov::pass::ModelPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::frontend::tensorflow::pass::UnsupportedConstToResultRemover");
|
||||||
|
UnsupportedConstToResultRemover() {}
|
||||||
|
|
||||||
|
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
|
||||||
|
|
||||||
|
#include "helper_ops/unsupported_constant.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
bool UnsupportedConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||||
|
ResultVector results_to_remove;
|
||||||
|
// look for isolated UnsupportedConst->Result sub-graphs to remove
|
||||||
|
for (const auto& result : m->get_results()) {
|
||||||
|
auto unsupported_const = as_type_ptr<UnsupportedConstant>(result->get_input_node_shared_ptr(0));
|
||||||
|
if (unsupported_const && unsupported_const->output(0).get_target_inputs().size() == 1) {
|
||||||
|
results_to_remove.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& result : results_to_remove) {
|
||||||
|
m->remove_result(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -14,18 +14,9 @@ namespace tensorflow {
|
|||||||
namespace op {
|
namespace op {
|
||||||
|
|
||||||
OutputVector translate_no_op(const NodeContext& node) {
|
OutputVector translate_no_op(const NodeContext& node) {
|
||||||
if (node.get_input_size() == 0) {
|
// the operation does nothing in terms of data generation
|
||||||
return OutputVector{};
|
default_op_checks(node, 0, {"NoOp", "SaveV2"});
|
||||||
}
|
return {};
|
||||||
|
|
||||||
TENSORFLOW_OP_VALIDATION(node,
|
|
||||||
node.get_input_size() == 1,
|
|
||||||
"NoOp has " + to_string(node.get_input_size()) + " inputs, should have 1");
|
|
||||||
|
|
||||||
auto input = node.get_input(0);
|
|
||||||
set_out_name(node.get_name(), input);
|
|
||||||
set_out_name(node.get_name() + ":" + "0", input);
|
|
||||||
return {input};
|
|
||||||
}
|
}
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user