[TF FE] Operations refactoring (#8477)

* Resolve review comments

* fix codestyle

* remove mention of ngraph from TF frontend
This commit is contained in:
Ivan Tikhonov 2021-11-23 13:14:29 +03:00 committed by GitHub
parent 277ff96564
commit cfe33fdf08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 174 additions and 176 deletions

View File

@ -77,7 +77,7 @@ ie_add_api_validator_post_build_step(TARGET ${TARGET_NAME})
link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES}) link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES})
target_link_libraries(${TARGET_NAME} PRIVATE frontend_manager::static target_link_libraries(${TARGET_NAME} PRIVATE frontend_manager::static
PRIVATE ngraph::builder inference_engine_transformations libprotobuf openvino::util) PRIVATE inference_engine_transformations libprotobuf openvino::util)
add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME} add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME}
EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS}) EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS})

View File

@ -4,8 +4,8 @@
#pragma once #pragma once
#include <openvino/core/variant.hpp> #include "openvino/core/variant.hpp"
#include <tensorflow_frontend/utility.hpp> #include "tensorflow_frontend/utility.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {

View File

@ -47,8 +47,7 @@ shared_ptr<Variant> DecoderTFProto::get_attribute(const string& name, const Vari
if (is_type<string>(type_info)) { if (is_type<string>(type_info)) {
return create_variant<string>(attrs[0].s()); return create_variant<string>(attrs[0].s());
} } else if (is_type<int64_t>(type_info)) {
if (is_type<int64_t>(type_info)) {
return create_variant<int64_t>(attrs[0].i()); return create_variant<int64_t>(attrs[0].i());
} else if (is_type<vector<int64_t>>(type_info)) { } else if (is_type<vector<int64_t>>(type_info)) {
vector<int64_t> longs; vector<int64_t> longs;

View File

@ -4,13 +4,12 @@
#pragma once #pragma once
#include <ngraph/ngraph.hpp>
#include <string> #include <string>
#include <tensorflow_frontend/decoder.hpp>
#include <vector> #include <vector>
#include "attr_value.pb.h" #include "attr_value.pb.h"
#include "node_def.pb.h" #include "node_def.pb.h"
#include "tensorflow_frontend/decoder.hpp"
#include "types.pb.h" #include "types.pb.h"
namespace ov { namespace ov {

View File

@ -3,8 +3,7 @@
// //
#pragma once #pragma once
#include <frontend_manager/frontend_exceptions.hpp> #include "frontend_manager/frontend_exceptions.hpp"
#include "openvino/core/node.hpp" #include "openvino/core/node.hpp"
namespace ov { namespace ov {

View File

@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <tensorflow_frontend/frontend.hpp> #include "tensorflow_frontend/frontend.hpp"
#include <tensorflow_frontend/graph_iterator.hpp>
#include "model.hpp" #include "model.hpp"
#include "op_table.hpp" #include "op_table.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/util/common_util.hpp" #include "openvino/util/common_util.hpp"
#include "pass/transpose_sinking.hpp" #include "pass/transpose_sinking.hpp"
#include "tensorflow_frontend/graph_iterator.hpp"
#include "tf_framework_node.hpp" #include "tf_framework_node.hpp"
#include "utils.hpp" #include "utils.hpp"
@ -52,13 +53,13 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
bool fail_fast, bool fail_fast,
bool no_conversion, bool no_conversion,
std::shared_ptr<ov::Function>& ng_function) const { std::shared_ptr<ov::Function>& ng_function) const {
// a map from operation names to generated nGraph Output<TFNodeDecoder> // a map from operation names to generated OV Output<TFNodeDecoder>
tf::OpMap ng_op_map; tf::OpMap ng_op_map;
ov::ParameterVector params; ov::ParameterVector params;
ov::ResultVector results; ov::ResultVector results;
const auto& model_tf = std::dynamic_pointer_cast<InputModelTF>(model); const auto& model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into nGraph function"); FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV function");
const auto& operation_places = model_tf->get_op_places(); const auto& operation_places = model_tf->get_op_places();
const auto& model_inputs = model_tf->get_inputs(); const auto& model_inputs = model_tf->get_inputs();
const auto& model_outputs = model_tf->get_outputs(); const auto& model_outputs = model_tf->get_outputs();
@ -101,7 +102,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
ng_op_map[input_name] = {param}; ng_op_map[input_name] = {param};
} }
// create the nGraph ops from TensorFlow ops // create the OV ops from TensorFlow ops
for (const auto& operation_place : operation_places) { for (const auto& operation_place : operation_places) {
auto operation_decoder = operation_place->get_decoder(); auto operation_decoder = operation_place->get_decoder();
auto operation_name = operation_place->get_names()[0]; auto operation_name = operation_place->get_names()[0];
@ -110,7 +111,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
continue; continue;
} }
// prepare a list of nGraph node inputs for each node // prepare a list of OV node inputs for each node
ov::OutputVector ng_inputs; ov::OutputVector ng_inputs;
::ov::frontend::tf::NamedInputs named_inputs; ::ov::frontend::tf::NamedInputs named_inputs;
for (size_t input_port_idx = 0; input_port_idx < operation_decoder->get_input_size(); ++input_port_idx) { for (size_t input_port_idx = 0; input_port_idx < operation_decoder->get_input_size(); ++input_port_idx) {
@ -157,7 +158,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
} }
} }
// generate nGraph node output vector for the current operation node // generate OV node output vector for the current operation node
ov::OutputVector ng_outputs; ov::OutputVector ng_outputs;
try { try {
FRONT_END_OP_CONVERSION_CHECK(translate_map.count(operation_decoder->get_op_type()), FRONT_END_OP_CONVERSION_CHECK(translate_map.count(operation_decoder->get_op_type()),
@ -166,7 +167,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
// NodeContext node_context(ng_inputs, operation_decoder, model_inputs); // NodeContext node_context(ng_inputs, operation_decoder, model_inputs);
// TODO: Check why NodeContextNew doesn't have ngOutputVector ng_inputs input in constructor // TODO: Check why NodeContextNew doesn't have ngOutputVector ng_inputs input in constructor
::ov::frontend::tf::NodeContext node_context(*operation_decoder, named_inputs); ::ov::frontend::tf::NodeContext node_context(*operation_decoder, named_inputs);
// generate nGraph node output vector using translator for given operation type // generate OV node output vector using translator for given operation type
ng_outputs = (*op_fun)(node_context); ng_outputs = (*op_fun)(node_context);
} catch (...) { } catch (...) {
if (fail_fast) { if (fail_fast) {
@ -181,7 +182,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
} }
} }
// register nGraph node outputs in the map for new operation node // register OV node outputs in the map for new operation node
for (const auto& output : ng_outputs) { for (const auto& output : ng_outputs) {
if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) { if (auto result = std::dynamic_pointer_cast<ov::opset8::Result>(output.get_node_shared_ptr())) {
// do not add RetVal type operation to ng_op_map // do not add RetVal type operation to ng_op_map
@ -247,7 +248,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
results.push_back(std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx])); results.push_back(std::make_shared<ov::opset8::Result>(node_outputs[producer_port_idx]));
} }
} }
// find all terminal nodes in ngraph graph to complete list of results // find all terminal nodes in OV graph to complete list of results
if (results.empty()) { if (results.empty()) {
for (const auto& node_output_vector : ng_op_map) { for (const auto& node_output_vector : ng_op_map) {
for (const auto& output : node_output_vector.second) { for (const auto& output : node_output_vector.second) {
@ -261,9 +262,9 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model,
// TODO: reorder results and params according to indices given in RT info (if any) // TODO: reorder results and params according to indices given in RT info (if any)
// create the nGraph function // create the OV function
ng_function = std::make_shared<ov::Function>(results, params, model_name); ng_function = std::make_shared<ov::Function>(results, params, model_name);
NGRAPH_DEBUG << "Done with translations"; OPENVINO_DEBUG << "Done with translations";
} }
/// \brief Check if FrontEndTensorflow can recognize model from given parts /// \brief Check if FrontEndTensorflow can recognize model from given parts
@ -309,7 +310,7 @@ std::shared_ptr<ov::Function> FrontEndTF::convert(ov::frontend::InputModel::Ptr
std::shared_ptr<ov::Function> f; std::shared_ptr<ov::Function> f;
translate_graph(model_tf, "here_should_be_a_graph_name", true, false, f); translate_graph(model_tf, "here_should_be_a_graph_name", true, false, f);
normalize(f); normalize(f);
// TODO: check that nGraph function does not contain operations which are not in the opset // TODO: check that OV function does not contain operations which are not in the opset
return f; return f;
} }

View File

@ -5,12 +5,12 @@
#pragma once #pragma once
#include <fstream> #include <fstream>
#include <tensorflow_frontend/decoder.hpp>
#include <tensorflow_frontend/graph_iterator.hpp>
#include "decoder_proto.hpp" #include "decoder_proto.hpp"
#include "graph.pb.h" #include "graph.pb.h"
#include "node_def.pb.h" #include "node_def.pb.h"
#include "tensorflow_frontend/decoder.hpp"
#include "tensorflow_frontend/graph_iterator.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {

View File

@ -4,16 +4,14 @@
#include "model.hpp" #include "model.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
#include <fstream> #include <fstream>
#include <openvino/opsets/opset7.hpp>
#include <queue> #include <queue>
#include <tensorflow_frontend/graph_iterator.hpp>
#include "graph_iterator_proto.hpp" #include "frontend_manager/frontend_exceptions.hpp"
#include "ngraph_conversions.hpp"
#include "node_context.hpp" #include "node_context.hpp"
#include "openvino/opsets/opset7.hpp"
#include "place.hpp" #include "place.hpp"
#include "tensorflow_frontend/graph_iterator.hpp"
#include "utils.hpp" #include "utils.hpp"
using namespace google; using namespace google;

View File

@ -4,9 +4,9 @@
#pragma once #pragma once
#include <frontend_manager/input_model.hpp> #include "frontend_manager/input_model.hpp"
#include <frontend_manager/place.hpp> #include "frontend_manager/place.hpp"
#include <tensorflow_frontend/graph_iterator.hpp> #include "tensorflow_frontend/graph_iterator.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {

View File

@ -3,15 +3,15 @@
// //
#pragma once #pragma once
#include <tensorflow_frontend/utility.hpp>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "openvino/core/variant.hpp" #include "openvino/core/variant.hpp"
#include "place.hpp" #include "place.hpp"
#include "tensor.pb.h" #include "tensor.pb.h"
#include "tensorflow_frontend/utility.hpp"
#include "types.pb.h" #include "types.pb.h"
#define NGRAPH_VARIANT_DECLARATION(TYPE, info) \ #define OPENVINO_VARIANT_DECLARATION(TYPE, info) \
template <> \ template <> \
class VariantWrapper<TYPE> : public VariantImpl<TYPE> { \ class VariantWrapper<TYPE> : public VariantImpl<TYPE> { \
public: \ public: \
@ -20,18 +20,18 @@
} }
namespace ov { namespace ov {
NGRAPH_VARIANT_DECLARATION(int32_t, "Variant::int32"); OPENVINO_VARIANT_DECLARATION(int32_t, "Variant::int32");
NGRAPH_VARIANT_DECLARATION(uint64_t, "Variant::uint64_t"); OPENVINO_VARIANT_DECLARATION(uint64_t, "Variant::uint64_t");
NGRAPH_VARIANT_DECLARATION(std::vector<int32_t>, "Variant::int32_vector"); OPENVINO_VARIANT_DECLARATION(std::vector<int32_t>, "Variant::int32_vector");
NGRAPH_VARIANT_DECLARATION(float, "Variant::float"); OPENVINO_VARIANT_DECLARATION(float, "Variant::float");
NGRAPH_VARIANT_DECLARATION(std::vector<float>, "Variant::float_vector"); OPENVINO_VARIANT_DECLARATION(std::vector<float>, "Variant::float_vector");
NGRAPH_VARIANT_DECLARATION(bool, "Variant::bool"); OPENVINO_VARIANT_DECLARATION(bool, "Variant::bool");
NGRAPH_VARIANT_DECLARATION(ov::element::Type, "Variant::ov_element_type"); OPENVINO_VARIANT_DECLARATION(ov::element::Type, "Variant::ov_element_type");
NGRAPH_VARIANT_DECLARATION(std::vector<int64_t>, "Variant::int64_vector"); OPENVINO_VARIANT_DECLARATION(std::vector<int64_t>, "Variant::int64_vector");
NGRAPH_VARIANT_DECLARATION(ov::PartialShape, "Variant::ngraph_PartialShape"); OPENVINO_VARIANT_DECLARATION(ov::PartialShape, "Variant:ov_PartialShape");
NGRAPH_VARIANT_DECLARATION(std::vector<std::string>, "Variant::string_vector"); OPENVINO_VARIANT_DECLARATION(std::vector<std::string>, "Variant::string_vector");
NGRAPH_VARIANT_DECLARATION(::tensorflow::DataType, "Variant::DataType"); OPENVINO_VARIANT_DECLARATION(::tensorflow::DataType, "Variant::DataType");
NGRAPH_VARIANT_DECLARATION(::tensorflow::TensorProto, "Variant::TensorProto"); OPENVINO_VARIANT_DECLARATION(::tensorflow::TensorProto, "Variant::TensorProto");
} // namespace ov } // namespace ov
namespace ov { namespace ov {

View File

@ -13,11 +13,11 @@ namespace frontend {
namespace tf { namespace tf {
namespace op { namespace op {
OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) { OutputVector translate_arg_min_max(const NodeContext& node, std::string mode) {
Output<Node> ng_input = node.get_input(0); Output<Node> ng_input = node.get_input(0);
std::vector<int64_t> tf_dim; std::vector<int64_t> tf_dim;
get_static_input_vec(node, 1, &tf_dim); get_const_input(node, 1, &tf_dim);
Shape input_shape = ng_input.get_shape(); Shape input_shape = ng_input.get_shape();
size_t input_rank = input_shape.size(); size_t input_rank = input_shape.size();
@ -26,10 +26,10 @@ OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) {
// If input dimension is negative, make it positive // If input dimension is negative, make it positive
if (tf_dim[0] < 0) { if (tf_dim[0] < 0) {
NGRAPH_DEBUG << "Input dimension is negative, make it positive " << tf_dim[0]; OPENVINO_DEBUG << "Input dimension is negative, make it positive " << tf_dim[0];
tf_dim[0] = (int64_t)input_rank + tf_dim[0]; tf_dim[0] = (int64_t)input_rank + tf_dim[0];
} }
NGRAPH_DEBUG << "Axis along which to compute " << tf_dim[0]; OPENVINO_DEBUG << "Axis along which to compute " << tf_dim[0];
size_t k_axis = tf_dim[0]; size_t k_axis = tf_dim[0];
auto ng_et = node.get_attribute<element::Type>("output_type"); auto ng_et = node.get_attribute<element::Type>("output_type");
@ -47,11 +47,11 @@ OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) {
} }
OutputVector translate_arg_max_op(const NodeContext& node) { OutputVector translate_arg_max_op(const NodeContext& node) {
return (TranslateArgMinMax(node, "max")); return (translate_arg_min_max(node, "max"));
} }
OutputVector translate_arg_min_op(const NodeContext& node) { OutputVector translate_arg_min_op(const NodeContext& node) {
return (TranslateArgMinMax(node, "min")); return (translate_arg_min_max(node, "min"));
} }
} // namespace op } // namespace op
} // namespace tf } // namespace tf

View File

@ -46,7 +46,7 @@ OutputVector translate_avg_pool_op(const NodeContext& node) {
padding_below, padding_below,
padding_above); padding_above);
// TODO: remove this once nGraph supports negative padding // TODO: remove this once OV supports negative padding
// (CoordinateDiff) for AvgPool // (CoordinateDiff) for AvgPool
Shape ng_padding_below(padding_below.begin(), padding_below.end()); Shape ng_padding_below(padding_below.begin(), padding_below.end());
Shape ng_padding_above(padding_above.begin(), padding_above.end()); Shape ng_padding_above(padding_above.begin(), padding_above.end());

View File

@ -31,7 +31,7 @@ OutputVector translate_concat_op(const NodeContext& node) {
} }
std::vector<int64_t> tf_concat_axis_vec; std::vector<int64_t> tf_concat_axis_vec;
get_static_input_vec(node, axis_idx, &tf_concat_axis_vec); get_const_input(node, axis_idx, &tf_concat_axis_vec);
int64_t concat_axis = tf_concat_axis_vec[0]; int64_t concat_axis = tf_concat_axis_vec[0];
OutputVector ng_args; OutputVector ng_args;

View File

@ -18,7 +18,7 @@ using ConstMap = std::map<ov::element::Type,
std::pair<std::function<void(const NodeContext&, ov::element::Type, ov::Output<ov::Node>&)>, std::pair<std::function<void(const NodeContext&, ov::element::Type, ov::Output<ov::Node>&)>,
const ov::element::Type>>; const ov::element::Type>>;
const ConstMap& TF_NGRAPH_CONST_MAP() { const ConstMap& TF_OPENVINO_CONST_MAP() {
static const ConstMap the_map = { static const ConstMap the_map = {
{ov::element::f32, make_pair(make_const_op<float>, ov::element::f32)}, {ov::element::f32, make_pair(make_const_op<float>, ov::element::f32)},
{ov::element::f64, make_pair(make_const_op<double>, ov::element::f64)}, {ov::element::f64, make_pair(make_const_op<double>, ov::element::f64)},
@ -43,23 +43,13 @@ OutputVector translate_const_op(const NodeContext& node) {
auto dt = node.get_attribute<ov::element::Type>("dtype"); auto dt = node.get_attribute<ov::element::Type>("dtype");
Output<Node> res; Output<Node> res;
// For some reason the following do not work (no specialization of // TODO: fix DT_UINT32 and DT_UINT64 support
// tensorflow::checkpoint::SavedTypeTraits...) // no specialization of tensorflow::checkpoint::SavedTypeTraits...)
// case DataType::DT_UINT32:
// TF_RETURN_IF_ERROR(make_const_op<uint32>(op, element::u32,
// &ng_node));
// break;
// case DataType::DT_UINT64:
// TF_RETURN_IF_ERROR(make_const_op<uint64>(op, element::u64,
// &ng_node));
// break;
try { try {
const auto& func_param = TF_NGRAPH_CONST_MAP().at(dt); const auto& func_param = TF_OPENVINO_CONST_MAP().at(dt);
func_param.first(node, func_param.second, res); func_param.first(node, func_param.second, res);
} catch (const std::out_of_range&) { } catch (const std::out_of_range&) {
TF_OP_VALIDATION_CHECK(node, TF_OP_VALIDATION_CHECK(node, false, "Failed to translate Constant with target OV type:" + dt.get_type_name());
false,
"Failed to translate Constant with target ngraph type:" + dt.get_type_name());
} }
set_node_name(node.get_name(), res.get_node_shared_ptr()); set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res}; return {res};

View File

@ -48,7 +48,7 @@ OutputVector translate_conv_2d_op(const NodeContext& node) {
auto& ng_filter_shape = ng_filter.get_shape(); auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1]; ng_kernel_shape[1] = ng_filter_shape[1];
Transpose<3, 2, 0, 1>(ng_filter); transpose<3, 2, 0, 1>(ng_filter);
CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_below;
CoordinateDiff ng_padding_above; CoordinateDiff ng_padding_above;

View File

@ -27,7 +27,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) {
"Conv2DBackpropInput data format is neither NHWC nor NCHW"); "Conv2DBackpropInput data format is neither NHWC nor NCHW");
std::vector<int64_t> tf_input_sizes; std::vector<int64_t> tf_input_sizes;
get_static_input_vec(node, 0, &tf_input_sizes); get_const_input(node, 0, &tf_input_sizes);
if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(), [](int32_t size) { if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(), [](int32_t size) {
return size <= 0; return size <= 0;
@ -62,7 +62,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) {
auto& ng_filter_shape = ng_filter.get_shape(); auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1]; ng_kernel_shape[1] = ng_filter_shape[1];
Transpose<3, 2, 0, 1>(ng_filter); transpose<3, 2, 0, 1>(ng_filter);
CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_below;
CoordinateDiff ng_padding_above; CoordinateDiff ng_padding_above;

View File

@ -51,7 +51,7 @@ OutputVector translate_conv_3d_op(const NodeContext& node) {
ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1]; ng_kernel_shape[1] = ng_filter_shape[1];
ng_kernel_shape[2] = ng_filter_shape[2]; ng_kernel_shape[2] = ng_filter_shape[2];
Transpose3D<4, 3, 0, 1, 2>(ng_filter); transpose_3d<4, 3, 0, 1, 2>(ng_filter);
CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_below;
CoordinateDiff ng_padding_above; CoordinateDiff ng_padding_above;

View File

@ -123,10 +123,10 @@ OutputVector translate_crop_and_resize_op(const NodeContext& node) {
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST; interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
} }
Transpose<0, 3, 1, 2>(ng_crop); transpose<0, 3, 1, 2>(ng_crop);
auto ng_output = auto ng_output =
make_shared<Interpolate>(ng_crop, ng_size, ng_scales, ng_axes, interpolate_attrs)->output(0); make_shared<Interpolate>(ng_crop, ng_size, ng_scales, ng_axes, interpolate_attrs)->output(0);
Transpose<0, 2, 3, 1>(ng_output); transpose<0, 2, 3, 1>(ng_output);
ng_crop_outputs.at(i) = ng_output; ng_crop_outputs.at(i) = ng_output;
} }

View File

@ -52,10 +52,10 @@ OutputVector translate_fake_quant_op(const NodeContext& node) {
auto ng_input_shape = ng_input.get_shape(); auto ng_input_shape = ng_input.get_shape();
if (ng_input_shape.size() == 4) if (ng_input_shape.size() == 4)
Transpose<0, 3, 1, 2>(ng_input); transpose<0, 3, 1, 2>(ng_input);
auto res = make_shared<FakeQuantize>(ng_input, min_adj, max_adj, min_adj, max_adj, levels)->output(0); auto res = make_shared<FakeQuantize>(ng_input, min_adj, max_adj, min_adj, max_adj, levels)->output(0);
if (ng_input_shape.size() == 4) if (ng_input_shape.size() == 4)
Transpose<0, 2, 3, 1>(res); transpose<0, 2, 3, 1>(res);
set_node_name(node.get_name(), res.get_node_shared_ptr()); set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res}; return {res};

View File

@ -27,12 +27,12 @@ OutputVector translate_fused_batch_norm_op(const NodeContext& node) {
bool is_nhwc = (data_format == "NHWC"); bool is_nhwc = (data_format == "NHWC");
NGRAPH_DEBUG << "data_format: " << data_format; OPENVINO_DEBUG << "data_format: " << data_format;
// TODO: where does 0.0001 come from? // TODO: where does 0.0001 come from?
auto tf_epsilon = node.get_attribute<float>("epsilon", 0.0001); auto tf_epsilon = node.get_attribute<float>("epsilon", 0.0001);
NGRAPH_DEBUG << "epsilon: " << tf_epsilon; OPENVINO_DEBUG << "epsilon: " << tf_epsilon;
convert_nhwc_to_nchw(node.get_name(), is_nhwc, ng_input); convert_nhwc_to_nchw(node.get_name(), is_nhwc, ng_input);

View File

@ -15,6 +15,12 @@ namespace op {
OutputVector translate_identity_op(const NodeContext& node) { OutputVector translate_identity_op(const NodeContext& node) {
auto input = node.get_input(0); auto input = node.get_input(0);
// since the input node can have several outputs, and identity have only one input,
// we cannot use set_node_name(..) helper, we have to set names for output connected
// to this identity only.
// Node_1 -> Node_2
// -(identity name) -> Identity
set_out_name(node.get_name(), input); set_out_name(node.get_name(), input);
set_out_name(node.get_name() + ":" + "0", input); set_out_name(node.get_name() + ":" + "0", input);
return {input}; return {input};

View File

@ -36,9 +36,9 @@ ov::OutputVector translate_interpolate_op(const NodeContext& node) {
auto ng_scales = make_shared<Divide>(ng_sizes, ng_spatial_shape); auto ng_scales = make_shared<Divide>(ng_sizes, ng_spatial_shape);
auto ng_axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({2, 3})); auto ng_axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
Transpose<0, 3, 1, 2>(input); transpose<0, 3, 1, 2>(input);
auto res = make_shared<Interpolate>(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0); auto res = make_shared<Interpolate>(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0);
Transpose<0, 2, 3, 1>(res); transpose<0, 2, 3, 1>(res);
set_node_name(node.get_name(), res.get_node_shared_ptr()); set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res}; return {res};
} }

View File

@ -49,7 +49,7 @@ OutputVector translate_max_pool_op(const NodeContext& node) {
padding_below, padding_below,
padding_above); padding_above);
// TODO: remove this once nGraph supports negative padding // TODO: remove this once OV supports negative padding
// (CoordinateDiff) for MaxPool // (CoordinateDiff) for MaxPool
Shape ng_padding_below(padding_below.begin(), padding_below.end()); Shape ng_padding_below(padding_below.begin(), padding_below.end());
Shape ng_padding_above(padding_above.begin(), padding_above.end()); Shape ng_padding_above(padding_above.begin(), padding_above.end());

View File

@ -46,7 +46,7 @@ OutputVector translate_pad_op(const NodeContext& node) {
// Set pads_begin & pads_end (from the pad_val_op) // Set pads_begin & pads_end (from the pad_val_op)
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
get_static_input_vec(node, 1, &paddings); get_const_input(node, 1, &paddings);
if (paddings.size() % 2 != 0) { if (paddings.size() % 2 != 0) {
TF_OP_VALIDATION_CHECK(node, TF_OP_VALIDATION_CHECK(node,
false, false,

View File

@ -20,10 +20,10 @@ OutputVector translate_batch_nd_and_space_nd_op(const NodeContext& node) {
// ng_crops should be of shape N=[ng_input.get_shape()).size()] // ng_crops should be of shape N=[ng_input.get_shape()).size()]
// But TF's ng_crops input is limited only to the spatial dimensions (neither // But TF's ng_crops input is limited only to the spatial dimensions (neither
// batch nor innermost), // batch nor innermost),
// which would mean ngraph inputs have missing ng_crops[0] and ng_crops[N]. // which would mean OV inputs have missing ng_crops[0] and ng_crops[N].
// Hence, pad ng_crops with zeros at both ends // Hence, pad ng_crops with zeros at both ends
// return with input if rank < 2 as ngraph's impl doesn't support it // return with input if rank < 2 as OV's impl doesn't support it
const auto& input_pshape = input.get_partial_shape(); const auto& input_pshape = input.get_partial_shape();
const auto& block_shape_pshape = block_shape.get_partial_shape(); const auto& block_shape_pshape = block_shape.get_partial_shape();
if (input_pshape.rank().is_static() && block_shape_pshape.rank().is_static()) { if (input_pshape.rank().is_static() && block_shape_pshape.rank().is_static()) {

View File

@ -8,9 +8,9 @@
#include <map> #include <map>
#include <string> #include <string>
#include "ngraph_conversions.hpp"
#include "node_context.hpp" #include "node_context.hpp"
#include "openvino/core/node_vector.hpp" #include "openvino/core/node_vector.hpp"
#include "openvino_conversions.hpp"
#include "utils.hpp" #include "utils.hpp"
namespace ov { namespace ov {

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph_conversions.hpp" #include "openvino_conversions.hpp"
#include "utils.hpp" #include "utils.hpp"
@ -14,9 +14,9 @@ void convert_nhwc_to_nchw(const std::string& op_name, bool need_convert, ov::Out
if (need_convert) { if (need_convert) {
auto rank = node.get_shape().size(); auto rank = node.get_shape().size();
if (rank == 4) { if (rank == 4) {
Transpose<0, 3, 1, 2>(node); transpose<0, 3, 1, 2>(node);
} else if (rank == 5) { } else if (rank == 5) {
Transpose3D<0, 4, 1, 2, 3>(node); transpose_3d<0, 4, 1, 2, 3>(node);
} }
} }
} }
@ -25,9 +25,9 @@ void convert_nchw_to_nhwc(const std::string& op_name, bool need_convert, ov::Out
if (need_convert) { if (need_convert) {
auto rank = node.get_shape().size(); auto rank = node.get_shape().size();
if (rank == 4) { if (rank == 4) {
Transpose<0, 2, 3, 1>(node); transpose<0, 2, 3, 1>(node);
} else if (rank == 5) { } else if (rank == 5) {
Transpose3D<0, 2, 3, 4, 1>(node); transpose_3d<0, 2, 3, 4, 1>(node);
} }
} }
} }

View File

@ -5,10 +5,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include <tensorflow_frontend/utility.hpp>
#include "graph.pb.h" #include "graph.pb.h"
#include "openvino/opsets/opset8.hpp" #include "openvino/opsets/opset8.hpp"
#include "tensorflow_frontend/utility.hpp"
#include "types.pb.h" #include "types.pb.h"
namespace ov { namespace ov {
@ -18,11 +18,9 @@ namespace tf {
using ::tensorflow::DataType; using ::tensorflow::DataType;
template <size_t a, size_t b, size_t c, size_t d> template <size_t a, size_t b, size_t c, size_t d>
void Transpose(ov::Output<ov::Node>& node) { void transpose(ov::Output<ov::Node>& node) {
static_assert(a < 4 && b < 4 && c < 4 && d < 4, "Number of dimensions cannot exceed 4"); static_assert(a < 4 && b < 4 && c < 4 && d < 4, "Number of dimensions cannot exceed 4");
static_assert(a != b && a != c && a != d && b != c && b != d && c != d, "Dimensions indices cannot be equal"); static_assert(a != b && a != c && a != d && b != c && b != d && c != d, "Dimensions indices cannot be equal");
auto& s = node.get_shape();
ov::Shape reshaped_shape{s[a], s[b], s[c], s[d]};
ov::Shape transpose_order{a, b, c, d}; ov::Shape transpose_order{a, b, c, d};
auto input_order = auto input_order =
std::make_shared<ov::opset8::Constant>(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order); std::make_shared<ov::opset8::Constant>(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order);
@ -30,17 +28,15 @@ void Transpose(ov::Output<ov::Node>& node) {
} }
template <size_t a, size_t b, size_t c, size_t d> template <size_t a, size_t b, size_t c, size_t d>
void Transpose(std::shared_ptr<ov::Node>& node) { void transpose(std::shared_ptr<ov::Node>& node) {
Transpose<a, b, c, d>(node->get_default_output()); transpose<a, b, c, d>(node->get_default_output());
} }
template <size_t a, size_t b, size_t c, size_t d, size_t e> template <size_t a, size_t b, size_t c, size_t d, size_t e>
void Transpose3D(ov::Output<ov::Node>& node) { void transpose_3d(ov::Output<ov::Node>& node) {
static_assert(a < 5 && b < 5 && c < 5 && d < 5 && e < 5, "Number of dimensions cannot exceed 5"); static_assert(a < 5 && b < 5 && c < 5 && d < 5 && e < 5, "Number of dimensions cannot exceed 5");
static_assert(a != b && a != c && a != d && a != e && b != c && b != d && b != e && c != d && c != e && d != e, static_assert(a != b && a != c && a != d && a != e && b != c && b != d && b != e && c != d && c != e && d != e,
"Dimensions indices cannot be equal"); "Dimensions indices cannot be equal");
auto& s = node.get_shape();
ov::Shape reshaped_shape{s[a], s[b], s[c], s[d], s[e]};
ov::Shape transpose_order{a, b, c, d, e}; ov::Shape transpose_order{a, b, c, d, e};
auto input_order = auto input_order =
std::make_shared<ov::opset8::Constant>(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order); std::make_shared<ov::opset8::Constant>(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order);
@ -48,8 +44,8 @@ void Transpose3D(ov::Output<ov::Node>& node) {
} }
template <size_t a, size_t b, size_t c, size_t d, size_t e> template <size_t a, size_t b, size_t c, size_t d, size_t e>
void Transpose3D(std::shared_ptr<ov::Node>& node) { void transpose_3d(std::shared_ptr<ov::Node>& node) {
Transpose3D<a, b, c, d, e>(node->get_default_output()); transpose_3d<a, b, c, d, e>(node->get_default_output());
} }
namespace detail { namespace detail {

View File

@ -4,7 +4,10 @@
#include "transpose_sinking.hpp" #include "transpose_sinking.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset8.hpp" #include "openvino/opsets/opset8.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/util/common_util.hpp"
#include "utils.hpp" #include "utils.hpp"
using namespace std; using namespace std;
@ -30,6 +33,12 @@ static AxisVector permutation_to_default_order(const AxisVector& axis_order) {
return out; return out;
} }
static AxisVector get_default_order(size_t rank) {
AxisVector default_order(rank);
std::iota(begin(default_order), end(default_order), 0);
return default_order;
}
template <typename T> template <typename T>
static string describe(shared_ptr<Node> node) { static string describe(shared_ptr<Node> node) {
// ensure that it's either a reshape or a transpose // ensure that it's either a reshape or a transpose
@ -41,8 +50,8 @@ static string describe(shared_ptr<Node> node) {
stringstream ss; stringstream ss;
auto transpose = as_type_ptr<T>(node); auto transpose = as_type_ptr<T>(node);
auto const1 = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1)); auto const1 = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
ss << transpose->get_name() << " ( axis order = " << ngraph::vector_to_string(const1->get_axis_vector_val()) ss << transpose->get_name() << " ( axis order = " << ov::util::vector_to_string(const1->get_axis_vector_val())
<< " , shape = " << ngraph::vector_to_string(transpose->get_shape()) << " ) " << " , shape = " << ov::util::vector_to_string(transpose->get_shape()) << " ) "
<< " , input = " << transpose->input_value(0).get_node()->get_name(); << " , input = " << transpose->input_value(0).get_node()->get_name();
return ss.str(); return ss.str();
} }
@ -50,14 +59,14 @@ static string describe(shared_ptr<Node> node) {
static shared_ptr<Transpose> make_transpose(const Output<Node>& arg, const AxisVector& input_order) { static shared_ptr<Transpose> make_transpose(const Output<Node>& arg, const AxisVector& input_order) {
auto order = std::make_shared<Constant>(element::u64, Shape{input_order.size()}, input_order); auto order = std::make_shared<Constant>(element::u64, Shape{input_order.size()}, input_order);
auto transpose = make_shared<Transpose>(arg, order); auto transpose = make_shared<Transpose>(arg, order);
NGRAPH_DEBUG << "Make Transpose " << describe<Transpose>(transpose); OPENVINO_DEBUG << "Make Transpose " << describe<Transpose>(transpose);
return transpose; return transpose;
} }
static shared_ptr<Reshape> make_reshape(const Output<Node>& arg, const AxisVector& input_order) { static shared_ptr<Reshape> make_reshape(const Output<Node>& arg, const AxisVector& input_order) {
auto order = std::make_shared<Constant>(element::u64, Shape{input_order.size()}, input_order); auto order = std::make_shared<Constant>(element::u64, Shape{input_order.size()}, input_order);
auto transpose = make_shared<Reshape>(arg, order, false); auto transpose = make_shared<Reshape>(arg, order, false);
NGRAPH_DEBUG << "Make Reshape " << describe<Reshape>(transpose); OPENVINO_DEBUG << "Make Reshape " << describe<Reshape>(transpose);
return transpose; return transpose;
} }
@ -65,19 +74,19 @@ static void write_transposemap(TransposeMap& reorders,
const Output<Node>& target, const Output<Node>& target,
const shared_ptr<Transpose>& transpose) { const shared_ptr<Transpose>& transpose) {
auto name = target.get_node()->get_name() + "." + to_string(target.get_index()); auto name = target.get_node()->get_name() + "." + to_string(target.get_index());
NGRAPH_DEBUG << "Write TransposeMap[" << name << "] = " << describe<Transpose>(transpose); OPENVINO_DEBUG << "Write TransposeMap[" << name << "] = " << describe<Transpose>(transpose);
reorders[name] = transpose; reorders[name] = transpose;
} }
static shared_ptr<Transpose> read_transposemap(TransposeMap& reorders, const Output<Node>& target) { static shared_ptr<Transpose> read_transposemap(TransposeMap& reorders, const Output<Node>& target) {
auto name = target.get_node()->get_name() + "." + to_string(target.get_index()); auto name = target.get_node()->get_name() + "." + to_string(target.get_index());
auto transpose = reorders[name]; auto transpose = reorders[name];
NGRAPH_DEBUG << "Read TransposeMap[" << name << "] -> " << describe<Transpose>(transpose); OPENVINO_DEBUG << "Read TransposeMap[" << name << "] -> " << describe<Transpose>(transpose);
return transpose; return transpose;
} }
static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1, const shared_ptr<Transpose>& t2) { static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1, const shared_ptr<Transpose>& t2) {
auto default_order = ngraph::get_default_order(t1->get_shape()); auto default_order = get_default_order(t1->get_shape().size());
auto t1_const = as_type_ptr<Constant>(t1->input_value(1).get_node_shared_ptr()); auto t1_const = as_type_ptr<Constant>(t1->input_value(1).get_node_shared_ptr());
auto t2_const = as_type_ptr<Constant>(t2->input_value(1).get_node_shared_ptr()); auto t2_const = as_type_ptr<Constant>(t2->input_value(1).get_node_shared_ptr());
@ -85,31 +94,31 @@ static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1,
auto perm_t2 = apply_permutation(perm_t1, t2_const->get_axis_vector_val()); auto perm_t2 = apply_permutation(perm_t1, t2_const->get_axis_vector_val());
auto combined = make_transpose(t2->input_value(0), perm_t2); auto combined = make_transpose(t2->input_value(0), perm_t2);
NGRAPH_DEBUG << "Combining " << describe<Transpose>(t1) << " and " << describe<Transpose>(t2) << " into " OPENVINO_DEBUG << "Combining " << describe<Transpose>(t1) << " and " << describe<Transpose>(t2) << " into "
<< describe<Transpose>(combined); << describe<Transpose>(combined);
return combined; return combined;
} }
static void insert_transpose(const shared_ptr<Node>& target, const shared_ptr<Node>& transpose, size_t input_index) { static void insert_transpose(const shared_ptr<Node>& target, const shared_ptr<Node>& transpose, size_t input_index) {
NGRAPH_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index; OPENVINO_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index;
auto arg = target->input(input_index).get_source_output(); auto arg = target->input(input_index).get_source_output();
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape(); OPENVINO_DEBUG << "Arg shape: " << arg.get_shape();
auto new_order = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr()); auto new_order = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr());
auto new_transpose = make_transpose(arg.get_node_shared_ptr(), new_order->get_axis_vector_val()); auto new_transpose = make_transpose(arg.get_node_shared_ptr(), new_order->get_axis_vector_val());
NGRAPH_DEBUG << "Inserting transpose " << describe<Transpose>(new_transpose) << " at input " << target->get_name() OPENVINO_DEBUG << "Inserting transpose " << describe<Transpose>(new_transpose) << " at input " << target->get_name()
<< " input index " << input_index; << " input index " << input_index;
target->input(input_index).replace_source_output(new_transpose->output(0)); target->input(input_index).replace_source_output(new_transpose->output(0));
} }
static void delete_transpose(const shared_ptr<Node>& transpose) { static void delete_transpose(const shared_ptr<Node>& transpose) {
NGRAPH_DEBUG << "Removing transpose " << transpose->get_name(); OPENVINO_DEBUG << "Removing transpose " << transpose->get_name();
if (!transpose->get_users().empty()) { if (!transpose->get_users().empty()) {
Output<Node> output = transpose->output(0); Output<Node> output = transpose->output(0);
NGRAPH_DEBUG << "output " << output.get_node_shared_ptr()->get_name(); OPENVINO_DEBUG << "output " << output.get_node_shared_ptr()->get_name();
NGRAPH_DEBUG << "target input size " << output.get_target_inputs().size(); OPENVINO_DEBUG << "target input size " << output.get_target_inputs().size();
for (auto input : output.get_target_inputs()) { for (auto input : output.get_target_inputs()) {
NGRAPH_DEBUG << "input " << input.get_node()->get_name(); OPENVINO_DEBUG << "input " << input.get_node()->get_name();
input.replace_source_output(transpose->input_value(0)); input.replace_source_output(transpose->input_value(0));
} }
} }
@ -117,12 +126,12 @@ static void delete_transpose(const shared_ptr<Node>& transpose) {
static void mark_transpose_for_deletion(const shared_ptr<Node>& transpose, static void mark_transpose_for_deletion(const shared_ptr<Node>& transpose,
set<shared_ptr<Node>>& transposes_to_delete) { set<shared_ptr<Node>>& transposes_to_delete) {
NGRAPH_DEBUG << "Marking transpose " << transpose->get_name() << " for deletion"; OPENVINO_DEBUG << "Marking transpose " << transpose->get_name() << " for deletion";
transposes_to_delete.insert(transpose); transposes_to_delete.insert(transpose);
} }
static shared_ptr<Transpose> create_default_transpose(const Output<Node>& n) { static shared_ptr<Transpose> create_default_transpose(const Output<Node>& n) {
auto default_order = ngraph::get_default_order(n.get_shape()); auto default_order = get_default_order(n.get_shape().size());
auto order = std::make_shared<Constant>(element::u64, Shape{default_order.size()}, default_order); auto order = std::make_shared<Constant>(element::u64, Shape{default_order.size()}, default_order);
return make_shared<Transpose>(n, order); return make_shared<Transpose>(n, order);
} }
@ -160,8 +169,8 @@ static void convert_binary_to_default_order(const shared_ptr<Node>& binary,
} }
input.replace_source_output(new_node->output(0)); input.replace_source_output(new_node->output(0));
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right.get_shape()) << ", " OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", "
<< right.get_node_shared_ptr()->get_name(); << right.get_node_shared_ptr()->get_name();
// this should now insert transpose on right // this should now insert transpose on right
mark_transpose_for_deletion(right_t, transposes_to_delete); mark_transpose_for_deletion(right_t, transposes_to_delete);
write_transposemap(reorders, binary, right_t); write_transposemap(reorders, binary, right_t);
@ -180,11 +189,11 @@ static void materialize_shapes(const shared_ptr<Node>& n,
// materialize all pending transposes, flush pending transposes // materialize all pending transposes, flush pending transposes
auto arg = n->input_value(i); auto arg = n->input_value(i);
auto arg_transpose = read_transposemap(reorders, arg); auto arg_transpose = read_transposemap(reorders, arg);
NGRAPH_DEBUG << "Materializing " << describe<Transpose>(arg_transpose) << " for " OPENVINO_DEBUG << "Materializing " << describe<Transpose>(arg_transpose) << " for "
<< arg.get_node_shared_ptr()->get_name(); << arg.get_node_shared_ptr()->get_name();
mark_transpose_for_deletion(arg_transpose, transposes_to_delete); mark_transpose_for_deletion(arg_transpose, transposes_to_delete);
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr()); auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
if (arg_transpose_order->get_axis_vector_val() != ngraph::get_default_order(arg.get_shape())) { if (arg_transpose_order->get_axis_vector_val() != get_default_order(arg.get_shape().size())) {
// Insert if arg needs to be transposed. // Insert if arg needs to be transposed.
insert_transpose(n, arg_transpose, i); insert_transpose(n, arg_transpose, i);
} }
@ -194,7 +203,7 @@ static void materialize_shapes(const shared_ptr<Node>& n,
static void sink_transpose(const shared_ptr<Transpose>& transpose, static void sink_transpose(const shared_ptr<Transpose>& transpose,
TransposeMap& reorders, TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) { set<shared_ptr<Node>>& transposes_to_delete) {
NGRAPH_DEBUG << "Sinking Transpose :" << describe<Transpose>(transpose); OPENVINO_DEBUG << "Sinking Transpose :" << describe<Transpose>(transpose);
auto transpose_in = transpose->input_value(0); auto transpose_in = transpose->input_value(0);
auto orig_transpose = read_transposemap(reorders, transpose_in); auto orig_transpose = read_transposemap(reorders, transpose_in);
// combine both transposes // combine both transposes
@ -212,7 +221,7 @@ static void sink_unary(const shared_ptr<Node>& n,
TransposeMap& reorders, TransposeMap& reorders,
set<shared_ptr<Node>>& /* transposes_to_delete */) { set<shared_ptr<Node>>& /* transposes_to_delete */) {
auto arg_transpose = read_transposemap(reorders, n->input_value(0)); auto arg_transpose = read_transposemap(reorders, n->input_value(0));
NGRAPH_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << n->get_name(); OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << n->get_name();
write_transposemap(reorders, n, arg_transpose); write_transposemap(reorders, n, arg_transpose);
} }
@ -229,18 +238,19 @@ static void sink_binary(const shared_ptr<Node>& binary,
auto left_order = left_const->get_axis_vector_val(); auto left_order = left_const->get_axis_vector_val();
auto right_order = right_const->get_axis_vector_val(); auto right_order = right_const->get_axis_vector_val();
auto left_mismatch = left_order != ngraph::get_default_order(left.get_shape()); auto left_mismatch = left_order != get_default_order(left.get_shape().size());
auto right_mismatch = right_order != ngraph::get_default_order(right.get_shape()); auto right_mismatch = right_order != get_default_order(right.get_shape().size());
NGRAPH_DEBUG << "Sink binary " << binary->get_name() << " left transpose: " << ngraph::vector_to_string(left_order) OPENVINO_DEBUG << "Sink binary " << binary->get_name()
<< " left default: " << ngraph::vector_to_string(ngraph::get_default_order(left.get_shape())) << " left transpose: " << ov::util::vector_to_string(left_order)
<< " right transpose: " << ngraph::vector_to_string(right_order) << " left default: " << ov::util::vector_to_string(get_default_order(left.get_shape().size()))
<< " right default: " << ngraph::vector_to_string(ngraph::get_default_order(right.get_shape())); << " right transpose: " << ov::util::vector_to_string(right_order)
<< " right default: " << ov::util::vector_to_string(get_default_order(right.get_shape().size()));
if ((left_order.size() == right_order.size() && left_order == right_order) || (!left_mismatch && !right_mismatch)) { if ((left_order.size() == right_order.size() && left_order == right_order) || (!left_mismatch && !right_mismatch)) {
// Propagate the reshape which matches the shape of the binary node // Propagate the reshape which matches the shape of the binary node
auto new_transpose = (binary->get_output_shape(0) == left.get_shape()) ? left_t : right_t; auto new_transpose = (binary->get_output_shape(0) == left.get_shape()) ? left_t : right_t;
NGRAPH_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << binary->get_name(); OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << binary->get_name();
write_transposemap(reorders, binary, new_transpose); write_transposemap(reorders, binary, new_transpose);
// at this point, both transposes will be eventually removed // at this point, both transposes will be eventually removed
mark_transpose_for_deletion(left_t, transposes_to_delete); mark_transpose_for_deletion(left_t, transposes_to_delete);
@ -273,7 +283,8 @@ static void sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<N
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order); auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto dummy_correct_shape = make_shared<ngraph::pattern::op::Label>(arg_transpose->get_element_type(), input_shape); auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
auto pad_begin = apply_permutation(n->get_pads_begin(), def_order); auto pad_begin = apply_permutation(n->get_pads_begin(), def_order);
auto pad_end = apply_permutation(n->get_pads_end(), def_order); auto pad_end = apply_permutation(n->get_pads_end(), def_order);
@ -282,10 +293,10 @@ static void sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<N
auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end); auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end);
auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode()); auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode());
replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr()); replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
replace_node(n, new_pad); replace_node(n, new_pad);
auto new_transpose = make_transpose(new_pad, order); auto new_transpose = make_transpose(new_pad, order);
NGRAPH_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name(); OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_pad, new_transpose); write_transposemap(reorders, new_pad, new_transpose);
} }
@ -303,7 +314,8 @@ static void sink_concat(const shared_ptr<Concat>& n,
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order); auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto dummy_correct_shape = make_shared<ngraph::pattern::op::Label>(arg_transpose->get_element_type(), input_shape); auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
NodeVector new_args; NodeVector new_args;
new_args.push_back(dummy_correct_shape); new_args.push_back(dummy_correct_shape);
@ -314,7 +326,7 @@ static void sink_concat(const shared_ptr<Concat>& n,
auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr()); auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr());
auto iorder = iarg_transpose_order->get_axis_vector_val(); auto iorder = iarg_transpose_order->get_axis_vector_val();
if (iorder != order) { if (iorder != order) {
NGRAPH_DEBUG << " input order at " << i << "-th arg is different from first arg"; OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg";
materialize_shapes(n, reorders, transposes_to_delete); materialize_shapes(n, reorders, transposes_to_delete);
return; return;
} }
@ -322,7 +334,7 @@ static void sink_concat(const shared_ptr<Concat>& n,
auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order); auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order);
auto idummy_correct_shape = auto idummy_correct_shape =
make_shared<ngraph::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape); make_shared<ov::pass::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape); new_args.push_back(idummy_correct_shape);
} }
@ -330,14 +342,14 @@ static void sink_concat(const shared_ptr<Concat>& n,
auto new_concat = make_shared<Concat>(new_args, new_axis); auto new_concat = make_shared<Concat>(new_args, new_axis);
// put back the original arguments // put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++) { for (size_t i = 0; i < new_concat->get_input_size(); i++) {
NGRAPH_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name() OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name()
<< " input " << i; << " input " << i;
new_concat->input(i).replace_source_output(n->input_value(i)); new_concat->input(i).replace_source_output(n->input_value(i));
} }
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
replace_node(n, new_concat); replace_node(n, new_concat);
auto new_transpose = make_transpose(new_concat, order); auto new_transpose = make_transpose(new_concat, order);
NGRAPH_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name(); OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_concat, new_transpose); write_transposemap(reorders, new_concat, new_transpose);
} }
@ -358,16 +370,16 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptr<Fu
// STEP 1 : Sink or Swim transposes away for op clusters // STEP 1 : Sink or Swim transposes away for op clusters
try { try {
for (const auto& n : f->get_ordered_ops()) { for (const auto& n : f->get_ordered_ops()) {
NGRAPH_DEBUG << "Processing " << n->get_name(); OPENVINO_DEBUG << "Processing " << n->get_name();
// collect output shape of all Result nodes for a sanity check // collect output shape of all Result nodes for a sanity check
if (ngraph::op::is_output(n)) { if (ov::op::util::is_output(n)) {
orig_result_out_shape[n->get_name()] = n->get_output_shape(0); orig_result_out_shape[n->get_name()] = n->get_output_shape(0);
} }
if (auto transpose = as_type_ptr<opset8::Transpose>(n)) { if (auto transpose = as_type_ptr<opset8::Transpose>(n)) {
sink_transpose(transpose, reorders, transposes_to_delete); sink_transpose(transpose, reorders, transposes_to_delete);
} else if (ngraph::op::is_unary_elementwise_arithmetic(n)) { } else if (ov::op::util::is_unary_elementwise_arithmetic(n)) {
sink_unary(n, reorders, transposes_to_delete); sink_unary(n, reorders, transposes_to_delete);
} else if (ngraph::op::is_binary_elementwise_arithmetic(n)) { } else if (ov::op::util::is_binary_elementwise_arithmetic(n)) {
sink_binary(n, reorders, transposes_to_delete); sink_binary(n, reorders, transposes_to_delete);
} else if (auto pad = as_type_ptr<Pad>(n)) { } else if (auto pad = as_type_ptr<Pad>(n)) {
sink_pad(pad, reorders, transposes_to_delete); sink_pad(pad, reorders, transposes_to_delete);
@ -378,18 +390,18 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptr<Fu
} }
} }
} catch (...) { } catch (...) {
NGRAPH_DEBUG << "Caught exception while sinking op"; OPENVINO_DEBUG << "Caught exception while sinking op";
return false; return false;
} }
// STEP 2: purge all the transposes we either sunk or swam. // STEP 2: purge all the transposes we either sunk or swam.
NGRAPH_DEBUG << "Purging transposes "; OPENVINO_DEBUG << "Purging transposes ";
for (const auto& r : transposes_to_delete) { for (const auto& r : transposes_to_delete) {
delete_transpose(r); delete_transpose(r);
} }
// STEP 3: fix wrong shape info wholesale // STEP 3: fix wrong shape info wholesale
NGRAPH_DEBUG << "Fixing wrong shape info for the whole graph"; OPENVINO_DEBUG << "Fixing wrong shape info for the whole graph";
for (const auto& n : f->get_ordered_ops()) { for (const auto& n : f->get_ordered_ops()) {
n->revalidate_and_infer_types(); n->revalidate_and_infer_types();
} }
@ -397,7 +409,7 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptr<Fu
const ResultVector& results = f->get_results(); const ResultVector& results = f->get_results();
for (const auto& r : results) { for (const auto& r : results) {
// make sure shapes are always materialized before results // make sure shapes are always materialized before results
NGRAPH_CHECK( FRONT_END_GENERAL_CHECK(
r->get_shape() == r->get_input_shape(0) && r->get_element_type() == r->input_value(0).get_element_type(), r->get_shape() == r->get_input_shape(0) && r->get_element_type() == r->input_value(0).get_element_type(),
" op::Result = ", " op::Result = ",
*r, *r,
@ -406,11 +418,11 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptr<Fu
// make sure that after TransposeSinking pass the output_shape for Result // make sure that after TransposeSinking pass the output_shape for Result
// does not change from the expected output_shape before the pass // does not change from the expected output_shape before the pass
NGRAPH_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()], FRONT_END_GENERAL_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()],
" op::Result = ", " op::Result = ",
*r, *r,
" expected output shape = ", " expected output shape = ",
orig_result_out_shape[r->get_name()]); orig_result_out_shape[r->get_name()]);
} }
return true; return true;
} }

