MO: Support of discovering of suitable MOC frontend (#6888)
* MO: Support of discovering of suitable MOC frontend if --framework is not specified * Ready for review * Fix: don't use FrontEndManager if framework is not in list of available frontends * Apply review comments
This commit is contained in:
parent
2a5584791c
commit
c4bd0a45d3
@ -97,16 +97,29 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet:
|
||||
|
||||
|
||||
def prepare_ir(argv: argparse.Namespace):
|
||||
is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx = deduce_framework_by_namespace(argv)
|
||||
|
||||
fem = argv.feManager
|
||||
new_front_ends = []
|
||||
if fem is not None: # in future, check of 'use_legacy_frontend' in argv can be added here
|
||||
new_front_ends = fem.get_available_front_ends()
|
||||
available_moc_front_ends = []
|
||||
moc_front_end = None
|
||||
|
||||
# TODO: in future, check of 'use_legacy_frontend' in argv can be added here (issue 61973)
|
||||
force_use_legacy_frontend = False
|
||||
|
||||
if fem and not force_use_legacy_frontend:
|
||||
available_moc_front_ends = fem.get_available_front_ends()
|
||||
if argv.input_model:
|
||||
if not argv.framework:
|
||||
moc_front_end = fem.load_by_model(argv.input_model)
|
||||
if moc_front_end:
|
||||
argv.framework = moc_front_end.get_name()
|
||||
elif argv.framework in available_moc_front_ends:
|
||||
moc_front_end = fem.load_by_framework(argv.framework)
|
||||
|
||||
is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx =\
|
||||
deduce_framework_by_namespace(argv) if not moc_front_end else [False, False, False, False, False]
|
||||
|
||||
if not any([is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx]):
|
||||
frameworks = ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx']
|
||||
frameworks = list(set(frameworks + new_front_ends))
|
||||
frameworks = list(set(frameworks + available_moc_front_ends))
|
||||
if argv.framework not in frameworks:
|
||||
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
|
||||
refer_to_faq_msg(15), argv.framework, frameworks)
|
||||
@ -173,7 +186,7 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
if argv.legacy_ir_generation and len(argv.transform) != 0:
|
||||
raise Error("--legacy_ir_generation and --transform keys can not be used at the same time.")
|
||||
|
||||
use_legacy_fe = argv.framework not in new_front_ends
|
||||
use_legacy_fe = argv.framework not in available_moc_front_ends
|
||||
# For C++ frontends there is no specific python installation requirements, thus check only generic ones
|
||||
ret_code = check_requirements(framework=argv.framework if use_legacy_fe else None)
|
||||
if ret_code:
|
||||
@ -258,7 +271,7 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
send_framework_info('kaldi')
|
||||
from mo.front.kaldi.register_custom_ops import get_front_classes
|
||||
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
|
||||
elif is_onnx: # in future check of 'use_legacy_frontend' can be added here
|
||||
elif is_onnx:
|
||||
send_framework_info('onnx')
|
||||
from mo.front.onnx.register_custom_ops import get_front_classes
|
||||
import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
|
||||
@ -266,11 +279,10 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
graph = None
|
||||
ngraph_function = None
|
||||
|
||||
# In future check of use_legacy_frontend option can be added here
|
||||
if argv.feManager is None or argv.framework not in new_front_ends:
|
||||
if argv.framework not in available_moc_front_ends:
|
||||
graph = unified_pipeline(argv)
|
||||
else:
|
||||
ngraph_function = moc_pipeline(argv)
|
||||
ngraph_function = moc_pipeline(argv, moc_front_end)
|
||||
return graph, ngraph_function
|
||||
|
||||
|
||||
@ -389,7 +401,6 @@ def main(cli_parser: argparse.ArgumentParser, fem: FrontEndManager, framework: s
|
||||
|
||||
argv = cli_parser.parse_args()
|
||||
send_params_info(argv, cli_parser)
|
||||
|
||||
if framework:
|
||||
argv.framework = framework
|
||||
argv.feManager = fem
|
||||
@ -435,5 +446,5 @@ def main(cli_parser: argparse.ArgumentParser, fem: FrontEndManager, framework: s
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mo.utils.cli_parser import get_all_cli_parser
|
||||
fem = FrontEndManager()
|
||||
sys.exit(main(get_all_cli_parser(fem), fem, None))
|
||||
fe_manager = FrontEndManager()
|
||||
sys.exit(main(get_all_cli_parser(fe_manager), fe_manager, None))
|
||||
|
@ -9,22 +9,18 @@ from mo.moc_frontend.extractor import fe_user_data_repack
|
||||
from mo.middle.passes.infer import validate_batch_in_shape
|
||||
|
||||
from ngraph import Dimension, PartialShape # pylint: disable=no-name-in-module,import-error
|
||||
from ngraph.frontend import Place # pylint: disable=no-name-in-module,import-error
|
||||
from ngraph.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
|
||||
from ngraph.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
|
||||
|
||||
|
||||
def moc_pipeline(argv: argparse.Namespace):
|
||||
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
"""
|
||||
Load input model and convert it to nGraph function
|
||||
:param: parsed command line arguments
|
||||
:param: argv: parsed command line arguments
|
||||
:param: moc_front_end: Loaded Frontend for converting input model
|
||||
:return: converted nGraph function ready for serialization
|
||||
"""
|
||||
fem = argv.feManager
|
||||
log.debug('Available front ends: {}'.format(
|
||||
str(fem.get_available_front_ends())))
|
||||
log.debug('Initializing new FE for framework {}'.format(argv.framework))
|
||||
fe = fem.load_by_framework(argv.framework)
|
||||
input_model = fe.load(argv.input_model)
|
||||
input_model = moc_front_end.load(argv.input_model)
|
||||
|
||||
user_shapes, outputs, freeze_placeholder = fe_user_data_repack(
|
||||
input_model, argv.placeholder_shapes, argv.placeholder_data_types,
|
||||
@ -78,7 +74,6 @@ def moc_pipeline(argv: argparse.Namespace):
|
||||
|
||||
def shape_to_array(shape: PartialShape):
|
||||
return [shape.get_dimension(i) for i in range(shape.rank.get_length())]
|
||||
return
|
||||
|
||||
# Set batch size
|
||||
if argv.batch is not None and argv.batch > 0:
|
||||
@ -100,5 +95,5 @@ def moc_pipeline(argv: argparse.Namespace):
|
||||
joined_name, old_shape_array, new_shape))
|
||||
input_model.set_partial_shape(place, new_partial_shape)
|
||||
|
||||
ngraph_function = fe.convert(input_model)
|
||||
ngraph_function = moc_front_end.convert(input_model)
|
||||
return ngraph_function
|
||||
|
@ -46,7 +46,7 @@ mock_needed = pytest.mark.skipif(not mock_available,
|
||||
def replaceArgsHelper(log_level='DEBUG',
|
||||
silent=False,
|
||||
model_name='abc',
|
||||
input_model='abc.abc',
|
||||
input_model='abc.test_mo_mock_mdl',
|
||||
transform=[],
|
||||
legacy_ir_generation=False,
|
||||
scale=None,
|
||||
@ -73,7 +73,8 @@ def replaceArgsHelper(log_level='DEBUG',
|
||||
mean_values=mean_values,
|
||||
scale_values=scale_values,
|
||||
output_dir=output_dir,
|
||||
freeze_placeholder_with_value=freeze_placeholder_with_value)
|
||||
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
||||
framework=None)
|
||||
|
||||
|
||||
class TestMainFrontend(unittest.TestCase):
|
||||
@ -97,9 +98,35 @@ class TestMainFrontend(unittest.TestCase):
|
||||
group(1).replace("\r", "")
|
||||
assert xml_file and bin_file
|
||||
|
||||
# verify that 'convert' was called
|
||||
# verify that 'convert' was called, and 'supported' was not
|
||||
stat = get_frontend_statistic()
|
||||
assert stat.convert_model == 1
|
||||
assert stat.supported == 0
|
||||
# verify that meta info is added to XML file
|
||||
with open(xml_file) as file:
|
||||
assert 'mock_mo_ngraph_frontend' in file.read()
|
||||
|
||||
@mock_needed
|
||||
@patch('argparse.ArgumentParser.parse_args',
|
||||
return_value=replaceArgsHelper())
|
||||
def test_convert_framework_discover(self, mock_argparse):
|
||||
f = io.StringIO()
|
||||
with redirect_stdout(f):
|
||||
main(argparse.ArgumentParser(), fem, None)
|
||||
out = f.getvalue()
|
||||
|
||||
xml_file = re.search(r'\[ SUCCESS \] XML file: (.*)', out). \
|
||||
group(1).replace("\r", "")
|
||||
bin_file = re.search(r'\[ SUCCESS \] BIN file: (.*)', out). \
|
||||
group(1).replace("\r", "")
|
||||
assert xml_file and bin_file
|
||||
|
||||
# verify that 'convert', 'supported' and 'get_name' were called
|
||||
stat = get_frontend_statistic()
|
||||
assert stat.convert_model == 1
|
||||
assert stat.supported == 1
|
||||
assert stat.get_name > 0
|
||||
|
||||
# verify that meta info is added to XML file
|
||||
with open(xml_file) as file:
|
||||
assert 'mock_mo_ngraph_frontend' in file.read()
|
||||
@ -227,3 +254,19 @@ class TestMainFrontend(unittest.TestCase):
|
||||
assert stat.get_partial_shape == 1
|
||||
# verify that 'set_element_type' was not called
|
||||
assert stat.set_partial_shape == 0
|
||||
|
||||
@mock_needed
|
||||
@patch('argparse.ArgumentParser.parse_args',
|
||||
return_value=replaceArgsHelper(input_model='abc.qwerty'))
|
||||
def test_error_input_model_no_framework(self, mock_argparse):
|
||||
# Framework is not specified and 'abc.qwerty' is not supported
|
||||
# so MO shall not convert anything and produce specified error
|
||||
with self.assertLogs() as logger:
|
||||
main(argparse.ArgumentParser(), fem, None)
|
||||
|
||||
stat = get_frontend_statistic()
|
||||
|
||||
assert [s for s in logger.output if 'can not be deduced' in s]
|
||||
|
||||
# verify that 'supported' was called
|
||||
assert stat.supported == 1
|
||||
|
@ -393,9 +393,13 @@ struct MOCK_API FeStat
|
||||
{
|
||||
std::vector<std::string> m_load_paths;
|
||||
int m_convert_model = 0;
|
||||
int m_supported = 0;
|
||||
int m_get_name = 0;
|
||||
// Getters
|
||||
std::vector<std::string> load_paths() const { return m_load_paths; }
|
||||
int convert_model() const { return m_convert_model; }
|
||||
int supported() const { return m_supported; }
|
||||
int get_name() const { return m_get_name; }
|
||||
};
|
||||
|
||||
/// \brief Mock implementation of FrontEnd
|
||||
@ -428,4 +432,24 @@ private:
|
||||
}
|
||||
return std::make_shared<InputModelMockPy>();
|
||||
}
|
||||
|
||||
bool supported_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
|
||||
{
|
||||
m_stat.m_supported++;
|
||||
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
|
||||
{
|
||||
auto path = as_type_ptr<VariantWrapper<std::string>>(params[0])->get();
|
||||
if (path.find(".test_mo_mock_mdl") != std::string::npos)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string get_name() const override
|
||||
{
|
||||
m_stat.m_get_name++;
|
||||
return "mock_mo_ngraph_frontend";
|
||||
}
|
||||
};
|
||||
|
@ -19,6 +19,8 @@ static void register_mock_frontend_stat(py::module m)
|
||||
py::class_<FeStat> feStat(m, "FeStat", py::dynamic_attr());
|
||||
feStat.def_property_readonly("load_paths", &FeStat::load_paths);
|
||||
feStat.def_property_readonly("convert_model", &FeStat::convert_model);
|
||||
feStat.def_property_readonly("supported", &FeStat::supported);
|
||||
feStat.def_property_readonly("get_name", &FeStat::get_name);
|
||||
}
|
||||
|
||||
static void register_mock_setup(py::module m)
|
||||
|
@ -83,6 +83,12 @@ namespace ngraph
|
||||
/// \param function partially converted nGraph function
|
||||
virtual void normalize(std::shared_ptr<ngraph::Function> function) const;
|
||||
|
||||
/// \brief Gets name of this FrontEnd. Can be used by clients
|
||||
/// if frontend is selected automatically by FrontEndManager::load_by_model
|
||||
///
|
||||
/// \return Current frontend name. Empty string if not implemented
|
||||
virtual std::string get_name() const;
|
||||
|
||||
protected:
|
||||
virtual bool
|
||||
supported_impl(const std::vector<std::shared_ptr<Variant>>& variants) const;
|
||||
|
@ -47,10 +47,13 @@ namespace ngraph
|
||||
/// Selects and loads appropriate frontend depending on model file extension and other
|
||||
/// file info (header)
|
||||
///
|
||||
/// \param framework
|
||||
/// Framework name. Throws exception if name is not in list of available frontends
|
||||
/// \param vars Any number of parameters of any type. What kind of parameters
|
||||
/// are accepted is determined by each FrontEnd individually, typically it is
|
||||
/// std::string containing path to the model file. For more information please
|
||||
/// refer to specific FrontEnd documentation.
|
||||
///
|
||||
/// \return Frontend interface for further loading of model
|
||||
/// \return Frontend interface for further loading of model. Returns 'nullptr'
|
||||
/// if no suitable frontend is found
|
||||
template <typename... Types>
|
||||
FrontEnd::Ptr load_by_model(const Types&... vars)
|
||||
{
|
||||
|
@ -167,6 +167,11 @@ void FrontEnd::normalize(std::shared_ptr<ngraph::Function> function) const
|
||||
FRONT_END_NOT_IMPLEMENTED(normalize);
|
||||
}
|
||||
|
||||
std::string FrontEnd::get_name() const
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
|
||||
//----------- InputModel ---------------------------
|
||||
std::vector<Place::Ptr> InputModel::get_inputs() const
|
||||
{
|
||||
|
@ -43,6 +43,12 @@ namespace ngraph
|
||||
/// \return nGraph function after decoding
|
||||
std::shared_ptr<Function> decode(InputModel::Ptr model) const override;
|
||||
|
||||
/// \brief Gets name of this FrontEnd. Can be used by clients
|
||||
/// if frontend is selected automatically by FrontEndManager::load_by_model
|
||||
///
|
||||
/// \return Paddle frontend name.
|
||||
std::string get_name() const override;
|
||||
|
||||
protected:
|
||||
/// \brief Check if FrontEndPDPD can recognize model from given parts
|
||||
/// \param params Can be path to folder which contains __model__ file or path to
|
||||
|
@ -399,6 +399,8 @@ namespace ngraph
|
||||
auto f = convert_each_node(pdpd_model, pdpd::make_framework_node);
|
||||
return f;
|
||||
}
|
||||
|
||||
std::string FrontEndPDPD::get_name() const { return "paddle"; }
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
||||
|
||||
|
@ -123,4 +123,16 @@ void regclass_pyngraph_FrontEnd(py::module m)
|
||||
function : Function
|
||||
Partially converted nGraph function.
|
||||
)");
|
||||
|
||||
fem.def("get_name",
|
||||
&ngraph::frontend::FrontEnd::get_name,
|
||||
R"(
|
||||
Gets name of this FrontEnd. Can be used by clients
|
||||
if frontend is selected automatically by FrontEndManager::load_by_model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
get_name : str
|
||||
Current frontend name. Empty string if not implemented.
|
||||
)");
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ void regclass_pyngraph_FrontEndManager(py::module m)
|
||||
get_available_front_ends : List[str]
|
||||
List of available frontend names.
|
||||
)");
|
||||
|
||||
fem.def("load_by_framework",
|
||||
&ngraph::frontend::FrontEndManager::load_by_framework,
|
||||
py::arg("framework"),
|
||||
@ -51,6 +52,25 @@ void regclass_pyngraph_FrontEndManager(py::module m)
|
||||
load_by_framework : FrontEnd
|
||||
Frontend interface for further loading of models.
|
||||
)");
|
||||
|
||||
fem.def(
|
||||
"load_by_model",
|
||||
[](const std::shared_ptr<ngraph::frontend::FrontEndManager>& fem,
|
||||
const std::string& model_path) { return fem->load_by_model(model_path); },
|
||||
py::arg("model_path"),
|
||||
R"(
|
||||
Selects and loads appropriate frontend depending on model file extension and other file info (header).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path : str
|
||||
Path to model file/directory.
|
||||
|
||||
Returns
|
||||
----------
|
||||
load_by_model : FrontEnd
|
||||
Frontend interface for further loading of models. 'None' if no suitable frontend is found
|
||||
)");
|
||||
}
|
||||
|
||||
void regclass_pyngraph_GeneralFailureFrontEnd(py::module m)
|
||||
|
@ -485,6 +485,8 @@ struct MOCK_API FeStat
|
||||
int m_convert_partially = 0;
|
||||
int m_decode = 0;
|
||||
int m_normalize = 0;
|
||||
int m_get_name = 0;
|
||||
int m_supported = 0;
|
||||
// Getters
|
||||
std::vector<std::string> load_paths() const { return m_load_paths; }
|
||||
int convert_model() const { return m_convert_model; }
|
||||
@ -492,6 +494,8 @@ struct MOCK_API FeStat
|
||||
int convert_partially() const { return m_convert_partially; }
|
||||
int decode() const { return m_decode; }
|
||||
int normalize() const { return m_normalize; }
|
||||
int get_name() const { return m_get_name; }
|
||||
int supported() const { return m_supported; }
|
||||
};
|
||||
|
||||
class MOCK_API FrontEndMockPy : public FrontEnd
|
||||
@ -509,6 +513,20 @@ public:
|
||||
return std::make_shared<InputModelMockPy>();
|
||||
}
|
||||
|
||||
bool supported_impl(const std::vector<std::shared_ptr<Variant>>& params) const override
|
||||
{
|
||||
m_stat.m_supported++;
|
||||
if (params.size() > 0 && is_type<VariantWrapper<std::string>>(params[0]))
|
||||
{
|
||||
auto path = as_type_ptr<VariantWrapper<std::string>>(params[0])->get();
|
||||
if (path.find(".test_mock_py_mdl") != std::string::npos)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> convert(InputModel::Ptr model) const override
|
||||
{
|
||||
m_stat.m_convert_model++;
|
||||
@ -534,5 +552,11 @@ public:
|
||||
m_stat.m_normalize++;
|
||||
}
|
||||
|
||||
std::string get_name() const override
|
||||
{
|
||||
m_stat.m_get_name++;
|
||||
return "mock_py";
|
||||
}
|
||||
|
||||
FeStat get_stat() const { return m_stat; }
|
||||
};
|
||||
|
@ -33,6 +33,8 @@ static void register_mock_frontend_stat(py::module m)
|
||||
feStat.def_property_readonly("convert_partially", &FeStat::convert_partially);
|
||||
feStat.def_property_readonly("decode", &FeStat::decode);
|
||||
feStat.def_property_readonly("normalize", &FeStat::normalize);
|
||||
feStat.def_property_readonly("get_name", &FeStat::get_name);
|
||||
feStat.def_property_readonly("supported", &FeStat::supported);
|
||||
}
|
||||
|
||||
static void register_mock_model_stat(py::module m)
|
||||
|
@ -52,6 +52,16 @@ def test_load():
|
||||
assert "abc.bin" in stat.load_paths
|
||||
|
||||
|
||||
@mock_needed
|
||||
def test_load_by_model():
|
||||
fe = fem.load_by_model(model_path="abc.test_mock_py_mdl")
|
||||
assert fe is not None
|
||||
assert fe.get_name() == "mock_py"
|
||||
stat = get_fe_stat(fe)
|
||||
assert stat.get_name == 1
|
||||
assert stat.supported == 1
|
||||
|
||||
|
||||
@mock_needed
|
||||
def test_convert_model():
|
||||
fe = fem.load_by_framework(framework="mock_py")
|
||||
@ -90,6 +100,16 @@ def test_decode_and_normalize():
|
||||
assert stat.decode == 1
|
||||
|
||||
|
||||
@mock_needed
|
||||
def test_get_name():
|
||||
fe = fem.load_by_framework(framework="mock_py")
|
||||
assert fe is not None
|
||||
name = fe.get_name()
|
||||
assert name == "mock_py"
|
||||
stat = get_fe_stat(fe)
|
||||
assert stat.get_name == 1
|
||||
|
||||
|
||||
# --------InputModel tests-----------------
|
||||
@mock_needed
|
||||
def init_model():
|
||||
|
@ -59,6 +59,9 @@ TEST(FrontEndManagerTest, testMockPluginFrontEnd)
|
||||
FrontEndManager fem;
|
||||
auto frontends = fem.get_available_front_ends();
|
||||
ASSERT_NE(std::find(frontends.begin(), frontends.end(), "mock1"), frontends.end());
|
||||
FrontEnd::Ptr fe;
|
||||
ASSERT_NO_THROW(fe = fem.load_by_framework("mock1"));
|
||||
ASSERT_EQ(fe->get_name(), "mock1");
|
||||
set_test_env("OV_FRONTEND_PATH", "");
|
||||
}
|
||||
|
||||
@ -77,6 +80,7 @@ TEST(FrontEndManagerTest, testDefaultFrontEnd)
|
||||
ASSERT_ANY_THROW(fe->convert_partially(nullptr));
|
||||
ASSERT_ANY_THROW(fe->decode(nullptr));
|
||||
ASSERT_ANY_THROW(fe->normalize(nullptr));
|
||||
ASSERT_EQ(fe->get_name(), std::string());
|
||||
}
|
||||
|
||||
TEST(FrontEndManagerTest, testDefaultInputModel)
|
||||
|
@ -18,6 +18,8 @@ using namespace ngraph::frontend;
|
||||
|
||||
class FrontEndMock : public FrontEnd
|
||||
{
|
||||
public:
|
||||
std::string get_name() const override { return "mock1"; }
|
||||
};
|
||||
|
||||
extern "C" MOCK_API FrontEndVersion GetAPIVersion()
|
||||
|
@ -37,6 +37,7 @@ void FrontEndBasicTest::doLoadFromFile()
|
||||
TEST_P(FrontEndBasicTest, testLoadFromFile)
|
||||
{
|
||||
ASSERT_NO_THROW(doLoadFromFile());
|
||||
ASSERT_EQ(m_frontEnd->get_name(), m_feName);
|
||||
std::shared_ptr<ngraph::Function> function;
|
||||
ASSERT_NO_THROW(function = m_frontEnd->convert(m_inputModel));
|
||||
ASSERT_NE(function, nullptr);
|
||||
|
Loading…
Reference in New Issue
Block a user