MO: Support of discovering of suitable MOC frontend (#6888)

* MO: Support of discovering of suitable MOC frontend if --framework is not specified

* Ready for review

* Fix: don't use FrontEndManager if framework is not in list of available frontends

* Apply review comments
This commit is contained in:
Mikhail Nosov 2021-08-10 09:23:30 +03:00 committed by GitHub
parent 2a5584791c
commit c4bd0a45d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 213 additions and 31 deletions

View File

@ -97,16 +97,29 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet:
def prepare_ir(argv: argparse.Namespace):
is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx = deduce_framework_by_namespace(argv)
fem = argv.feManager
new_front_ends = []
if fem is not None: # in future, check of 'use_legacy_frontend' in argv can be added here
new_front_ends = fem.get_available_front_ends()
available_moc_front_ends = []
moc_front_end = None
# TODO: in future, check of 'use_legacy_frontend' in argv can be added here (issue 61973)
force_use_legacy_frontend = False
if fem and not force_use_legacy_frontend:
available_moc_front_ends = fem.get_available_front_ends()
if argv.input_model:
if not argv.framework:
moc_front_end = fem.load_by_model(argv.input_model)
if moc_front_end:
argv.framework = moc_front_end.get_name()
elif argv.framework in available_moc_front_ends:
moc_front_end = fem.load_by_framework(argv.framework)
is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx =\
deduce_framework_by_namespace(argv) if not moc_front_end else [False, False, False, False, False]
if not any([is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx]):
frameworks = ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx']
frameworks = list(set(frameworks + new_front_ends))
frameworks = list(set(frameworks + available_moc_front_ends))
if argv.framework not in frameworks:
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
refer_to_faq_msg(15), argv.framework, frameworks)
@ -173,7 +186,7 @@ def prepare_ir(argv: argparse.Namespace):
if argv.legacy_ir_generation and len(argv.transform) != 0:
raise Error("--legacy_ir_generation and --transform keys can not be used at the same time.")
use_legacy_fe = argv.framework not in new_front_ends
use_legacy_fe = argv.framework not in available_moc_front_ends
# For C++ frontends there is no specific python installation requirements, thus check only generic ones
ret_code = check_requirements(framework=argv.framework if use_legacy_fe else None)
if ret_code:
@ -258,7 +271,7 @@ def prepare_ir(argv: argparse.Namespace):
send_framework_info('kaldi')
from mo.front.kaldi.register_custom_ops import get_front_classes
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
elif is_onnx: # in future check of 'use_legacy_frontend' can be added here
elif is_onnx:
send_framework_info('onnx')
from mo.front.onnx.register_custom_ops import get_front_classes
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
@ -266,11 +279,10 @@ def prepare_ir(argv: argparse.Namespace):
graph = None
ngraph_function = None
# In future check of use_legacy_frontend option can be added here
if argv.feManager is None or argv.framework not in new_front_ends:
if argv.framework not in available_moc_front_ends:
graph = unified_pipeline(argv)
else:
ngraph_function = moc_pipeline(argv)
ngraph_function = moc_pipeline(argv, moc_front_end)
return graph, ngraph_function
@ -389,7 +401,6 @@ def main(cli_parser: argparse.ArgumentParser, fem: FrontEndManager, framework: s
argv = cli_parser.parse_args()
send_params_info(argv, cli_parser)
if framework:
argv.framework = framework
argv.feManager = fem
@ -435,5 +446,5 @@ def main(cli_parser: argparse.ArgumentParser, fem: FrontEndManager, framework: s
if __name__ == "__main__":
from mo.utils.cli_parser import get_all_cli_parser
fem = FrontEndManager()
sys.exit(main(get_all_cli_parser(fem), fem, None))
fe_manager = FrontEndManager()
sys.exit(main(get_all_cli_parser(fe_manager), fe_manager, None))

View File

@ -9,22 +9,18 @@ from mo.moc_frontend.extractor import fe_user_data_repack
from mo.middle.passes.infer import validate_batch_in_shape
from ngraph import Dimension, PartialShape # pylint: disable=no-name-in-module,import-error
from ngraph.frontend import Place # pylint: disable=no-name-in-module,import-error
from ngraph.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
from ngraph.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
def moc_pipeline(argv: argparse.Namespace):
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
"""
Load input model and convert it to nGraph function
:param: parsed command line arguments
:param: argv: parsed command line arguments
:param: moc_front_end: Loaded Frontend for converting input model
:return: converted nGraph function ready for serialization
"""
fem = argv.feManager
log.debug('Available front ends: {}'.format(
str(fem.get_available_front_ends())))
log.debug('Initializing new FE for framework {}'.format(argv.framework))
fe = fem.load_by_framework(argv.framework)
input_model = fe.load(argv.input_model)
input_model = moc_front_end.load(argv.input_model)
user_shapes, outputs, freeze_placeholder = fe_user_data_repack(
input_model, argv.placeholder_shapes, argv.placeholder_data_types,
@ -78,7 +74,6 @@ def moc_pipeline(argv: argparse.Namespace):
def shape_to_array(shape: PartialShape):
return [shape.get_dimension(i) for i in range(shape.rank.get_length())]
return
# Set batch size
if argv.batch is not None and argv.batch > 0:
@ -100,5 +95,5 @@ def moc_pipeline(argv: argparse.Namespace):
joined_name, old_shape_array, new_shape))
input_model.set_partial_shape(place, new_partial_shape)
ngraph_function = fe.convert(input_model)
ngraph_function = moc_front_end.convert(input_model)
return ngraph_function

View File

@ -46,7 +46,7 @@ mock_needed = pytest.mark.skipif(not mock_available,
def replaceArgsHelper(log_level='DEBUG',
silent=False,
model_name='abc',
input_model='abc.abc',
input_model='abc.test_mo_mock_mdl',
transform=[],
legacy_ir_generation=False,
scale=None,
@ -73,7 +73,8 @@ def replaceArgsHelper(log_level='DEBUG',
mean_values=mean_values,
scale_values=scale_values,
output_dir=output_dir,
freeze_placeholder_with_value=freeze_placeholder_with_value)
freeze_placeholder_with_value=freeze_placeholder_with_value,
framework=None)
class TestMainFrontend(unittest.TestCase):
@ -97,9 +98,35 @@ class TestMainFrontend(unittest.TestCase):
group(1).replace("\r", "")
assert xml_file and bin_file
# verify that 'convert' was called
# verify that 'convert' was called, and 'supported' was not
stat = get_frontend_statistic()
assert stat.convert_model == 1
assert stat.supported == 0
# verify that meta info is added to XML file
with open(xml_file) as file:
assert 'mock_mo_ngraph_frontend' in file.read()
@mock_needed
@patch('argparse.ArgumentParser.parse_args',
return_value=replaceArgsHelper())
def test_convert_framework_discover(self, mock_argparse):
f = io.StringIO()
with redirect_stdout(f):
main(argparse.ArgumentParser(), fem, None)
out = f.getvalue()
xml_file = re.search(r'\[ SUCCESS \] XML file: (.*)', out). \
group(1).replace("\r", "")
bin_file = re.search(r'\[ SUCCESS \] BIN file: (.*)', out). \
group(1).replace("\r", "")
assert xml_file and bin_file
# verify that 'convert', 'supported' and 'get_name' were called
stat = get_frontend_statistic()
assert stat.convert_model == 1
assert stat.supported == 1
assert stat.get_name > 0
# verify that meta info is added to XML file
with open(xml_file) as file:
assert 'mock_mo_ngraph_frontend' in file.read()
@ -227,3 +254,19 @@ class TestMainFrontend(unittest.TestCase):
assert stat.get_partial_shape == 1
# verify that 'set_element_type' was not called
assert stat.set_partial_shape == 0
@mock_needed
@patch('argparse.ArgumentParser.parse_args',
return_value=replaceArgsHelper(input_model='abc.qwerty'))
def test_error_input_model_no_framework(self, mock_argparse):
# Framework is not specified and 'abc.qwerty' is not supported
# so MO shall not convert anything and produce specified error
with self.assertLogs() as logger:
main(argparse.ArgumentParser(), fem, None)
stat = get_frontend_statistic()
assert [s for s in logger.output if 'can not be deduced' in s]
# verify that 'supported' was called
assert stat.supported == 1

View File

@ -393,9 +393,13 @@ struct MOCK_API FeStat
{
std::vector<std::string> m_load_paths;
int m_convert_model = 0;
int m_supported = 0;
int m_get_name = 0;
// Getters
std::vector<std::string> load_paths() const { return m_load_paths; }
int convert_model() const { return m_convert_model; }
int supported() const { return m_supported; }
int get_name() const { return m_get_name; }
};
/// \brief Mock implementation of FrontEnd
@ -428,4 +432,24 @@ private:
}
return std::make_shared<InputModelMockPy>();
}
bool supported_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
{
m_stat.m_supported++;
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
{
auto path = as_type_ptr<VariantWrapper<std::string>>(params[0])->get();
if (path.find(".test_mo_mock_mdl") != std::string::npos)
{
return true;
}
}
return false;
}
std::string get_name() const override
{
m_stat.m_get_name++;
return "mock_mo_ngraph_frontend";
}
};

