diff --git a/src/frontends/paddle/include/openvino/frontend/paddle/node_context.hpp b/src/frontends/paddle/include/openvino/frontend/paddle/node_context.hpp index 73a500e5605..545ac6391cf 100644 --- a/src/frontends/paddle/include/openvino/frontend/paddle/node_context.hpp +++ b/src/frontends/paddle/include/openvino/frontend/paddle/node_context.hpp @@ -23,8 +23,8 @@ using NamedInputs = std::map; class NodeContext : public ov::frontend::NodeContext { public: using Ptr = std::shared_ptr; - NodeContext(const DecoderBase& _decoder, const NamedInputs& _name_map) - : ov::frontend::NodeContext(_decoder.get_op_type()), + NodeContext(const std::shared_ptr& _decoder, const NamedInputs& _name_map) + : ov::frontend::NodeContext(_decoder->get_op_type()), decoder(_decoder), name_map(_name_map) {} @@ -67,45 +67,45 @@ public: } std::vector get_output_names() const { - return decoder.get_output_names(); + return decoder->get_output_names(); } std::vector get_output_var_names(const std::string& var_name) const { - return decoder.get_output_var_names(var_name); + return decoder->get_output_var_names(var_name); } std::vector get_input_var_names(const std::string& var_name) const { - return decoder.get_input_var_names(var_name); + return decoder->get_input_var_names(var_name); } ov::element::Type get_out_port_type(const std::string& port_name) const { - return decoder.get_out_port_type(port_name); + return decoder->get_out_port_type(port_name); } NamedOutputs default_single_output_mapping(const std::shared_ptr& node, const std::vector& required_pdpd_out_names) const; ov::Any get_attribute_as_any(const std::string& name) const override { - auto res = decoder.get_attribute(name); + auto res = decoder->get_attribute(name); return res; } size_t get_output_size(const std::string& port_name) const { - return decoder.get_output_size(port_name); + return decoder->get_output_size(port_name); } std::vector> get_output_port_infos( const std::string& port_name) const { - return decoder.get_output_port_infos(port_name); + return decoder->get_output_port_infos(port_name); } private: ov::Any apply_additional_conversion_rules(const ov::Any& any, const std::type_info& type_info) const override { - auto res = decoder.convert_attribute(any, type_info); + auto res = decoder->convert_attribute(any, type_info); return res; } - const DecoderBase& decoder; + const std::shared_ptr decoder; const NamedInputs& name_map; }; diff --git a/src/frontends/paddle/src/decoder_proto.cpp b/src/frontends/paddle/src/decoder_proto.cpp index 51f6ec25b7d..2f4e7a0b372 100644 --- a/src/frontends/paddle/src/decoder_proto.cpp +++ b/src/frontends/paddle/src/decoder_proto.cpp @@ -86,7 +86,7 @@ ov::Any DecoderProto::convert_attribute(const Any& data, const std::type_info& t std::vector DecoderProto::get_output_names() const { std::vector output_names; - for (const auto& output : op_place->get_desc().outputs()) { + for (const auto& output : get_place()->get_desc().outputs()) { output_names.push_back(output.parameter()); } return output_names; @@ -94,7 +94,7 @@ std::vector DecoderProto::get_output_names() const { std::vector DecoderProto::get_output_var_names(const std::string& var_name) const { std::vector output_names; - for (const auto& output : op_place->get_desc().outputs()) { + for (const auto& output : get_place()->get_desc().outputs()) { if (output.parameter() == var_name) { for (int idx = 0; idx < output.arguments_size(); ++idx) { output_names.push_back(output.arguments()[idx]); @@ -106,7 +106,7 @@ std::vector DecoderProto::get_output_var_names(const std::st std::vector DecoderProto::get_input_var_names(const std::string& var_name) const { std::vector input_names; - for (const auto& input : op_place->get_desc().inputs()) { + for (const auto& input : get_place()->get_desc().inputs()) { if (input.parameter() == var_name) { for (int idx = 0; idx < input.arguments_size(); ++idx) { input_names.push_back(input.arguments()[idx]); @@ -117,13 +117,13 @@ std::vector DecoderProto::get_input_var_names(const std::str } size_t DecoderProto::get_output_size(const std::string& port_name) const { - const auto out_port = op_place->get_output_ports().at(port_name); + const auto out_port = get_place()->get_output_ports().at(port_name); return out_port.size(); } size_t DecoderProto::get_output_size() const { size_t res = 0; - for (const auto& output : op_place->get_desc().outputs()) { + for (const auto& output : get_place()->get_desc().outputs()) { res += output.arguments().size(); } return res; @@ -131,7 +131,7 @@ size_t DecoderProto::get_output_size() const { std::map> DecoderProto::get_output_type_map() const { std::map> output_types; - for (const auto& out_port_pair : op_place->get_output_ports()) { + for (const auto& out_port_pair : get_place()->get_output_ports()) { for (const auto& p_place : out_port_pair.second) { output_types[out_port_pair.first].push_back(p_place->get_target_tensor_paddle()->get_element_type()); } @@ -142,7 +142,7 @@ std::map> DecoderProto::get_output_t std::vector> DecoderProto::get_output_port_infos( const std::string& port_name) const { std::vector> output_types; - for (const auto& out_port : op_place->get_output_ports().at(port_name)) { + for (const auto& out_port : get_place()->get_output_ports().at(port_name)) { output_types.push_back({out_port->get_target_tensor_paddle()->get_element_type(), out_port->get_target_tensor_paddle()->get_partial_shape()}); } @@ -151,7 +151,7 @@ std::vector> DecoderProto::get_ou ov::element::Type DecoderProto::get_out_port_type(const std::string& port_name) const { std::vector output_types; - for (const auto& out_port : op_place->get_output_ports().at(port_name)) { + for (const auto& out_port : get_place()->get_output_ports().at(port_name)) { output_types.push_back(out_port->get_target_tensor_paddle()->get_element_type()); } FRONT_END_GENERAL_CHECK(!output_types.empty(), "Port has no tensors connected."); @@ -161,12 +161,12 @@ ov::element::Type DecoderProto::get_out_port_type(const std::string& port_name) } std::string DecoderProto::get_op_type() const { - return op_place->get_desc().type(); + return get_place()->get_desc().type(); } std::vector DecoderProto::decode_attribute_helper(const std::string& name) const { std::vector attrs; - for (const auto& attr : op_place->get_desc().attrs()) { + for (const auto& attr : get_place()->get_desc().attrs()) { if (attr.name() == name) attrs.push_back(attr); } @@ -174,7 +174,7 @@ std::vector DecoderProto::decode_attribute_helper(const std: "An error occurred while parsing the ", name, " attribute of ", - op_place->get_desc().type(), + get_place()->get_desc().type(), "node. Unsupported number of attributes. Current number: ", attrs.size(), " Expected number: 0 or 1"); @@ -201,12 +201,12 @@ inline std::map map_for_each_input_impl( std::map DecoderProto::map_for_each_input( const std::function(const std::string&, size_t)>& func) const { - return map_for_each_input_impl(op_place->get_desc().inputs(), func); + return map_for_each_input_impl(get_place()->get_desc().inputs(), func); } std::map DecoderProto::map_for_each_output( const std::function(const std::string&, size_t)>& func) const { - return map_for_each_input_impl(op_place->get_desc().outputs(), func); + return map_for_each_input_impl(get_place()->get_desc().outputs(), func); } } // namespace paddle diff --git a/src/frontends/paddle/src/decoder_proto.hpp b/src/frontends/paddle/src/decoder_proto.hpp index 9188c147386..febe023b4dd 100644 --- a/src/frontends/paddle/src/decoder_proto.hpp +++ b/src/frontends/paddle/src/decoder_proto.hpp @@ -56,7 +56,14 @@ public: private: std::vector<::paddle::framework::proto::OpDesc_Attr> decode_attribute_helper(const std::string& name) const; - std::shared_ptr op_place; + std::weak_ptr op_place; + + const std::shared_ptr get_place() const { + auto place = op_place.lock(); + if (!place) + FRONT_END_THROW("This proto decoder contains empty op place."); + return place; + } }; } // namespace paddle diff --git a/src/frontends/paddle/src/frontend.cpp b/src/frontends/paddle/src/frontend.cpp index adafbff2579..cbaaa598af9 100644 --- a/src/frontends/paddle/src/frontend.cpp +++ b/src/frontends/paddle/src/frontend.cpp @@ -60,7 +60,7 @@ NamedOutputs make_ng_node(const std::map>& node NamedOutputs outputs; // In case the conversion function throws exception try { - outputs = creator_it->second(paddle::NodeContext(DecoderProto(op_place), named_inputs)); + outputs = creator_it->second(paddle::NodeContext(op_place->get_decoder(), named_inputs)); } catch (std::exception& ex) { FRONT_END_OP_CONVERSION_CHECK(false, "Fail to convert " + op_desc.type() + " Exception " + ex.what()); } @@ -90,7 +90,10 @@ NamedOutputs make_framework_node(const std::map } } - auto node = std::make_shared(DecoderProto(op_place), inputs_vector, inputs_names); + auto decoder_proto = std::dynamic_pointer_cast(op_place->get_decoder()); + if (!decoder_proto) + FRONT_END_THROW("Failed to cast to DecoderProto."); + auto node = std::make_shared(decoder_proto, inputs_vector, inputs_names); return node->return_named_outputs(); } @@ -192,7 +195,7 @@ void try_update_sublock_info(const std::shared_ptr& op_place, SubblockI inp_tensors.push_back(inp_tensor); } - auto tmp_node = paddle::NodeContext(DecoderProto(op_place), paddle::NamedInputs()); + auto tmp_node = paddle::NodeContext(op_place->get_decoder(), paddle::NamedInputs()); auto block_idx = tmp_node.get_attribute("sub_block"); subblock_info[block_idx] = std::make_tuple(op_desc.type(), inp_tensors, outp_tensors); @@ -214,7 +217,7 @@ void try_update_sublock_info(const std::shared_ptr& op_place, SubblockI } FRONT_END_GENERAL_CHECK(inp_tensors.size() > 0, "Port has no tensors connected."); - auto tmp_node = paddle::NodeContext(DecoderProto(op_place), paddle::NamedInputs()); + auto tmp_node = paddle::NodeContext(op_place->get_decoder(), paddle::NamedInputs()); auto block_idx = tmp_node.get_attribute("sub_block"); subblock_info[block_idx] = std::make_tuple(op_desc.type(), inp_tensors, outp_tensors); diff --git a/src/frontends/paddle/src/input_model.cpp b/src/frontends/paddle/src/input_model.cpp index eb981cca37a..10fd4878f1c 100644 --- a/src/frontends/paddle/src/input_model.cpp +++ b/src/frontends/paddle/src/input_model.cpp @@ -93,6 +93,7 @@ void InputModel::InputModelImpl::loadPlaces() { for (const auto& op : block.ops()) { auto op_place = std::make_shared(m_input_model, op); + op_place->set_decoder(std::make_shared(op_place)); if (m_telemetry) { op_statistics[op.type()]++; diff --git a/src/frontends/paddle/src/paddle_fw_node.cpp b/src/frontends/paddle/src/paddle_fw_node.cpp index 365090bc0cf..a1d76860603 100644 --- a/src/frontends/paddle/src/paddle_fw_node.cpp +++ b/src/frontends/paddle/src/paddle_fw_node.cpp @@ -10,7 +10,7 @@ namespace paddle { void FrameworkNode::validate_and_infer_types() { ov::op::util::FrameworkNode::validate_and_infer_types(); size_t idx = 0; - for (const auto& port_pair : m_decoder.get_output_type_map()) { + for (const auto& port_pair : m_decoder->get_output_type_map()) { for (const auto& p_type : port_pair.second) { set_output_type(idx++, p_type, PartialShape::dynamic()); } @@ -18,7 +18,7 @@ void FrameworkNode::validate_and_infer_types() { } std::map FrameworkNode::get_named_inputs() const { - return m_decoder.map_for_each_input([&](const std::string& name, size_t) { + return m_decoder->map_for_each_input([&](const std::string& name, size_t) { auto it = std::find(m_inputs_names.begin(), m_inputs_names.end(), name); if (it != m_inputs_names.end()) { return input(it - m_inputs_names.begin()).get_source_output(); @@ -29,7 +29,7 @@ std::map FrameworkNode::get_named_inputs() const { } std::map FrameworkNode::return_named_outputs() { - return m_decoder.map_for_each_output([&](const std::string&, size_t idx) { + return m_decoder->map_for_each_output([&](const std::string&, size_t idx) { return output(idx); }); } diff --git a/src/frontends/paddle/src/paddle_fw_node.hpp b/src/frontends/paddle/src/paddle_fw_node.hpp index f18c29d11c3..f168ab3e305 100644 --- a/src/frontends/paddle/src/paddle_fw_node.hpp +++ b/src/frontends/paddle/src/paddle_fw_node.hpp @@ -14,12 +14,14 @@ class FrameworkNode : public ov::op::util::FrameworkNode { public: OPENVINO_OP("FrameworkNode", "util", ov::op::util::FrameworkNode); - FrameworkNode(const DecoderProto& decoder, const OutputVector& inputs, const std::vector& inputs_names) - : ov::op::util::FrameworkNode(inputs, decoder.get_output_size()), + FrameworkNode(const std::shared_ptr& decoder, + const OutputVector& inputs, + const std::vector& inputs_names) + : ov::op::util::FrameworkNode(inputs, decoder->get_output_size()), m_decoder{decoder}, m_inputs_names{inputs_names} { ov::op::util::FrameworkNodeAttrs attrs; - attrs.set_type_name(m_decoder.get_op_type()); + attrs.set_type_name(m_decoder->get_op_type()); set_attrs(attrs); validate_and_infer_types(); @@ -32,10 +34,10 @@ public: } std::string get_op_type() const { - return m_decoder.get_op_type(); + return m_decoder->get_op_type(); } - const DecoderProto& get_decoder() const { + const std::shared_ptr get_decoder() const { return m_decoder; } @@ -44,7 +46,7 @@ public: std::map return_named_outputs(); private: - const DecoderProto m_decoder; + const std::shared_ptr m_decoder; std::vector m_inputs_names; }; } // namespace paddle diff --git a/src/frontends/paddle/src/place.cpp b/src/frontends/paddle/src/place.cpp index 6f393bca0e9..508fa1d7a33 100644 --- a/src/frontends/paddle/src/place.cpp +++ b/src/frontends/paddle/src/place.cpp @@ -61,6 +61,14 @@ const ::paddle::framework::proto::OpDesc& OpPlace::get_desc() const { return m_op_desc; } +const std::shared_ptr OpPlace::get_decoder() const { + return m_op_decoder; +} + +void OpPlace::set_decoder(const std::shared_ptr op_decoder) { + m_op_decoder = op_decoder; +} + void OpPlace::add_out_port(const std::shared_ptr& output, const std::string& name) { m_output_ports[name].push_back(output); } diff --git a/src/frontends/paddle/src/place.hpp b/src/frontends/paddle/src/place.hpp index 3abb40462bc..1f6c34f0a1e 100644 --- a/src/frontends/paddle/src/place.hpp +++ b/src/frontends/paddle/src/place.hpp @@ -115,6 +115,8 @@ public: std::shared_ptr get_output_port_paddle(const std::string& outputName, int outputPortIndex) const; std::shared_ptr get_input_port_paddle(const std::string& inputName, int inputPortIndex) const; const ::paddle::framework::proto::OpDesc& get_desc() const; + const std::shared_ptr get_decoder() const; + void set_decoder(const std::shared_ptr op_decoder); // External API methods std::vector get_consuming_ports() const override; @@ -150,7 +152,8 @@ public: Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const override; private: - const ::paddle::framework::proto::OpDesc& m_op_desc; + const ::paddle::framework::proto::OpDesc& m_op_desc; // TODO: to conceal it behind decoder. + std::shared_ptr m_op_decoder; std::map>> m_input_ports; std::map>> m_output_ports; };