Call transformation_extensions in convert_partially method
This commit is contained in:
parent
e53f4cf2f6
commit
0211880f29
@ -309,13 +309,27 @@ void FrontEndPDPD::convert(std::shared_ptr<ov::Function> partiallyConverted) con
|
||||
pdpd::get_supported_ops());
|
||||
}
|
||||
}
|
||||
for (auto result : partiallyConverted->get_results()) {
|
||||
for (const auto& result : partiallyConverted->get_results()) {
|
||||
result->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Function> FrontEndPDPD::convert_partially(InputModel::Ptr model) const {
|
||||
auto pdpd_model = std::dynamic_pointer_cast<InputModelPDPD>(model);
|
||||
FRONT_END_GENERAL_CHECK(pdpd_model != nullptr, "Invalid input model");
|
||||
|
||||
if (!m_transformation_extensions.empty()) {
|
||||
auto function = decode(model);
|
||||
|
||||
pass::Manager manager;
|
||||
for (const auto& transformation : m_transformation_extensions) {
|
||||
transformation->register_pass(manager);
|
||||
}
|
||||
manager.run_passes(function);
|
||||
convert(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::map<std::string, pdpd::CreatorFunction> CREATORS_MAP = pdpd::get_supported_ops();
|
||||
auto f = convert_each_node(
|
||||
pdpd_model,
|
||||
@ -333,6 +347,8 @@ std::shared_ptr<ov::Function> FrontEndPDPD::convert_partially(InputModel::Ptr mo
|
||||
|
||||
std::shared_ptr<ov::Function> FrontEndPDPD::decode(InputModel::Ptr model) const {
|
||||
auto pdpd_model = std::dynamic_pointer_cast<InputModelPDPD>(model);
|
||||
FRONT_END_GENERAL_CHECK(pdpd_model != nullptr, "Invalid input model");
|
||||
|
||||
std::map<std::string, pdpd::CreatorFunction> CREATORS_MAP = pdpd::get_supported_ops();
|
||||
auto f = convert_each_node(pdpd_model, pdpd::make_framework_node);
|
||||
return f;
|
||||
|
@ -332,6 +332,20 @@ std::shared_ptr<ov::Function> FrontEndTF::convert(ov::frontend::InputModel::Ptr
|
||||
|
||||
std::shared_ptr<ov::Function> FrontEndTF::convert_partially(ov::frontend::InputModel::Ptr model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
|
||||
|
||||
if (!m_transformation_extensions.empty()) {
|
||||
auto function = decode(model);
|
||||
|
||||
pass::Manager manager;
|
||||
for (const auto& transformation : m_transformation_extensions) {
|
||||
transformation->register_pass(manager);
|
||||
}
|
||||
manager.run_passes(function);
|
||||
convert(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Function> f;
|
||||
translate_graph(model_tf, "here_should_be_a_graph_name", false, false, f);
|
||||
normalize(f);
|
||||
@ -340,6 +354,8 @@ std::shared_ptr<ov::Function> FrontEndTF::convert_partially(ov::frontend::InputM
|
||||
|
||||
std::shared_ptr<ov::Function> FrontEndTF::decode(ov::frontend::InputModel::Ptr model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
|
||||
|
||||
std::shared_ptr<ov::Function> f;
|
||||
translate_graph(model_tf, "here_should_be_a_graph_name", false, true, f);
|
||||
return f;
|
||||
|
Loading…
Reference in New Issue
Block a user