View File

@ -19,6 +19,8 @@ static void register_mock_frontend_stat(py::module m)
py::class_<FeStat> feStat(m, "FeStat", py::dynamic_attr());
feStat.def_property_readonly("load_paths", &FeStat::load_paths);
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
feStat.def_property_readonly("supported", &FeStat::supported);
feStat.def_property_readonly("get_name", &FeStat::get_name);
}
static void register_mock_setup(py::module m)

View File

@ -83,6 +83,12 @@ namespace ngraph
/// \param function partially converted nGraph function
virtual void normalize(std::shared_ptr<ngraph::Function> function) const;
/// \brief Gets name of this FrontEnd. Can be used by clients
/// if frontend is selected automatically by FrontEndManager::load_by_model
///
/// \return Current frontend name. Empty string if not implemented
virtual std::string get_name() const;
protected:
virtual bool
supported_impl(const std::vector<std::shared_ptr<Variant>>& variants) const;

View File

@ -47,10 +47,13 @@ namespace ngraph
/// Selects and loads appropriate frontend depending on model file extension and other
/// file info (header)
///
/// \param framework
/// Framework name. Throws exception if name is not in list of available frontends
/// \param vars Any number of parameters of any type. What kind of parameters
/// are accepted is determined by each FrontEnd individually, typically it is
/// std::string containing path to the model file. For more information please
/// refer to specific FrontEnd documentation.
///
/// \return Frontend interface for further loading of model
/// \return Frontend interface for further loading of model. Returns 'nullptr'
/// if no suitable frontend is found
template <typename... Types>
FrontEnd::Ptr load_by_model(const Types&... vars)
{

View File

@ -167,6 +167,11 @@ void FrontEnd::normalize(std::shared_ptr<ngraph::Function> function) const
FRONT_END_NOT_IMPLEMENTED(normalize);
}
std::string FrontEnd::get_name() const
{
return std::string();
}
//----------- InputModel ---------------------------
std::vector<Place::Ptr> InputModel::get_inputs() const
{

View File

@ -43,6 +43,12 @@ namespace ngraph
/// \return nGraph function after decoding
std::shared_ptr<Function> decode(InputModel::Ptr model) const override;
/// \brief Gets name of this FrontEnd. Can be used by clients
/// if frontend is selected automatically by FrontEndManager::load_by_model
///
/// \return Paddle frontend name.
std::string get_name() const override;
protected:
/// \brief Check if FrontEndPDPD can recognize model from given parts
/// \param params Can be path to folder which contains __model__ file or path to

View File

@ -399,6 +399,8 @@ namespace ngraph
auto f = convert_each_node(pdpd_model, pdpd::make_framework_node);
return f;
}
std::string FrontEndPDPD::get_name() const { return "paddle"; }
} // namespace frontend
} // namespace ngraph

