[TF FE] Support SaveV2 operation (#15572)
* [TF FE] Support SaveV2 operation Also, implement transformation to remove UnsupportedConstant to Result isolated sub-graphs Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Revert change for Constant --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
82286cc2af
commit
a350bd7e85
@ -8,6 +8,7 @@
|
||||
#include "helper_transforms/block_lstm_replacer.hpp"
|
||||
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
|
||||
#include "helper_transforms/gru_block_cell_replacer.hpp"
|
||||
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "op_table.hpp"
|
||||
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
|
||||
@ -251,6 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
manager.register_pass<pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||
manager.register_pass<pass::UnsupportedConstToResultRemover>();
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
|
@ -189,6 +189,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Roll", translate_roll_op},
|
||||
{"Round", translate_round_op},
|
||||
{"Rsqrt", translate_rsqrt_op},
|
||||
{"SaveV2", translate_no_op},
|
||||
{"ScatterNd", translate_scatter_nd_op},
|
||||
{"SegmentSum", translate_segment_sum_op},
|
||||
{"SparseToDense", translate_sparse_to_dense_op},
|
||||
|
@ -246,3 +246,19 @@ TEST_F(TransformationTestsF, DISABLED_ModelWithDilatedGroupConvolution) {
|
||||
model_ref = make_shared<Model>(OutputVector{transpose_after}, ParameterVector{x});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ModelWithSaveV2) {
|
||||
{
|
||||
model = convert_model("model_savev2/model_savev2.pb");
|
||||
// need to call shape inference since body graphs can be injected with undefined shapes
|
||||
model->validate_nodes_and_infer_types();
|
||||
}
|
||||
{
|
||||
// create a reference graph
|
||||
auto x = make_shared<Parameter>(element::f32, Shape{2});
|
||||
auto const_2 = make_shared<Constant>(element::f32, Shape{2}, vector<float>{1, 2});
|
||||
auto add = make_shared<Add>(x, const_2);
|
||||
|
||||
model_ref = make_shared<Model>(OutputVector{add}, ParameterVector{x});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,152 @@
|
||||
node {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\000\000\200?\000\000\000@"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "add"
|
||||
op: "AddV2"
|
||||
input: "x"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "model"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "tensor_names"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
string_val: "Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "shape_and_slices"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
string_val: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "SaveV2"
|
||||
op: "SaveV2"
|
||||
input: "save/Const"
|
||||
input: "tensor_names"
|
||||
input: "shape_and_slices"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "dtypes"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/control_dependency"
|
||||
op: "Identity"
|
||||
input: "save/Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
// This transformation removes isolated subgraph Unsupported constant going to the Result node
|
||||
class UnsupportedConstToResultRemover : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::UnsupportedConstToResultRemover");
|
||||
UnsupportedConstToResultRemover() {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "helper_transforms/unsupported_const_to_result_remover.hpp"
|
||||
|
||||
#include "helper_ops/unsupported_constant.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
bool UnsupportedConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
ResultVector results_to_remove;
|
||||
// look for isolated UnsupportedConst->Result sub-graphs to remove
|
||||
for (const auto& result : m->get_results()) {
|
||||
auto unsupported_const = as_type_ptr<UnsupportedConstant>(result->get_input_node_shared_ptr(0));
|
||||
if (unsupported_const && unsupported_const->output(0).get_target_inputs().size() == 1) {
|
||||
results_to_remove.push_back(result);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& result : results_to_remove) {
|
||||
m->remove_result(result);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace pass
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -14,18 +14,9 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_no_op(const NodeContext& node) {
|
||||
if (node.get_input_size() == 0) {
|
||||
return OutputVector{};
|
||||
}
|
||||
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
node.get_input_size() == 1,
|
||||
"NoOp has " + to_string(node.get_input_size()) + " inputs, should have 1");
|
||||
|
||||
auto input = node.get_input(0);
|
||||
set_out_name(node.get_name(), input);
|
||||
set_out_name(node.get_name() + ":" + "0", input);
|
||||
return {input};
|
||||
// the operation does nothing in terms of data generation
|
||||
default_op_checks(node, 0, {"NoOp", "SaveV2"});
|
||||
return {};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user