diff --git a/ngraph/frontend/tensorflow/CMakeLists.txt b/ngraph/frontend/tensorflow/CMakeLists.txt index ee3102b9db2..8f58386f4cb 100644 --- a/ngraph/frontend/tensorflow/CMakeLists.txt +++ b/ngraph/frontend/tensorflow/CMakeLists.txt @@ -77,7 +77,7 @@ ie_add_api_validator_post_build_step(TARGET ${TARGET_NAME}) link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES}) target_link_libraries(${TARGET_NAME} PRIVATE frontend_manager::static - PRIVATE ngraph::builder inference_engine_transformations libprotobuf openvino::util) + PRIVATE inference_engine_transformations libprotobuf openvino::util) add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME} EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS}) diff --git a/ngraph/frontend/tensorflow/include/tensorflow_frontend/decoder.hpp b/ngraph/frontend/tensorflow/include/tensorflow_frontend/decoder.hpp index 46bc8963d8f..7d9f2a32f4a 100644 --- a/ngraph/frontend/tensorflow/include/tensorflow_frontend/decoder.hpp +++ b/ngraph/frontend/tensorflow/include/tensorflow_frontend/decoder.hpp @@ -4,8 +4,8 @@ #pragma once -#include -#include +#include "openvino/core/variant.hpp" +#include "tensorflow_frontend/utility.hpp" namespace ov { namespace frontend { diff --git a/ngraph/frontend/tensorflow/src/decoder_proto.cpp b/ngraph/frontend/tensorflow/src/decoder_proto.cpp index 784d352df6f..31ec6657ebe 100644 --- a/ngraph/frontend/tensorflow/src/decoder_proto.cpp +++ b/ngraph/frontend/tensorflow/src/decoder_proto.cpp @@ -47,8 +47,7 @@ shared_ptr DecoderTFProto::get_attribute(const string& name, const Vari if (is_type(type_info)) { return create_variant(attrs[0].s()); - } - if (is_type(type_info)) { + } else if (is_type(type_info)) { return create_variant(attrs[0].i()); } else if (is_type>(type_info)) { vector longs; diff --git a/ngraph/frontend/tensorflow/src/decoder_proto.hpp b/ngraph/frontend/tensorflow/src/decoder_proto.hpp index a446cdecf79..087986deeb3 100644 --- a/ngraph/frontend/tensorflow/src/decoder_proto.hpp +++ b/ngraph/frontend/tensorflow/src/decoder_proto.hpp @@ -4,13 +4,12 @@ #pragma once -#include #include -#include #include #include "attr_value.pb.h" #include "node_def.pb.h" +#include "tensorflow_frontend/decoder.hpp" #include "types.pb.h" namespace ov { diff --git a/ngraph/frontend/tensorflow/src/exceptions.hpp b/ngraph/frontend/tensorflow/src/exceptions.hpp index 211cebbed10..1dd00a5328a 100644 --- a/ngraph/frontend/tensorflow/src/exceptions.hpp +++ b/ngraph/frontend/tensorflow/src/exceptions.hpp @@ -3,8 +3,7 @@ // #pragma once -#include - +#include "frontend_manager/frontend_exceptions.hpp" #include "openvino/core/node.hpp" namespace ov { diff --git a/ngraph/frontend/tensorflow/src/frontend.cpp b/ngraph/frontend/tensorflow/src/frontend.cpp index 694b04be1fa..c2392c76e90 100644 --- a/ngraph/frontend/tensorflow/src/frontend.cpp +++ b/ngraph/frontend/tensorflow/src/frontend.cpp @@ -2,13 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include "tensorflow_frontend/frontend.hpp" #include "model.hpp" #include "op_table.hpp" +#include "openvino/pass/manager.hpp" #include "openvino/util/common_util.hpp" #include "pass/transpose_sinking.hpp" +#include "tensorflow_frontend/graph_iterator.hpp" #include "tf_framework_node.hpp" #include "utils.hpp" @@ -52,13 +53,13 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, bool fail_fast, bool no_conversion, std::shared_ptr& ng_function) const { - // a map from operation names to generated nGraph Output + // a map from operation names to generated OV Output tf::OpMap ng_op_map; ov::ParameterVector params; ov::ResultVector results; const auto& model_tf = std::dynamic_pointer_cast(model); - FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into nGraph function"); + FRONT_END_GENERAL_CHECK(model_tf, "nullptr for InputModel is given for translation into OV function"); const auto& operation_places = model_tf->get_op_places(); const auto& model_inputs = model_tf->get_inputs(); const auto& model_outputs = model_tf->get_outputs(); @@ -101,7 +102,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, ng_op_map[input_name] = {param}; } - // create the nGraph ops from TensorFlow ops + // create the OV ops from TensorFlow ops for (const auto& operation_place : operation_places) { auto operation_decoder = operation_place->get_decoder(); auto operation_name = operation_place->get_names()[0]; @@ -110,7 +111,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, continue; } - // prepare a list of nGraph node inputs for each node + // prepare a list of OV node inputs for each node ov::OutputVector ng_inputs; ::ov::frontend::tf::NamedInputs named_inputs; for (size_t input_port_idx = 0; input_port_idx < operation_decoder->get_input_size(); ++input_port_idx) { @@ -157,7 +158,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, } } - // generate nGraph node output vector for the current operation node + // generate OV node output vector for the current operation node ov::OutputVector ng_outputs; try { FRONT_END_OP_CONVERSION_CHECK(translate_map.count(operation_decoder->get_op_type()), @@ -166,7 +167,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, // NodeContext node_context(ng_inputs, operation_decoder, model_inputs); // TODO: Check why NodeContextNew doesn't have ngOutputVector ng_inputs input in constructor ::ov::frontend::tf::NodeContext node_context(*operation_decoder, named_inputs); - // generate nGraph node output vector using translator for given operation type + // generate OV node output vector using translator for given operation type ng_outputs = (*op_fun)(node_context); } catch (...) { if (fail_fast) { @@ -181,7 +182,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, } } - // register nGraph node outputs in the map for new operation node + // register OV node outputs in the map for new operation node for (const auto& output : ng_outputs) { if (auto result = std::dynamic_pointer_cast(output.get_node_shared_ptr())) { // do not add RetVal type operation to ng_op_map @@ -247,7 +248,7 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, results.push_back(std::make_shared(node_outputs[producer_port_idx])); } } - // find all terminal nodes in ngraph graph to complete list of results + // find all terminal nodes in OV graph to complete list of results if (results.empty()) { for (const auto& node_output_vector : ng_op_map) { for (const auto& output : node_output_vector.second) { @@ -261,9 +262,9 @@ void FrontEndTF::translate_graph(const ov::frontend::InputModel::Ptr& model, // TODO: reorder results and params according to indices given in RT info (if any) - // create the nGraph function + // create the OV function ng_function = std::make_shared(results, params, model_name); - NGRAPH_DEBUG << "Done with translations"; + OPENVINO_DEBUG << "Done with translations"; } /// \brief Check if FrontEndTensorflow can recognize model from given parts @@ -309,7 +310,7 @@ std::shared_ptr FrontEndTF::convert(ov::frontend::InputModel::Ptr std::shared_ptr f; translate_graph(model_tf, "here_should_be_a_graph_name", true, false, f); normalize(f); - // TODO: check that nGraph function does not contain operations which are not in the opset + // TODO: check that OV function does not contain operations which are not in the opset return f; } diff --git a/ngraph/frontend/tensorflow/src/graph_iterator_proto.hpp b/ngraph/frontend/tensorflow/src/graph_iterator_proto.hpp index d4004d10f2c..50b0796ff6a 100644 --- a/ngraph/frontend/tensorflow/src/graph_iterator_proto.hpp +++ b/ngraph/frontend/tensorflow/src/graph_iterator_proto.hpp @@ -5,12 +5,12 @@ #pragma once #include -#include -#include #include "decoder_proto.hpp" #include "graph.pb.h" #include "node_def.pb.h" +#include "tensorflow_frontend/decoder.hpp" +#include "tensorflow_frontend/graph_iterator.hpp" namespace ov { namespace frontend { diff --git a/ngraph/frontend/tensorflow/src/model.cpp b/ngraph/frontend/tensorflow/src/model.cpp index 98a922a4d3c..a8cfbf62ab4 100644 --- a/ngraph/frontend/tensorflow/src/model.cpp +++ b/ngraph/frontend/tensorflow/src/model.cpp @@ -4,16 +4,14 @@ #include "model.hpp" -#include #include -#include #include -#include -#include "graph_iterator_proto.hpp" -#include "ngraph_conversions.hpp" +#include "frontend_manager/frontend_exceptions.hpp" #include "node_context.hpp" +#include "openvino/opsets/opset7.hpp" #include "place.hpp" +#include "tensorflow_frontend/graph_iterator.hpp" #include "utils.hpp" using namespace google; diff --git a/ngraph/frontend/tensorflow/src/model.hpp b/ngraph/frontend/tensorflow/src/model.hpp index 94936e0684c..4a7a0279fee 100644 --- a/ngraph/frontend/tensorflow/src/model.hpp +++ b/ngraph/frontend/tensorflow/src/model.hpp @@ -4,9 +4,9 @@ #pragma once -#include -#include -#include +#include "frontend_manager/input_model.hpp" +#include "frontend_manager/place.hpp" +#include "tensorflow_frontend/graph_iterator.hpp" namespace ov { namespace frontend { diff --git a/ngraph/frontend/tensorflow/src/node_context.hpp b/ngraph/frontend/tensorflow/src/node_context.hpp index e2f92e53884..31d583def6d 100644 --- a/ngraph/frontend/tensorflow/src/node_context.hpp +++ b/ngraph/frontend/tensorflow/src/node_context.hpp @@ -3,15 +3,15 @@ // #pragma once -#include #include "exceptions.hpp" #include "openvino/core/variant.hpp" #include "place.hpp" #include "tensor.pb.h" +#include "tensorflow_frontend/utility.hpp" #include "types.pb.h" -#define NGRAPH_VARIANT_DECLARATION(TYPE, info) \ +#define OPENVINO_VARIANT_DECLARATION(TYPE, info) \ template <> \ class VariantWrapper : public VariantImpl { \ public: \ @@ -20,18 +20,18 @@ } namespace ov { -NGRAPH_VARIANT_DECLARATION(int32_t, "Variant::int32"); -NGRAPH_VARIANT_DECLARATION(uint64_t, "Variant::uint64_t"); -NGRAPH_VARIANT_DECLARATION(std::vector, "Variant::int32_vector"); -NGRAPH_VARIANT_DECLARATION(float, "Variant::float"); -NGRAPH_VARIANT_DECLARATION(std::vector, "Variant::float_vector"); -NGRAPH_VARIANT_DECLARATION(bool, "Variant::bool"); -NGRAPH_VARIANT_DECLARATION(ov::element::Type, "Variant::ov_element_type"); -NGRAPH_VARIANT_DECLARATION(std::vector, "Variant::int64_vector"); -NGRAPH_VARIANT_DECLARATION(ov::PartialShape, "Variant::ngraph_PartialShape"); -NGRAPH_VARIANT_DECLARATION(std::vector, "Variant::string_vector"); -NGRAPH_VARIANT_DECLARATION(::tensorflow::DataType, "Variant::DataType"); -NGRAPH_VARIANT_DECLARATION(::tensorflow::TensorProto, "Variant::TensorProto"); +OPENVINO_VARIANT_DECLARATION(int32_t, "Variant::int32"); +OPENVINO_VARIANT_DECLARATION(uint64_t, "Variant::uint64_t"); +OPENVINO_VARIANT_DECLARATION(std::vector, "Variant::int32_vector"); +OPENVINO_VARIANT_DECLARATION(float, "Variant::float"); +OPENVINO_VARIANT_DECLARATION(std::vector, "Variant::float_vector"); +OPENVINO_VARIANT_DECLARATION(bool, "Variant::bool"); +OPENVINO_VARIANT_DECLARATION(ov::element::Type, "Variant::ov_element_type"); +OPENVINO_VARIANT_DECLARATION(std::vector, "Variant::int64_vector"); +OPENVINO_VARIANT_DECLARATION(ov::PartialShape, "Variant:ov_PartialShape"); +OPENVINO_VARIANT_DECLARATION(std::vector, "Variant::string_vector"); +OPENVINO_VARIANT_DECLARATION(::tensorflow::DataType, "Variant::DataType"); +OPENVINO_VARIANT_DECLARATION(::tensorflow::TensorProto, "Variant::TensorProto"); } // namespace ov namespace ov { diff --git a/ngraph/frontend/tensorflow/src/op/arg_min_max.cpp b/ngraph/frontend/tensorflow/src/op/arg_min_max.cpp index 50146a386c6..ed88729ea8a 100644 --- a/ngraph/frontend/tensorflow/src/op/arg_min_max.cpp +++ b/ngraph/frontend/tensorflow/src/op/arg_min_max.cpp @@ -13,11 +13,11 @@ namespace frontend { namespace tf { namespace op { -OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) { +OutputVector translate_arg_min_max(const NodeContext& node, std::string mode) { Output ng_input = node.get_input(0); std::vector tf_dim; - get_static_input_vec(node, 1, &tf_dim); + get_const_input(node, 1, &tf_dim); Shape input_shape = ng_input.get_shape(); size_t input_rank = input_shape.size(); @@ -26,10 +26,10 @@ OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) { // If input dimension is negative, make it positive if (tf_dim[0] < 0) { - NGRAPH_DEBUG << "Input dimension is negative, make it positive " << tf_dim[0]; + OPENVINO_DEBUG << "Input dimension is negative, make it positive " << tf_dim[0]; tf_dim[0] = (int64_t)input_rank + tf_dim[0]; } - NGRAPH_DEBUG << "Axis along which to compute " << tf_dim[0]; + OPENVINO_DEBUG << "Axis along which to compute " << tf_dim[0]; size_t k_axis = tf_dim[0]; auto ng_et = node.get_attribute("output_type"); @@ -47,11 +47,11 @@ OutputVector TranslateArgMinMax(const NodeContext& node, std::string mode) { } OutputVector translate_arg_max_op(const NodeContext& node) { - return (TranslateArgMinMax(node, "max")); + return (translate_arg_min_max(node, "max")); } OutputVector translate_arg_min_op(const NodeContext& node) { - return (TranslateArgMinMax(node, "min")); + return (translate_arg_min_max(node, "min")); } } // namespace op } // namespace tf diff --git a/ngraph/frontend/tensorflow/src/op/avg_pool.cpp b/ngraph/frontend/tensorflow/src/op/avg_pool.cpp index ef34cd07cde..9430f6d90cb 100644 --- a/ngraph/frontend/tensorflow/src/op/avg_pool.cpp +++ b/ngraph/frontend/tensorflow/src/op/avg_pool.cpp @@ -46,7 +46,7 @@ OutputVector translate_avg_pool_op(const NodeContext& node) { padding_below, padding_above); - // TODO: remove this once nGraph supports negative padding + // TODO: remove this once OV supports negative padding // (CoordinateDiff) for AvgPool Shape ng_padding_below(padding_below.begin(), padding_below.end()); Shape ng_padding_above(padding_above.begin(), padding_above.end()); diff --git a/ngraph/frontend/tensorflow/src/op/concat.cpp b/ngraph/frontend/tensorflow/src/op/concat.cpp index 513770039aa..4e2bbe26039 100644 --- a/ngraph/frontend/tensorflow/src/op/concat.cpp +++ b/ngraph/frontend/tensorflow/src/op/concat.cpp @@ -31,7 +31,7 @@ OutputVector translate_concat_op(const NodeContext& node) { } std::vector tf_concat_axis_vec; - get_static_input_vec(node, axis_idx, &tf_concat_axis_vec); + get_const_input(node, axis_idx, &tf_concat_axis_vec); int64_t concat_axis = tf_concat_axis_vec[0]; OutputVector ng_args; diff --git a/ngraph/frontend/tensorflow/src/op/const.cpp b/ngraph/frontend/tensorflow/src/op/const.cpp index f0dee56b301..054f2e44eb8 100644 --- a/ngraph/frontend/tensorflow/src/op/const.cpp +++ b/ngraph/frontend/tensorflow/src/op/const.cpp @@ -18,7 +18,7 @@ using ConstMap = std::map&)>, const ov::element::Type>>; -const ConstMap& TF_NGRAPH_CONST_MAP() { +const ConstMap& TF_OPENVINO_CONST_MAP() { static const ConstMap the_map = { {ov::element::f32, make_pair(make_const_op, ov::element::f32)}, {ov::element::f64, make_pair(make_const_op, ov::element::f64)}, @@ -43,23 +43,13 @@ OutputVector translate_const_op(const NodeContext& node) { auto dt = node.get_attribute("dtype"); Output res; - // For some reason the following do not work (no specialization of - // tensorflow::checkpoint::SavedTypeTraits...) - // case DataType::DT_UINT32: - // TF_RETURN_IF_ERROR(make_const_op(op, element::u32, - // &ng_node)); - // break; - // case DataType::DT_UINT64: - // TF_RETURN_IF_ERROR(make_const_op(op, element::u64, - // &ng_node)); - // break; + // TODO: fix DT_UINT32 and DT_UINT64 support + // no specialization of tensorflow::checkpoint::SavedTypeTraits...) try { - const auto& func_param = TF_NGRAPH_CONST_MAP().at(dt); + const auto& func_param = TF_OPENVINO_CONST_MAP().at(dt); func_param.first(node, func_param.second, res); } catch (const std::out_of_range&) { - TF_OP_VALIDATION_CHECK(node, - false, - "Failed to translate Constant with target ngraph type:" + dt.get_type_name()); + TF_OP_VALIDATION_CHECK(node, false, "Failed to translate Constant with target OV type:" + dt.get_type_name()); } set_node_name(node.get_name(), res.get_node_shared_ptr()); return {res}; diff --git a/ngraph/frontend/tensorflow/src/op/conv_2d.cpp b/ngraph/frontend/tensorflow/src/op/conv_2d.cpp index 599354781a9..53f132415b0 100644 --- a/ngraph/frontend/tensorflow/src/op/conv_2d.cpp +++ b/ngraph/frontend/tensorflow/src/op/conv_2d.cpp @@ -48,7 +48,7 @@ OutputVector translate_conv_2d_op(const NodeContext& node) { auto& ng_filter_shape = ng_filter.get_shape(); ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; - Transpose<3, 2, 0, 1>(ng_filter); + transpose<3, 2, 0, 1>(ng_filter); CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_above; diff --git a/ngraph/frontend/tensorflow/src/op/conv_2d_backprop.cpp b/ngraph/frontend/tensorflow/src/op/conv_2d_backprop.cpp index 6c40767a904..59ba2346bea 100644 --- a/ngraph/frontend/tensorflow/src/op/conv_2d_backprop.cpp +++ b/ngraph/frontend/tensorflow/src/op/conv_2d_backprop.cpp @@ -27,7 +27,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) { "Conv2DBackpropInput data format is neither NHWC nor NCHW"); std::vector tf_input_sizes; - get_static_input_vec(node, 0, &tf_input_sizes); + get_const_input(node, 0, &tf_input_sizes); if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(), [](int32_t size) { return size <= 0; @@ -62,7 +62,7 @@ OutputVector translate_conv_2d_backprop_input_op(const NodeContext& node) { auto& ng_filter_shape = ng_filter.get_shape(); ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; - Transpose<3, 2, 0, 1>(ng_filter); + transpose<3, 2, 0, 1>(ng_filter); CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_above; diff --git a/ngraph/frontend/tensorflow/src/op/conv_3d.cpp b/ngraph/frontend/tensorflow/src/op/conv_3d.cpp index 1a76e9419ab..a8c008e2abb 100644 --- a/ngraph/frontend/tensorflow/src/op/conv_3d.cpp +++ b/ngraph/frontend/tensorflow/src/op/conv_3d.cpp @@ -51,7 +51,7 @@ OutputVector translate_conv_3d_op(const NodeContext& node) { ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; ng_kernel_shape[2] = ng_filter_shape[2]; - Transpose3D<4, 3, 0, 1, 2>(ng_filter); + transpose_3d<4, 3, 0, 1, 2>(ng_filter); CoordinateDiff ng_padding_below; CoordinateDiff ng_padding_above; diff --git a/ngraph/frontend/tensorflow/src/op/crop_and_resize.cpp b/ngraph/frontend/tensorflow/src/op/crop_and_resize.cpp index 763e9dea540..57578f3f7ea 100644 --- a/ngraph/frontend/tensorflow/src/op/crop_and_resize.cpp +++ b/ngraph/frontend/tensorflow/src/op/crop_and_resize.cpp @@ -123,10 +123,10 @@ OutputVector translate_crop_and_resize_op(const NodeContext& node) { interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST; } - Transpose<0, 3, 1, 2>(ng_crop); + transpose<0, 3, 1, 2>(ng_crop); auto ng_output = make_shared(ng_crop, ng_size, ng_scales, ng_axes, interpolate_attrs)->output(0); - Transpose<0, 2, 3, 1>(ng_output); + transpose<0, 2, 3, 1>(ng_output); ng_crop_outputs.at(i) = ng_output; } diff --git a/ngraph/frontend/tensorflow/src/op/fake_quant_min_max_vars.cpp b/ngraph/frontend/tensorflow/src/op/fake_quant_min_max_vars.cpp index bdb2595225f..526a6fd1722 100644 --- a/ngraph/frontend/tensorflow/src/op/fake_quant_min_max_vars.cpp +++ b/ngraph/frontend/tensorflow/src/op/fake_quant_min_max_vars.cpp @@ -52,10 +52,10 @@ OutputVector translate_fake_quant_op(const NodeContext& node) { auto ng_input_shape = ng_input.get_shape(); if (ng_input_shape.size() == 4) - Transpose<0, 3, 1, 2>(ng_input); + transpose<0, 3, 1, 2>(ng_input); auto res = make_shared(ng_input, min_adj, max_adj, min_adj, max_adj, levels)->output(0); if (ng_input_shape.size() == 4) - Transpose<0, 2, 3, 1>(res); + transpose<0, 2, 3, 1>(res); set_node_name(node.get_name(), res.get_node_shared_ptr()); return {res}; diff --git a/ngraph/frontend/tensorflow/src/op/fused_batch_norm.cpp b/ngraph/frontend/tensorflow/src/op/fused_batch_norm.cpp index 93542365b16..b5a7b6016d4 100644 --- a/ngraph/frontend/tensorflow/src/op/fused_batch_norm.cpp +++ b/ngraph/frontend/tensorflow/src/op/fused_batch_norm.cpp @@ -27,12 +27,12 @@ OutputVector translate_fused_batch_norm_op(const NodeContext& node) { bool is_nhwc = (data_format == "NHWC"); - NGRAPH_DEBUG << "data_format: " << data_format; + OPENVINO_DEBUG << "data_format: " << data_format; // TODO: where does 0.0001 come from? auto tf_epsilon = node.get_attribute("epsilon", 0.0001); - NGRAPH_DEBUG << "epsilon: " << tf_epsilon; + OPENVINO_DEBUG << "epsilon: " << tf_epsilon; convert_nhwc_to_nchw(node.get_name(), is_nhwc, ng_input); diff --git a/ngraph/frontend/tensorflow/src/op/identity.cpp b/ngraph/frontend/tensorflow/src/op/identity.cpp index cc6bd45463b..285f9da8db5 100644 --- a/ngraph/frontend/tensorflow/src/op/identity.cpp +++ b/ngraph/frontend/tensorflow/src/op/identity.cpp @@ -15,6 +15,12 @@ namespace op { OutputVector translate_identity_op(const NodeContext& node) { auto input = node.get_input(0); + + // since the input node can have several outputs, and identity have only one input, + // we cannot use set_node_name(..) helper, we have to set names for output connected + // to this identity only. + // Node_1 -> Node_2 + // -(identity name) -> Identity set_out_name(node.get_name(), input); set_out_name(node.get_name() + ":" + "0", input); return {input}; diff --git a/ngraph/frontend/tensorflow/src/op/interpolate.cpp b/ngraph/frontend/tensorflow/src/op/interpolate.cpp index ed98d112ae7..cf6bf93e85d 100644 --- a/ngraph/frontend/tensorflow/src/op/interpolate.cpp +++ b/ngraph/frontend/tensorflow/src/op/interpolate.cpp @@ -36,9 +36,9 @@ ov::OutputVector translate_interpolate_op(const NodeContext& node) { auto ng_scales = make_shared(ng_sizes, ng_spatial_shape); auto ng_axes = make_shared(element::i32, Shape{2}, std::vector({2, 3})); - Transpose<0, 3, 1, 2>(input); + transpose<0, 3, 1, 2>(input); auto res = make_shared(input, input_sizes, ng_scales, ng_axes, interpolate_attrs)->output(0); - Transpose<0, 2, 3, 1>(res); + transpose<0, 2, 3, 1>(res); set_node_name(node.get_name(), res.get_node_shared_ptr()); return {res}; } diff --git a/ngraph/frontend/tensorflow/src/op/max_pool.cpp b/ngraph/frontend/tensorflow/src/op/max_pool.cpp index 82c72a5263c..54585f1af0a 100644 --- a/ngraph/frontend/tensorflow/src/op/max_pool.cpp +++ b/ngraph/frontend/tensorflow/src/op/max_pool.cpp @@ -49,7 +49,7 @@ OutputVector translate_max_pool_op(const NodeContext& node) { padding_below, padding_above); - // TODO: remove this once nGraph supports negative padding + // TODO: remove this once OV supports negative padding // (CoordinateDiff) for MaxPool Shape ng_padding_below(padding_below.begin(), padding_below.end()); Shape ng_padding_above(padding_above.begin(), padding_above.end()); diff --git a/ngraph/frontend/tensorflow/src/op/pad.cpp b/ngraph/frontend/tensorflow/src/op/pad.cpp index ef4e54b39b0..fec2fbd3854 100644 --- a/ngraph/frontend/tensorflow/src/op/pad.cpp +++ b/ngraph/frontend/tensorflow/src/op/pad.cpp @@ -46,7 +46,7 @@ OutputVector translate_pad_op(const NodeContext& node) { // Set pads_begin & pads_end (from the pad_val_op) std::vector paddings; - get_static_input_vec(node, 1, &paddings); + get_const_input(node, 1, &paddings); if (paddings.size() % 2 != 0) { TF_OP_VALIDATION_CHECK(node, false, diff --git a/ngraph/frontend/tensorflow/src/op/space_to_batch_nd.cpp b/ngraph/frontend/tensorflow/src/op/space_to_batch_nd.cpp index d1b2897ceca..940cc344677 100644 --- a/ngraph/frontend/tensorflow/src/op/space_to_batch_nd.cpp +++ b/ngraph/frontend/tensorflow/src/op/space_to_batch_nd.cpp @@ -20,10 +20,10 @@ OutputVector translate_batch_nd_and_space_nd_op(const NodeContext& node) { // ng_crops should be of shape N=[ng_input.get_shape()).size()] // But TF's ng_crops input is limited only to the spatial dimensions (neither // batch nor innermost), - // which would mean ngraph inputs have missing ng_crops[0] and ng_crops[N]. + // which would mean OV inputs have missing ng_crops[0] and ng_crops[N]. // Hence, pad ng_crops with zeros at both ends - // return with input if rank < 2 as ngraph's impl doesn't support it + // return with input if rank < 2 as OV's impl doesn't support it const auto& input_pshape = input.get_partial_shape(); const auto& block_shape_pshape = block_shape.get_partial_shape(); if (input_pshape.rank().is_static() && block_shape_pshape.rank().is_static()) { diff --git a/ngraph/frontend/tensorflow/src/op_table.hpp b/ngraph/frontend/tensorflow/src/op_table.hpp index 7cf9f4ba47e..f1785cd9956 100644 --- a/ngraph/frontend/tensorflow/src/op_table.hpp +++ b/ngraph/frontend/tensorflow/src/op_table.hpp @@ -8,9 +8,9 @@ #include #include -#include "ngraph_conversions.hpp" #include "node_context.hpp" #include "openvino/core/node_vector.hpp" +#include "openvino_conversions.hpp" #include "utils.hpp" namespace ov { diff --git a/ngraph/frontend/tensorflow/src/ngraph_conversions.cpp b/ngraph/frontend/tensorflow/src/openvino_conversions.cpp similarity index 77% rename from ngraph/frontend/tensorflow/src/ngraph_conversions.cpp rename to ngraph/frontend/tensorflow/src/openvino_conversions.cpp index 8f1814da8eb..6f98ae472a7 100644 --- a/ngraph/frontend/tensorflow/src/ngraph_conversions.cpp +++ b/ngraph/frontend/tensorflow/src/openvino_conversions.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph_conversions.hpp" +#include "openvino_conversions.hpp" #include "utils.hpp" @@ -14,9 +14,9 @@ void convert_nhwc_to_nchw(const std::string& op_name, bool need_convert, ov::Out if (need_convert) { auto rank = node.get_shape().size(); if (rank == 4) { - Transpose<0, 3, 1, 2>(node); + transpose<0, 3, 1, 2>(node); } else if (rank == 5) { - Transpose3D<0, 4, 1, 2, 3>(node); + transpose_3d<0, 4, 1, 2, 3>(node); } } } @@ -25,9 +25,9 @@ void convert_nchw_to_nhwc(const std::string& op_name, bool need_convert, ov::Out if (need_convert) { auto rank = node.get_shape().size(); if (rank == 4) { - Transpose<0, 2, 3, 1>(node); + transpose<0, 2, 3, 1>(node); } else if (rank == 5) { - Transpose3D<0, 2, 3, 4, 1>(node); + transpose_3d<0, 2, 3, 4, 1>(node); } } } diff --git a/ngraph/frontend/tensorflow/src/ngraph_conversions.hpp b/ngraph/frontend/tensorflow/src/openvino_conversions.hpp similarity index 82% rename from ngraph/frontend/tensorflow/src/ngraph_conversions.hpp rename to ngraph/frontend/tensorflow/src/openvino_conversions.hpp index fb44b2472ff..bf8d6d589a0 100644 --- a/ngraph/frontend/tensorflow/src/ngraph_conversions.hpp +++ b/ngraph/frontend/tensorflow/src/openvino_conversions.hpp @@ -5,10 +5,10 @@ #pragma once #include -#include #include "graph.pb.h" #include "openvino/opsets/opset8.hpp" +#include "tensorflow_frontend/utility.hpp" #include "types.pb.h" namespace ov { @@ -18,11 +18,9 @@ namespace tf { using ::tensorflow::DataType; template -void Transpose(ov::Output& node) { +void transpose(ov::Output& node) { static_assert(a < 4 && b < 4 && c < 4 && d < 4, "Number of dimensions cannot exceed 4"); static_assert(a != b && a != c && a != d && b != c && b != d && c != d, "Dimensions indices cannot be equal"); - auto& s = node.get_shape(); - ov::Shape reshaped_shape{s[a], s[b], s[c], s[d]}; ov::Shape transpose_order{a, b, c, d}; auto input_order = std::make_shared(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order); @@ -30,17 +28,15 @@ void Transpose(ov::Output& node) { } template -void Transpose(std::shared_ptr& node) { - Transpose(node->get_default_output()); +void transpose(std::shared_ptr& node) { + transpose(node->get_default_output()); } template -void Transpose3D(ov::Output& node) { +void transpose_3d(ov::Output& node) { static_assert(a < 5 && b < 5 && c < 5 && d < 5 && e < 5, "Number of dimensions cannot exceed 5"); static_assert(a != b && a != c && a != d && a != e && b != c && b != d && b != e && c != d && c != e && d != e, "Dimensions indices cannot be equal"); - auto& s = node.get_shape(); - ov::Shape reshaped_shape{s[a], s[b], s[c], s[d], s[e]}; ov::Shape transpose_order{a, b, c, d, e}; auto input_order = std::make_shared(ov::element::u64, ov::Shape{transpose_order.size()}, transpose_order); @@ -48,8 +44,8 @@ void Transpose3D(ov::Output& node) { } template -void Transpose3D(std::shared_ptr& node) { - Transpose3D(node->get_default_output()); +void transpose_3d(std::shared_ptr& node) { + transpose_3d(node->get_default_output()); } namespace detail { diff --git a/ngraph/frontend/tensorflow/src/pass/transpose_sinking.cpp b/ngraph/frontend/tensorflow/src/pass/transpose_sinking.cpp index cdfccebf598..1beb636dd15 100644 --- a/ngraph/frontend/tensorflow/src/pass/transpose_sinking.cpp +++ b/ngraph/frontend/tensorflow/src/pass/transpose_sinking.cpp @@ -4,7 +4,10 @@ #include "transpose_sinking.hpp" +#include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset8.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/util/common_util.hpp" #include "utils.hpp" using namespace std; @@ -30,6 +33,12 @@ static AxisVector permutation_to_default_order(const AxisVector& axis_order) { return out; } +static AxisVector get_default_order(size_t rank) { + AxisVector default_order(rank); + std::iota(begin(default_order), end(default_order), 0); + return default_order; +} + template static string describe(shared_ptr node) { // ensure that it's either a reshape or a transpose @@ -41,8 +50,8 @@ static string describe(shared_ptr node) { stringstream ss; auto transpose = as_type_ptr(node); auto const1 = as_type_ptr(transpose->get_input_node_shared_ptr(1)); - ss << transpose->get_name() << " ( axis order = " << ngraph::vector_to_string(const1->get_axis_vector_val()) - << " , shape = " << ngraph::vector_to_string(transpose->get_shape()) << " ) " + ss << transpose->get_name() << " ( axis order = " << ov::util::vector_to_string(const1->get_axis_vector_val()) + << " , shape = " << ov::util::vector_to_string(transpose->get_shape()) << " ) " << " , input = " << transpose->input_value(0).get_node()->get_name(); return ss.str(); } @@ -50,14 +59,14 @@ static string describe(shared_ptr node) { static shared_ptr make_transpose(const Output& arg, const AxisVector& input_order) { auto order = std::make_shared(element::u64, Shape{input_order.size()}, input_order); auto transpose = make_shared(arg, order); - NGRAPH_DEBUG << "Make Transpose " << describe(transpose); + OPENVINO_DEBUG << "Make Transpose " << describe(transpose); return transpose; } static shared_ptr make_reshape(const Output& arg, const AxisVector& input_order) { auto order = std::make_shared(element::u64, Shape{input_order.size()}, input_order); auto transpose = make_shared(arg, order, false); - NGRAPH_DEBUG << "Make Reshape " << describe(transpose); + OPENVINO_DEBUG << "Make Reshape " << describe(transpose); return transpose; } @@ -65,19 +74,19 @@ static void write_transposemap(TransposeMap& reorders, const Output& target, const shared_ptr& transpose) { auto name = target.get_node()->get_name() + "." + to_string(target.get_index()); - NGRAPH_DEBUG << "Write TransposeMap[" << name << "] = " << describe(transpose); + OPENVINO_DEBUG << "Write TransposeMap[" << name << "] = " << describe(transpose); reorders[name] = transpose; } static shared_ptr read_transposemap(TransposeMap& reorders, const Output& target) { auto name = target.get_node()->get_name() + "." + to_string(target.get_index()); auto transpose = reorders[name]; - NGRAPH_DEBUG << "Read TransposeMap[" << name << "] -> " << describe(transpose); + OPENVINO_DEBUG << "Read TransposeMap[" << name << "] -> " << describe(transpose); return transpose; } static shared_ptr combine_transposes(const shared_ptr& t1, const shared_ptr& t2) { - auto default_order = ngraph::get_default_order(t1->get_shape()); + auto default_order = get_default_order(t1->get_shape().size()); auto t1_const = as_type_ptr(t1->input_value(1).get_node_shared_ptr()); auto t2_const = as_type_ptr(t2->input_value(1).get_node_shared_ptr()); @@ -85,31 +94,31 @@ static shared_ptr combine_transposes(const shared_ptr& t1, auto perm_t2 = apply_permutation(perm_t1, t2_const->get_axis_vector_val()); auto combined = make_transpose(t2->input_value(0), perm_t2); - NGRAPH_DEBUG << "Combining " << describe(t1) << " and " << describe(t2) << " into " - << describe(combined); + OPENVINO_DEBUG << "Combining " << describe(t1) << " and " << describe(t2) << " into " + << describe(combined); return combined; } static void insert_transpose(const shared_ptr& target, const shared_ptr& transpose, size_t input_index) { - NGRAPH_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index; + OPENVINO_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index; auto arg = target->input(input_index).get_source_output(); - NGRAPH_DEBUG << "Arg shape: " << arg.get_shape(); + OPENVINO_DEBUG << "Arg shape: " << arg.get_shape(); auto new_order = as_type_ptr(transpose->input_value(1).get_node_shared_ptr()); auto new_transpose = make_transpose(arg.get_node_shared_ptr(), new_order->get_axis_vector_val()); - NGRAPH_DEBUG << "Inserting transpose " << describe(new_transpose) << " at input " << target->get_name() - << " input index " << input_index; + OPENVINO_DEBUG << "Inserting transpose " << describe(new_transpose) << " at input " << target->get_name() + << " input index " << input_index; target->input(input_index).replace_source_output(new_transpose->output(0)); } static void delete_transpose(const shared_ptr& transpose) { - NGRAPH_DEBUG << "Removing transpose " << transpose->get_name(); + OPENVINO_DEBUG << "Removing transpose " << transpose->get_name(); if (!transpose->get_users().empty()) { Output output = transpose->output(0); - NGRAPH_DEBUG << "output " << output.get_node_shared_ptr()->get_name(); - NGRAPH_DEBUG << "target input size " << output.get_target_inputs().size(); + OPENVINO_DEBUG << "output " << output.get_node_shared_ptr()->get_name(); + OPENVINO_DEBUG << "target input size " << output.get_target_inputs().size(); for (auto input : output.get_target_inputs()) { - NGRAPH_DEBUG << "input " << input.get_node()->get_name(); + OPENVINO_DEBUG << "input " << input.get_node()->get_name(); input.replace_source_output(transpose->input_value(0)); } } @@ -117,12 +126,12 @@ static void delete_transpose(const shared_ptr& transpose) { static void mark_transpose_for_deletion(const shared_ptr& transpose, set>& transposes_to_delete) { - NGRAPH_DEBUG << "Marking transpose " << transpose->get_name() << " for deletion"; + OPENVINO_DEBUG << "Marking transpose " << transpose->get_name() << " for deletion"; transposes_to_delete.insert(transpose); } static shared_ptr create_default_transpose(const Output& n) { - auto default_order = ngraph::get_default_order(n.get_shape()); + auto default_order = get_default_order(n.get_shape().size()); auto order = std::make_shared(element::u64, Shape{default_order.size()}, default_order); return make_shared(n, order); } @@ -160,8 +169,8 @@ static void convert_binary_to_default_order(const shared_ptr& binary, } input.replace_source_output(new_node->output(0)); - NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right.get_shape()) << ", " - << right.get_node_shared_ptr()->get_name(); + OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", " + << right.get_node_shared_ptr()->get_name(); // this should now insert transpose on right mark_transpose_for_deletion(right_t, transposes_to_delete); write_transposemap(reorders, binary, right_t); @@ -180,11 +189,11 @@ static void materialize_shapes(const shared_ptr& n, // materialize all pending transposes, flush pending transposes auto arg = n->input_value(i); auto arg_transpose = read_transposemap(reorders, arg); - NGRAPH_DEBUG << "Materializing " << describe(arg_transpose) << " for " - << arg.get_node_shared_ptr()->get_name(); + OPENVINO_DEBUG << "Materializing " << describe(arg_transpose) << " for " + << arg.get_node_shared_ptr()->get_name(); mark_transpose_for_deletion(arg_transpose, transposes_to_delete); auto arg_transpose_order = as_type_ptr(arg_transpose->input_value(1).get_node_shared_ptr()); - if (arg_transpose_order->get_axis_vector_val() != ngraph::get_default_order(arg.get_shape())) { + if (arg_transpose_order->get_axis_vector_val() != get_default_order(arg.get_shape().size())) { // Insert if arg needs to be transposed. insert_transpose(n, arg_transpose, i); } @@ -194,7 +203,7 @@ static void materialize_shapes(const shared_ptr& n, static void sink_transpose(const shared_ptr& transpose, TransposeMap& reorders, set>& transposes_to_delete) { - NGRAPH_DEBUG << "Sinking Transpose :" << describe(transpose); + OPENVINO_DEBUG << "Sinking Transpose :" << describe(transpose); auto transpose_in = transpose->input_value(0); auto orig_transpose = read_transposemap(reorders, transpose_in); // combine both transposes @@ -212,7 +221,7 @@ static void sink_unary(const shared_ptr& n, TransposeMap& reorders, set>& /* transposes_to_delete */) { auto arg_transpose = read_transposemap(reorders, n->input_value(0)); - NGRAPH_DEBUG << "Propagating " << describe(arg_transpose) << " for " << n->get_name(); + OPENVINO_DEBUG << "Propagating " << describe(arg_transpose) << " for " << n->get_name(); write_transposemap(reorders, n, arg_transpose); } @@ -229,18 +238,19 @@ static void sink_binary(const shared_ptr& binary, auto left_order = left_const->get_axis_vector_val(); auto right_order = right_const->get_axis_vector_val(); - auto left_mismatch = left_order != ngraph::get_default_order(left.get_shape()); - auto right_mismatch = right_order != ngraph::get_default_order(right.get_shape()); + auto left_mismatch = left_order != get_default_order(left.get_shape().size()); + auto right_mismatch = right_order != get_default_order(right.get_shape().size()); - NGRAPH_DEBUG << "Sink binary " << binary->get_name() << " left transpose: " << ngraph::vector_to_string(left_order) - << " left default: " << ngraph::vector_to_string(ngraph::get_default_order(left.get_shape())) - << " right transpose: " << ngraph::vector_to_string(right_order) - << " right default: " << ngraph::vector_to_string(ngraph::get_default_order(right.get_shape())); + OPENVINO_DEBUG << "Sink binary " << binary->get_name() + << " left transpose: " << ov::util::vector_to_string(left_order) + << " left default: " << ov::util::vector_to_string(get_default_order(left.get_shape().size())) + << " right transpose: " << ov::util::vector_to_string(right_order) + << " right default: " << ov::util::vector_to_string(get_default_order(right.get_shape().size())); if ((left_order.size() == right_order.size() && left_order == right_order) || (!left_mismatch && !right_mismatch)) { // Propagate the reshape which matches the shape of the binary node auto new_transpose = (binary->get_output_shape(0) == left.get_shape()) ? left_t : right_t; - NGRAPH_DEBUG << "Propagating " << describe(new_transpose) << " for " << binary->get_name(); + OPENVINO_DEBUG << "Propagating " << describe(new_transpose) << " for " << binary->get_name(); write_transposemap(reorders, binary, new_transpose); // at this point, both transposes will be eventually removed mark_transpose_for_deletion(left_t, transposes_to_delete); @@ -273,7 +283,8 @@ static void sink_pad(shared_ptr n, TransposeMap& reorders, setget_shape(), def_order); - auto dummy_correct_shape = make_shared(arg_transpose->get_element_type(), input_shape); + auto dummy_correct_shape = + make_shared(arg_transpose->get_element_type(), input_shape); auto pad_begin = apply_permutation(n->get_pads_begin(), def_order); auto pad_end = apply_permutation(n->get_pads_end(), def_order); @@ -282,10 +293,10 @@ static void sink_pad(shared_ptr n, TransposeMap& reorders, set(element::i64, Shape{pad_end.size()}, pad_end); auto new_pad = make_shared(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode()); replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr()); - NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); + OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); replace_node(n, new_pad); auto new_transpose = make_transpose(new_pad, order); - NGRAPH_DEBUG << "Propagating " << describe(new_transpose) << " for " << n->get_name(); + OPENVINO_DEBUG << "Propagating " << describe(new_transpose) << " for " << n->get_name(); write_transposemap(reorders, new_pad, new_transpose); } @@ -303,7 +314,8 @@ static void sink_concat(const shared_ptr& n, auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order); - auto dummy_correct_shape = make_shared(arg_transpose->get_element_type(), input_shape); + auto dummy_correct_shape = + make_shared(arg_transpose->get_element_type(), input_shape); NodeVector new_args; new_args.push_back(dummy_correct_shape); @@ -314,7 +326,7 @@ static void sink_concat(const shared_ptr& n, auto iarg_transpose_order = as_type_ptr(iarg_transpose->input_value(1).get_node_shared_ptr()); auto iorder = iarg_transpose_order->get_axis_vector_val(); if (iorder != order) { - NGRAPH_DEBUG << " input order at " << i << "-th arg is different from first arg"; + OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg"; materialize_shapes(n, reorders, transposes_to_delete); return; } @@ -322,7 +334,7 @@ static void sink_concat(const shared_ptr& n, auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order); auto idummy_correct_shape = - make_shared(iarg_transpose->get_element_type(), iinput_shape); + make_shared(iarg_transpose->get_element_type(), iinput_shape); new_args.push_back(idummy_correct_shape); } @@ -330,14 +342,14 @@ static void sink_concat(const shared_ptr& n, auto new_concat = make_shared(new_args, new_axis); // put back the original arguments for (size_t i = 0; i < new_concat->get_input_size(); i++) { - NGRAPH_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name() - << " input " << i; + OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name() + << " input " << i; new_concat->input(i).replace_source_output(n->input_value(i)); } - NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); + OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); replace_node(n, new_concat); auto new_transpose = make_transpose(new_concat, order); - NGRAPH_DEBUG << "Propagating " << describe(new_transpose) << " for " << n->get_name(); + OPENVINO_DEBUG << "Propagating " << describe(new_transpose) << " for " << n->get_name(); write_transposemap(reorders, new_concat, new_transpose); } @@ -358,16 +370,16 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptrget_ordered_ops()) { - NGRAPH_DEBUG << "Processing " << n->get_name(); + OPENVINO_DEBUG << "Processing " << n->get_name(); // collect output shape of all Result nodes for a sanity check - if (ngraph::op::is_output(n)) { + if (ov::op::util::is_output(n)) { orig_result_out_shape[n->get_name()] = n->get_output_shape(0); } if (auto transpose = as_type_ptr(n)) { sink_transpose(transpose, reorders, transposes_to_delete); - } else if (ngraph::op::is_unary_elementwise_arithmetic(n)) { + } else if (ov::op::util::is_unary_elementwise_arithmetic(n)) { sink_unary(n, reorders, transposes_to_delete); - } else if (ngraph::op::is_binary_elementwise_arithmetic(n)) { + } else if (ov::op::util::is_binary_elementwise_arithmetic(n)) { sink_binary(n, reorders, transposes_to_delete); } else if (auto pad = as_type_ptr(n)) { sink_pad(pad, reorders, transposes_to_delete); @@ -378,18 +390,18 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptrget_ordered_ops()) { n->revalidate_and_infer_types(); } @@ -397,7 +409,7 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptrget_results(); for (const auto& r : results) { // make sure shapes are always materialized before results - NGRAPH_CHECK( + FRONT_END_GENERAL_CHECK( r->get_shape() == r->get_input_shape(0) && r->get_element_type() == r->input_value(0).get_element_type(), " op::Result = ", *r, @@ -406,11 +418,11 @@ bool ov::frontend::tf::pass::TransposeSinkingOVTF::run_on_function(shared_ptrget_output_shape(0) == orig_result_out_shape[r->get_name()], - " op::Result = ", - *r, - " expected output shape = ", - orig_result_out_shape[r->get_name()]); + FRONT_END_GENERAL_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()], + " op::Result = ", + *r, + " expected output shape = ", + orig_result_out_shape[r->get_name()]); } return true; } diff --git a/ngraph/frontend/tensorflow/src/place.cpp b/ngraph/frontend/tensorflow/src/place.cpp index e4cad115fb4..43166fa7318 100644 --- a/ngraph/frontend/tensorflow/src/place.cpp +++ b/ngraph/frontend/tensorflow/src/place.cpp @@ -4,8 +4,7 @@ #include "place.hpp" -#include - +#include "frontend_manager/frontend_exceptions.hpp" #include "node_context.hpp" #include "op_def.pb.h" #include "tensor.pb.h" diff --git a/ngraph/frontend/tensorflow/src/place.hpp b/ngraph/frontend/tensorflow/src/place.hpp index 2fea68d2bee..639a669f654 100644 --- a/ngraph/frontend/tensorflow/src/place.hpp +++ b/ngraph/frontend/tensorflow/src/place.hpp @@ -4,8 +4,8 @@ #pragma once -#include -#include +#include "frontend_manager/frontend.hpp" +#include "tensorflow_frontend/decoder.hpp" namespace ov { namespace frontend { diff --git a/ngraph/frontend/tensorflow/src/tensorflow.cpp b/ngraph/frontend/tensorflow/src/tensorflow.cpp index ac8ff086cca..fcfed5dc16d 100644 --- a/ngraph/frontend/tensorflow/src/tensorflow.cpp +++ b/ngraph/frontend/tensorflow/src/tensorflow.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include "frontend_manager/frontend_manager.hpp" +#include "tensorflow_frontend/frontend.hpp" extern "C" OPENVINO_CORE_EXPORTS ov::frontend::FrontEndVersion GetAPIVersion() { return OV_FRONTEND_API_VERSION; diff --git a/ngraph/frontend/tensorflow/src/tf_framework_node.hpp b/ngraph/frontend/tensorflow/src/tf_framework_node.hpp index 894b89b5f56..1ed553c299f 100644 --- a/ngraph/frontend/tensorflow/src/tf_framework_node.hpp +++ b/ngraph/frontend/tensorflow/src/tf_framework_node.hpp @@ -5,9 +5,9 @@ #pragma once #include -#include #include "openvino/op/util/framework_node.hpp" +#include "tensorflow_frontend/decoder.hpp" namespace ov { namespace frontend { diff --git a/ngraph/frontend/tensorflow/src/utils.cpp b/ngraph/frontend/tensorflow/src/utils.cpp index a61d49438f4..b8557330e78 100644 --- a/ngraph/frontend/tensorflow/src/utils.cpp +++ b/ngraph/frontend/tensorflow/src/utils.cpp @@ -4,8 +4,7 @@ #include "utils.hpp" -void ov::frontend::tf::tf_shape_to_ngraph_shape(const tensorflow::TensorShapeProto& tf_shape, - ov::PartialShape* ng_shape) { +void ov::frontend::tf::tf_shape_to_ov_shape(const tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape) { std::vector dims; for (int i = 0; i < tf_shape.dim_size(); i++) { dims.emplace_back(tf_shape.dim(i).size()); diff --git a/ngraph/frontend/tensorflow/src/utils.hpp b/ngraph/frontend/tensorflow/src/utils.hpp index 63bc29d01ce..27183ea5ec8 100644 --- a/ngraph/frontend/tensorflow/src/utils.hpp +++ b/ngraph/frontend/tensorflow/src/utils.hpp @@ -21,11 +21,11 @@ #pragma once #include "graph_iterator_proto.hpp" -#include "ngraph/log.hpp" -#include "ngraph/ngraph.hpp" -#include "ngraph_conversions.hpp" #include "node_context.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/opsets/opset8.hpp" +#include "openvino/util/log.hpp" +#include "openvino_conversions.hpp" namespace ov { namespace frontend { @@ -69,10 +69,10 @@ void make_padding(const std::string& tf_padding_type, } } -void tf_shape_to_ngraph_shape(const ::tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape); +void tf_shape_to_ov_shape(const ::tensorflow::TensorShapeProto& tf_shape, ov::PartialShape* ng_shape); template -void get_static_input_vec(const NodeContext& node, int64_t input_index, std::vector* vector) { +void get_const_input(const NodeContext& node, int64_t input_index, std::vector* vector) { auto ng_input = node.get_input(input_index); if (auto constant = std::dynamic_pointer_cast(ng_input.get_node_shared_ptr())) { *vector = constant->cast_vector(); @@ -87,7 +87,7 @@ void get_static_input_vec(const NodeContext& node, int64_t input_index, std::vec // Modified with an extra `VecT` parameter to handle the case where the type // in the std::vector does not match TensorFlow's notion of what the C++ type // should be (e.g. when T is `bool`, we actually need a std::vector of `char` for -// compatibility with nGraph). +// compatibility with OpenVINO). template void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_shape, std::vector* values) { TF_OP_VALIDATION_CHECK(node, node.get_op_type() == "Const", "Node is expected to be Constant."); @@ -104,7 +104,7 @@ void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_sha const tensorflow::TensorShapeProto& shape = tensor_proto.tensor_shape(); ov::PartialShape pshape; - tf_shape_to_ngraph_shape(shape, &pshape); + tf_shape_to_ov_shape(shape, &pshape); *const_tensor_shape = pshape.get_shape(); TF_OP_VALIDATION_CHECK(node, pshape.is_static(), "Dynamic shapes are not supported in Constant conversion."); auto tensor_content = tensor_proto.tensor_content(); @@ -171,8 +171,8 @@ void values_from_const_node(const NodeContext& node, ov::Shape* const_tensor_sha val_i = tensor_proto.double_val()[i]; break; default: - NGRAPH_DEBUG << "Const node has empty tensor_proto and we don't know how to " - "handle this element type"; + OPENVINO_DEBUG << "Const node has empty tensor_proto and we don't know how to " + "handle this element type"; FRONT_END_THROW("Encountered unknown element type " + DataType_Name(dt) + " on an empty tensor_proto"); } TF_OP_VALIDATION_CHECK(node, val_size != 0, "Empty values vector");