View File

@ -123,4 +123,16 @@ void regclass_pyngraph_FrontEnd(py::module m)
function : Function
Partially converted nGraph function.
)");
fem.def("get_name",
&ngraph::frontend::FrontEnd::get_name,
R"(
Gets name of this FrontEnd. Can be used by clients
if frontend is selected automatically by FrontEndManager::load_by_model.
Parameters
----------
get_name : str
Current frontend name. Empty string if not implemented.
)");
}

View File

@ -35,6 +35,7 @@ void regclass_pyngraph_FrontEndManager(py::module m)
get_available_front_ends : List[str]
List of available frontend names.
)");
fem.def("load_by_framework",
&ngraph::frontend::FrontEndManager::load_by_framework,
py::arg("framework"),
@ -51,6 +52,25 @@ void regclass_pyngraph_FrontEndManager(py::module m)
load_by_framework : FrontEnd
Frontend interface for further loading of models.
)");
fem.def(
"load_by_model",
[](const std::shared_ptr<ngraph::frontend::FrontEndManager>& fem,
const std::string& model_path) { return fem->load_by_model(model_path); },
py::arg("model_path"),
R"(
Selects and loads appropriate frontend depending on model file extension and other file info (header).
Parameters
----------
model_path : str
Path to model file/directory.
Returns
----------
load_by_model : FrontEnd
Frontend interface for further loading of models. 'None' if no suitable frontend is found
)");
}
void regclass_pyngraph_GeneralFailureFrontEnd(py::module m)

