Call transformation_extensions in convert_partially method

This commit is contained in:
Ivan Tikhonov 2021-12-08 09:36:40 +03:00
parent e53f4cf2f6
commit 0211880f29
2 changed files with 33 additions and 1 deletions

View File

@ -309,13 +309,27 @@ void FrontEndPDPD::convert(std::shared_ptr<ov::Function> partiallyConverted) con
pdpd::get_supported_ops());
}
}
for (auto result : partiallyConverted->get_results()) {
for (const auto& result : partiallyConverted->get_results()) {
result->validate_and_infer_types();
}
}
std::shared_ptr<ov::Function> FrontEndPDPD::convert_partially(InputModel::Ptr model) const {
auto pdpd_model = std::dynamic_pointer_cast<InputModelPDPD>(model);
FRONT_END_GENERAL_CHECK(pdpd_model != nullptr, "Invalid input model");
if (!m_transformation_extensions.empty()) {
auto function = decode(model);
pass::Manager manager;
for (const auto& transformation : m_transformation_extensions) {
transformation->register_pass(manager);
}
manager.run_passes(function);
convert(function);
return function;
}
std::map<std::string, pdpd::CreatorFunction> CREATORS_MAP = pdpd::get_supported_ops();
auto f = convert_each_node(
pdpd_model,
@ -333,6 +347,8 @@ std::shared_ptr<ov::Function> FrontEndPDPD::convert_partially(InputModel::Ptr mo
std::shared_ptr<ov::Function> FrontEndPDPD::decode(InputModel::Ptr model) const {
auto pdpd_model = std::dynamic_pointer_cast<InputModelPDPD>(model);
FRONT_END_GENERAL_CHECK(pdpd_model != nullptr, "Invalid input model");
std::map<std::string, pdpd::CreatorFunction> CREATORS_MAP = pdpd::get_supported_ops();
auto f = convert_each_node(pdpd_model, pdpd::make_framework_node);
return f;

View File

@ -332,6 +332,20 @@ std::shared_ptr<ov::Function> FrontEndTF::convert(ov::frontend::InputModel::Ptr
std::shared_ptr<ov::Function> FrontEndTF::convert_partially(ov::frontend::InputModel::Ptr model) const {
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
if (!m_transformation_extensions.empty()) {
auto function = decode(model);
pass::Manager manager;
for (const auto& transformation : m_transformation_extensions) {
transformation->register_pass(manager);
}
manager.run_passes(function);
convert(function);
return function;
}
std::shared_ptr<ov::Function> f;
translate_graph(model_tf, "here_should_be_a_graph_name", false, false, f);
normalize(f);
@ -340,6 +354,8 @@ std::shared_ptr<ov::Function> FrontEndTF::convert_partially(ov::frontend::InputM
std::shared_ptr<ov::Function> FrontEndTF::decode(ov::frontend::InputModel::Ptr model) const {
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
std::shared_ptr<ov::Function> f;
translate_graph(model_tf, "here_should_be_a_graph_name", false, true, f);
return f;