View File

@ -4,8 +4,7 @@
#include "place.hpp" #include "place.hpp"
#include <frontend_manager/frontend_exceptions.hpp> #include "frontend_manager/frontend_exceptions.hpp"
#include "node_context.hpp" #include "node_context.hpp"
#include "op_def.pb.h" #include "op_def.pb.h"
#include "tensor.pb.h" #include "tensor.pb.h"

View File

@ -4,8 +4,8 @@
#pragma once #pragma once
#include <frontend_manager/frontend.hpp> #include "frontend_manager/frontend.hpp"
#include <tensorflow_frontend/decoder.hpp> #include "tensorflow_frontend/decoder.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {

View File

@ -2,8 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <frontend_manager/frontend_manager.hpp> #include "frontend_manager/frontend_manager.hpp"
#include <tensorflow_frontend/frontend.hpp> #include "tensorflow_frontend/frontend.hpp"
extern "C" OPENVINO_CORE_EXPORTS ov::frontend::FrontEndVersion GetAPIVersion() { extern "C" OPENVINO_CORE_EXPORTS ov::frontend::FrontEndVersion GetAPIVersion() {
return OV_FRONTEND_API_VERSION; return OV_FRONTEND_API_VERSION;

View File

@ -5,9 +5,9 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <tensorflow_frontend/decoder.hpp>
#include "openvino/op/util/framework_node.hpp" #include "openvino/op/util/framework_node.hpp"
#include "tensorflow_frontend/decoder.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {

View File

@ -4,8 +4,7 @@
#include "utils.hpp" #include "utils.hpp"
void ov::frontend::tf::tf_shape_to_ngraph_shape(const tensorflow::TensorShapeProto& tf_shape, void ov::frontend::tf::tf_shape_to_ov_shape(const tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape) {
ov::PartialShape* ng_shape) {
std::vector<ov::Dimension> dims; std::vector<ov::Dimension> dims;
for (int i = 0; i < tf_shape.dim_size(); i++) { for (int i = 0; i < tf_shape.dim_size(); i++) {
dims.emplace_back(tf_shape.dim(i).size()); dims.emplace_back(tf_shape.dim(i).size());

View File

@ -21,11 +21,11 @@
#pragma once #pragma once
#include "graph_iterator_proto.hpp" #include "graph_iterator_proto.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph_conversions.hpp"
#include "node_context.hpp" #include "node_context.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset8.hpp" #include "openvino/opsets/opset8.hpp"
#include "openvino/util/log.hpp"
#include "openvino_conversions.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {
@ -69,10 +69,10 @@ void make_padding(const std::string& tf_padding_type,
} }
} }
void tf_shape_to_ngraph_shape(const ::tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape); void tf_shape_to_ov_shape(const ::tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape);
template <typename T> template <typename T>
void get_static_input_vec(const NodeContext& node, int64_t input_index, std::vector<T>* vector) { void get_const_input(const NodeContext& node, int64_t input_index, std::vector<T>* vector) {
auto ng_input = node.get_input(input_index); auto ng_input = node.get_input(input_index);
if (auto constant = std::dynamic_pointer_cast<opset8::Constant>(ng_input.get_node_shared_ptr())) { if (auto constant = std::dynamic_pointer_cast<opset8::Constant>(ng_input.get_node_shared_ptr())) {
*vector = constant->cast_vector<T>(); *vector = constant->cast_vector<T>();
@ -87,7 +87,7 @@ void get_static_input_vec(const NodeContext& node, int64_t input_index, std::vec
// Modified with an extra `VecT` parameter to handle the case where the type // Modified with an extra `VecT` parameter to handle the case where the type
// in the std::vector does not match TensorFlow's notion of what the C++ type // in the std::vector does not match TensorFlow's notion of what the C++ type
// should be (e.g. when T is `bool`, we actually need a std::vector of `char` for // should be (e.g. when T is `bool`, we actually need a std::vector of `char` for
// compatibility with nGraph). // compatibility with OpenVINO).
template <typename T, typename VecT = T> template <typename T, typename VecT = T>
void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_shape, std::vector<VecT>* values) { void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_shape, std::vector<VecT>* values) {
TF_OP_VALIDATION_CHECK(node, node.get_op_type() == "Const", "Node is expected to be Constant."); TF_OP_VALIDATION_CHECK(node, node.get_op_type() == "Const", "Node is expected to be Constant.");
@ -104,7 +104,7 @@ void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_sha
const tensorflow::TensorShapeProto& shape = tensor_proto.tensor_shape(); const tensorflow::TensorShapeProto& shape = tensor_proto.tensor_shape();
ov::PartialShape pshape; ov::PartialShape pshape;
tf_shape_to_ngraph_shape(shape, &pshape); tf_shape_to_ov_shape(shape, &pshape);
*const_tensor_shape = pshape.get_shape(); *const_tensor_shape = pshape.get_shape();
TF_OP_VALIDATION_CHECK(node, pshape.is_static(), "Dynamic shapes are not supported in Constant conversion."); TF_OP_VALIDATION_CHECK(node, pshape.is_static(), "Dynamic shapes are not supported in Constant conversion.");
auto tensor_content = tensor_proto.tensor_content(); auto tensor_content = tensor_proto.tensor_content();
@ -171,8 +171,8 @@ void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_sha
val_i = tensor_proto.double_val()[i]; val_i = tensor_proto.double_val()[i];
break; break;
default: default:
NGRAPH_DEBUG << "Const node has empty tensor_proto and we don't know how to " OPENVINO_DEBUG << "Const node has empty tensor_proto and we don't know how to "
"handle this element type"; "handle this element type";
FRONT_END_THROW("Encountered unknown element type " + DataType_Name(dt) + " on an empty tensor_proto"); FRONT_END_THROW("Encountered unknown element type " + DataType_Name(dt) + " on an empty tensor_proto");
} }
TF_OP_VALIDATION_CHECK(node, val_size != 0, "Empty values vector"); TF_OP_VALIDATION_CHECK(node, val_size != 0, "Empty values vector");