View File

@ -485,6 +485,8 @@ struct MOCK_API FeStat
int m_convert_partially = 0;
int m_decode = 0;
int m_normalize = 0;
int m_get_name = 0;
int m_supported = 0;
// Getters
std::vector<std::string> load_paths() const { return m_load_paths; }
int convert_model() const { return m_convert_model; }
@ -492,6 +494,8 @@ struct MOCK_API FeStat
int convert_partially() const { return m_convert_partially; }
int decode() const { return m_decode; }
int normalize() const { return m_normalize; }
int get_name() const { return m_get_name; }
int supported() const { return m_supported; }
};
class MOCK_API FrontEndMockPy : public FrontEnd
@ -509,6 +513,20 @@ public:
return std::make_shared<InputModelMockPy>();
}
bool supported_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
{
m_stat.m_supported++;
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
{
auto path = as_type_ptr<VariantWrapper<std::string>>(params[0])->get();
if (path.find(".test_mock_py_mdl") != std::string::npos)
{
return true;
}
}
return false;
}
std::shared_ptr<ngraph::Function> convert(InputModel::Ptr model) const override
{
m_stat.m_convert_model++;
@ -534,5 +552,11 @@ public:
m_stat.m_normalize++;
}
std::string get_name() const override
{
m_stat.m_get_name++;
return "mock_py";
}
FeStat get_stat() const { return m_stat; }
};

View File

@ -33,6 +33,8 @@ static void register_mock_frontend_stat(py::module m)
feStat.def_property_readonly("convert_partially", &FeStat::convert_partially);
feStat.def_property_readonly("decode", &FeStat::decode);
feStat.def_property_readonly("normalize", &FeStat::normalize);
feStat.def_property_readonly("get_name", &FeStat::get_name);
feStat.def_property_readonly("supported", &FeStat::supported);
}
static void register_mock_model_stat(py::module m)

View File

@ -52,6 +52,16 @@ def test_load():
assert "abc.bin" in stat.load_paths
@mock_needed
def test_load_by_model():
fe = fem.load_by_model(model_path="abc.test_mock_py_mdl")
assert fe is not None
assert fe.get_name() == "mock_py"
stat = get_fe_stat(fe)
assert stat.get_name == 1
assert stat.supported == 1
@mock_needed
def test_convert_model():
fe = fem.load_by_framework(framework="mock_py")
@ -90,6 +100,16 @@ def test_decode_and_normalize():
assert stat.decode == 1
@mock_needed
def test_get_name():
fe = fem.load_by_framework(framework="mock_py")
assert fe is not None
name = fe.get_name()
assert name == "mock_py"
stat = get_fe_stat(fe)
assert stat.get_name == 1
# --------InputModel tests-----------------
@mock_needed
def init_model():

View File

@ -59,6 +59,9 @@ TEST(FrontEndManagerTest, testMockPluginFrontEnd)
FrontEndManager fem;
auto frontends = fem.get_available_front_ends();
ASSERT_NE(std::find(frontends.begin(), frontends.end(), "mock1"), frontends.end());
FrontEnd::Ptr fe;
ASSERT_NO_THROW(fe = fem.load_by_framework("mock1"));
ASSERT_EQ(fe->get_name(), "mock1");
set_test_env("OV_FRONTEND_PATH", "");
}
@ -77,6 +80,7 @@ TEST(FrontEndManagerTest, testDefaultFrontEnd)
ASSERT_ANY_THROW(fe->convert_partially(nullptr));
ASSERT_ANY_THROW(fe->decode(nullptr));
ASSERT_ANY_THROW(fe->normalize(nullptr));
ASSERT_EQ(fe->get_name(), std::string());
}
TEST(FrontEndManagerTest, testDefaultInputModel)

View File

@ -18,6 +18,8 @@ using namespace ngraph::frontend;
class FrontEndMock : public FrontEnd
{
public:
std::string get_name() const override { return "mock1"; }
};
extern "C" MOCK_API FrontEndVersion GetAPIVersion()

View File

@ -37,6 +37,7 @@ void FrontEndBasicTest::doLoadFromFile()
TEST_P(FrontEndBasicTest, testLoadFromFile)
{
ASSERT_NO_THROW(doLoadFromFile());
ASSERT_EQ(m_frontEnd->get_name(), m_feName);
std::shared_ptr<ngraph::Function> function;
ASSERT_NO_THROW(function = m_frontEnd->convert(m_inputModel));
ASSERT_NE(function, nullptr);