fix conflicts with master branch
This commit is contained in:
parent
22dc27b7fb
commit
c14a46a3f8
@ -41,7 +41,7 @@ void regclass_frontend_TelemetryExtension(py::module m) {
|
||||
}
|
||||
}
|
||||
|
||||
void regclass_frontend_DecoderTransformationNExtension(py::module m) {
|
||||
void regclass_frontend_DecoderTransformationExtension(py::module m) {
|
||||
{
|
||||
py::class_<ov::frontend::DecoderTransformationExtension,
|
||||
std::shared_ptr<ov::frontend::DecoderTransformationExtension>,
|
||||
|
@ -10,5 +10,5 @@ namespace py = pybind11;
|
||||
|
||||
void regclass_frontend_Extension(py::module m);
|
||||
void regclass_frontend_TelemetryExtension(py::module m);
|
||||
void regclass_frontend_DecoderTransformationNExtension(py::module m);
|
||||
void regclass_frontend_DecoderTransformationExtension(py::module m);
|
||||
void regclass_frontend_JsonConfigExtension(py::module m);
|
||||
|
@ -21,13 +21,17 @@ add_library(${TARGET_NAME} STATIC EXCLUDE_FROM_ALL ${LIBRARY_SRC} ${PUBLIC_HEADE
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PUBLIC ngraph inference_engine_transformations ngraph::reference
|
||||
nlohmann_json_schema_validator
|
||||
PRIVATE openvino::itt pugixml::static)
|
||||
PRIVATE openvino::itt pugixml::static openvino::frontend::common)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR}
|
||||
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src")
|
||||
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})
|
||||
|
||||
# Add include path to so_extension.hpp
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/src/json_extension/json_config_extension.cpp
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${OpenVINO_SOURCE_DIR}/src/core/src/")
|
||||
|
||||
# developer package
|
||||
|
||||
openvino_developer_export_targets(COMPONENT core TARGETS ${TARGET_NAME})
|
||||
|
@ -42,7 +42,7 @@ public:
|
||||
return m_id;
|
||||
}
|
||||
|
||||
virtual bool transform(std::shared_ptr<ov::Function>& function,
|
||||
virtual bool transform(std::shared_ptr<ov::Model>& function,
|
||||
const nlohmann::json& replacement_descriptions) const = 0;
|
||||
|
||||
private:
|
@ -13,7 +13,7 @@ using namespace ov;
|
||||
using namespace ov::frontend;
|
||||
|
||||
JsonConfigExtension::JsonConfigExtension(const std::string& config_path)
|
||||
: DecoderTransformationExtension([this](std::shared_ptr<ov::Function> f) {
|
||||
: DecoderTransformationExtension([this](std::shared_ptr<ov::Model> f) {
|
||||
bool res = true;
|
||||
for (const auto& target_extension : m_target_extensions) {
|
||||
if (auto extension = std::dynamic_pointer_cast<JsonTransformationExtension>(target_extension.first)) {
|
@ -89,7 +89,7 @@ private:
|
||||
|
||||
class OPENVINO_API ModelPass : public PassBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::FunctionPass");
|
||||
OPENVINO_RTTI("ov::pass::ModelPass");
|
||||
~ModelPass() override;
|
||||
OPENVINO_DEPRECATED("run_on_function() method is deprecated. Please use run_on_model() instead.")
|
||||
virtual bool run_on_function(std::shared_ptr<ov::Model> m);
|
||||
|
@ -17,30 +17,30 @@ TEST(DecoderTransformation, MatcherPass) {
|
||||
|
||||
ov::pass::Manager manager;
|
||||
decoder_ext.register_pass(manager);
|
||||
manager.run_passes(std::make_shared<ov::Function>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
manager.run_passes(std::make_shared<ov::Model>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
EXPECT_EQ(flag, true);
|
||||
}
|
||||
|
||||
TEST(DecoderTransformation, FunctionPass) {
|
||||
bool flag = false;
|
||||
DecoderTransformationExtension decoder_ext([&](const std::shared_ptr<ov::Function>&) {
|
||||
DecoderTransformationExtension decoder_ext([&](const std::shared_ptr<ov::Model>&) {
|
||||
flag = true;
|
||||
return flag;
|
||||
});
|
||||
|
||||
ov::pass::Manager manager;
|
||||
decoder_ext.register_pass(manager);
|
||||
manager.run_passes(std::make_shared<ov::Function>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
manager.run_passes(std::make_shared<ov::Model>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
EXPECT_EQ(flag, true);
|
||||
}
|
||||
|
||||
TEST(DecoderTransformation, TestPass) {
|
||||
class TestPass : public ov::pass::FunctionPass {
|
||||
class TestPass : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TestPass");
|
||||
TestPass() = default;
|
||||
TestPass(const TestPass& tp) = default;
|
||||
bool run_on_function(std::shared_ptr<ov::Function>) override {
|
||||
bool run_on_function(std::shared_ptr<ov::Model>) override {
|
||||
*m_flag = true;
|
||||
return *m_flag;
|
||||
}
|
||||
@ -50,6 +50,6 @@ TEST(DecoderTransformation, TestPass) {
|
||||
|
||||
ov::pass::Manager manager;
|
||||
decoder_ext.register_pass(manager);
|
||||
manager.run_passes(std::make_shared<ov::Function>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
manager.run_passes(std::make_shared<ov::Model>(ov::ResultVector{}, ov::ParameterVector{}));
|
||||
EXPECT_EQ(*test_pass.m_flag, true);
|
||||
}
|
||||
|
@ -1,98 +1,98 @@
|
||||
# #
|
||||
# # slice paddle model generator
|
||||
# #
|
||||
# import sys
|
||||
# import os
|
||||
#
|
||||
# slice paddle model generator
|
||||
# import numpy as np
|
||||
# import paddle as pdpd
|
||||
#
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle as pdpd
|
||||
|
||||
from save_model import exportModel
|
||||
from save_model import saveModel
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def slice(name : str, x, axes : list, start : list, end : list):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
node_x = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.fluid.layers.slice(node_x, axes = axes, starts = start, ends = end)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def slice_dyn(test_shape=[2,8,10,10]):
|
||||
pdpd.disable_static()
|
||||
|
||||
data = pdpd.rand(shape=test_shape, dtype='float32')
|
||||
|
||||
'''
|
||||
slice w/ decrease_axis
|
||||
'''
|
||||
@pdpd.jit.to_static
|
||||
def test_slice_decrease_axis(x):
|
||||
return x[0, 1:3, :, 5]
|
||||
exportModel('slice_decrease_axis', test_slice_decrease_axis, [data], target_dir=sys.argv[1]) # output shape (2, 10)
|
||||
|
||||
'''
|
||||
slice w/o decrease_axis
|
||||
'''
|
||||
@pdpd.jit.to_static
|
||||
def test_slice(x):
|
||||
return pdpd.slice(x, axes=[0,1,3], starts=[0,1,5], ends=[1,3,6])
|
||||
# exportModel('slice_dyn', test_slice, [data], target_dir=sys.argv[1]) # output shape (1, 2, 10, 1) # disable it by default as this kind of test model already there. It's for comparsion only.
|
||||
|
||||
'''
|
||||
slice w/ decrease_axis of all dims
|
||||
'''
|
||||
@pdpd.jit.to_static
|
||||
def test_slice_decrease_axis_all(x):
|
||||
return x[0, 0, 0, 0]
|
||||
exportModel('slice_decrease_axis_all', test_slice_decrease_axis_all, [data], target_dir=sys.argv[1]) # output shape (1,)
|
||||
|
||||
'''
|
||||
slice w/o decrease_axis of all dims
|
||||
'''
|
||||
@pdpd.jit.to_static
|
||||
def test_slice_alldim(x):
|
||||
return pdpd.slice(x, axes=[0,1,2,3], starts=[0,0,0,0], ends=[1,1,1,1])
|
||||
# exportModel('slice_alldim', test_slice_alldim, [data], target_dir=sys.argv[1]) # output shape (1, 1, 1, 1) # disable it by default as this kind of test model already there. It's for comparsion only.
|
||||
|
||||
'''
|
||||
a test case simulating the last reshape2 of ocrnet which accepts slice (with decrease_axes in all dims) as its parents.
|
||||
'''
|
||||
def slice_reshape(B=1, C=256, H=16, W=32):
|
||||
pdpd.disable_static()
|
||||
|
||||
data = pdpd.rand(shape=[B, C, H*W], dtype='float32')
|
||||
|
||||
@pdpd.jit.to_static
|
||||
def test_model(x):
|
||||
x2 = pdpd.assign([-1, -1, 16, 32]).astype('int32')
|
||||
node_reshape = pdpd.reshape(x, [0, 256, x2[2], x2[3]])
|
||||
return node_reshape
|
||||
exportModel('slice_reshape', test_model, [data], target_dir=sys.argv[1])
|
||||
|
||||
def main():
|
||||
x = np.linspace(1, 60, num = 60, dtype=np.int32).reshape(4, 3, 5).astype(data_type)
|
||||
slice("slice", x, axes=[1, 2], start=(0, 1), end=(-1, 3))
|
||||
|
||||
x = np.linspace(1, 60, num = 60, dtype=np.int32).reshape(2, 30).astype(data_type)
|
||||
slice("slice_1d", x, axes=[0], start=[0], end=[1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
slice_dyn()
|
||||
slice_reshape()
|
||||
# from save_model import exportModel
|
||||
# from save_model import saveModel
|
||||
#
|
||||
# data_type = 'float32'
|
||||
#
|
||||
# def slice(name : str, x, axes : list, start : list, end : list):
|
||||
# pdpd.enable_static()
|
||||
#
|
||||
# with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
# node_x = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
# out = pdpd.fluid.layers.slice(node_x, axes = axes, starts = start, ends = end)
|
||||
#
|
||||
# cpu = pdpd.static.cpu_places(1)
|
||||
# exe = pdpd.static.Executor(cpu[0])
|
||||
# # startup program will call initializer to initialize the parameters.
|
||||
# exe.run(pdpd.static.default_startup_program())
|
||||
#
|
||||
# outs = exe.run(
|
||||
# feed={'x': x},
|
||||
# fetch_list=[out])
|
||||
#
|
||||
# saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
#
|
||||
# return outs[0]
|
||||
#
|
||||
#
|
||||
# def slice_dyn(test_shape=[2,8,10,10]):
|
||||
# pdpd.disable_static()
|
||||
#
|
||||
# data = pdpd.rand(shape=test_shape, dtype='float32')
|
||||
#
|
||||
# '''
|
||||
# slice w/ decrease_axis
|
||||
# '''
|
||||
# @pdpd.jit.to_static
|
||||
# def test_slice_decrease_axis(x):
|
||||
# return x[0, 1:3, :, 5]
|
||||
# exportModel('slice_decrease_axis', test_slice_decrease_axis, [data], target_dir=sys.argv[1]) # output shape (2, 10)
|
||||
#
|
||||
# '''
|
||||
# slice w/o decrease_axis
|
||||
# '''
|
||||
# @pdpd.jit.to_static
|
||||
# def test_slice(x):
|
||||
# return pdpd.slice(x, axes=[0,1,3], starts=[0,1,5], ends=[1,3,6])
|
||||
# # exportModel('slice_dyn', test_slice, [data], target_dir=sys.argv[1]) # output shape (1, 2, 10, 1) # disable it by default as this kind of test model already there. It's for comparsion only.
|
||||
#
|
||||
# '''
|
||||
# slice w/ decrease_axis of all dims
|
||||
# '''
|
||||
# @pdpd.jit.to_static
|
||||
# def test_slice_decrease_axis_all(x):
|
||||
# return x[0, 0, 0, 0]
|
||||
# exportModel('slice_decrease_axis_all', test_slice_decrease_axis_all, [data], target_dir=sys.argv[1]) # output shape (1,)
|
||||
#
|
||||
# '''
|
||||
# slice w/o decrease_axis of all dims
|
||||
# '''
|
||||
# @pdpd.jit.to_static
|
||||
# def test_slice_alldim(x):
|
||||
# return pdpd.slice(x, axes=[0,1,2,3], starts=[0,0,0,0], ends=[1,1,1,1])
|
||||
# # exportModel('slice_alldim', test_slice_alldim, [data], target_dir=sys.argv[1]) # output shape (1, 1, 1, 1) # disable it by default as this kind of test model already there. It's for comparsion only.
|
||||
#
|
||||
# '''
|
||||
# a test case simulating the last reshape2 of ocrnet which accepts slice (with decrease_axes in all dims) as its parents.
|
||||
# '''
|
||||
# def slice_reshape(B=1, C=256, H=16, W=32):
|
||||
# pdpd.disable_static()
|
||||
#
|
||||
# data = pdpd.rand(shape=[B, C, H*W], dtype='float32')
|
||||
#
|
||||
# @pdpd.jit.to_static
|
||||
# def test_model(x):
|
||||
# x2 = pdpd.assign([-1, -1, 16, 32]).astype('int32')
|
||||
# node_reshape = pdpd.reshape(x, [0, 256, x2[2], x2[3]])
|
||||
# return node_reshape
|
||||
# exportModel('slice_reshape', test_model, [data], target_dir=sys.argv[1])
|
||||
#
|
||||
# def main():
|
||||
# x = np.linspace(1, 60, num = 60, dtype=np.int32).reshape(4, 3, 5).astype(data_type)
|
||||
# slice("slice", x, axes=[1, 2], start=(0, 1), end=(-1, 3))
|
||||
#
|
||||
# x = np.linspace(1, 60, num = 60, dtype=np.int32).reshape(2, 30).astype(data_type)
|
||||
# slice("slice_1d", x, axes=[0], start=[0], end=[1])
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
# slice_dyn()
|
||||
# slice_reshape()
|
@ -9,10 +9,14 @@ file(GLOB_RECURSE LIBRARY_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
|
||||
|
||||
add_library(${TARGET_NAME} STATIC EXCLUDE_FROM_ALL ${LIBRARY_SRC} ${LIBRARY_HEADERS})
|
||||
|
||||
add_subdirectory(test_builtin_extensions_1)
|
||||
add_subdirectory(test_builtin_extensions_2)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
target_include_directories(${TARGET_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../..)
|
||||
target_link_libraries(${TARGET_NAME} PUBLIC frontend_common interpreter_backend engines_test_util
|
||||
ngraph cnpy commonTestUtils ngraph_test_util openvino::util)
|
||||
offline_transformations ngraph cnpy commonTestUtils ngraph_test_util openvino::util
|
||||
test_builtin_extensions_1 test_builtin_extensions_2)
|
||||
|
||||
target_compile_definitions(${TARGET_NAME}
|
||||
PRIVATE
|
||||
|
@ -130,7 +130,7 @@ TEST_P(FrontEndJsonConfigTest, testAddJsonConfigExtension) {
|
||||
ref_val.second["library"] = get_lib_path(ref_val.second["library"]);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Function> function;
|
||||
std::shared_ptr<ov::Model> function;
|
||||
{
|
||||
ov::frontend::FrontEnd::Ptr m_frontEnd;
|
||||
ov::frontend::InputModel::Ptr m_inputModel;
|
||||
@ -162,7 +162,7 @@ TEST_P(FrontEndJsonConfigTest, testAddJsonConfigExtension) {
|
||||
TEST_P(FrontEndJsonConfigTest, compareFunctions) {
|
||||
auto path_to_json = generate_json_config();
|
||||
|
||||
std::shared_ptr<ov::Function> function;
|
||||
std::shared_ptr<ov::Model> function;
|
||||
{
|
||||
ov::frontend::FrontEnd::Ptr m_frontEnd;
|
||||
ov::frontend::InputModel::Ptr m_inputModel;
|
||||
@ -178,7 +178,7 @@ TEST_P(FrontEndJsonConfigTest, compareFunctions) {
|
||||
ASSERT_NE(function, nullptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Function> function_ref;
|
||||
std::shared_ptr<ov::Model> function_ref;
|
||||
{
|
||||
ov::frontend::FrontEnd::Ptr m_frontEnd;
|
||||
ov::frontend::InputModel::Ptr m_inputModel;
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
#include <openvino/core/core.hpp>
|
||||
|
||||
bool TestExtension1::transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const {
|
||||
bool TestExtension1::transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const {
|
||||
function->set_friendly_name("TestFunction");
|
||||
return true;
|
||||
}
|
||||
|
@ -11,5 +11,5 @@ class TestExtension1 : public ov::frontend::JsonTransformationExtension {
|
||||
public:
|
||||
TestExtension1();
|
||||
|
||||
bool transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const override;
|
||||
bool transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const override;
|
||||
};
|
||||
|
@ -6,14 +6,14 @@
|
||||
|
||||
#include <openvino/core/core.hpp>
|
||||
|
||||
bool TestExtension1::transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const {
|
||||
bool TestExtension1::transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const {
|
||||
function->set_friendly_name("TestFunction");
|
||||
return true;
|
||||
}
|
||||
|
||||
TestExtension1::TestExtension1() : ov::frontend::JsonTransformationExtension("buildin_extensions_2::TestExtension1") {}
|
||||
|
||||
bool TestExtension2::transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const {
|
||||
bool TestExtension2::transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const {
|
||||
function->set_friendly_name("TestFunction");
|
||||
return true;
|
||||
}
|
||||
|
@ -11,12 +11,12 @@ class TestExtension1 : public ov::frontend::JsonTransformationExtension {
|
||||
public:
|
||||
TestExtension1();
|
||||
|
||||
bool transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const override;
|
||||
bool transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const override;
|
||||
};
|
||||
|
||||
class TestExtension2 : public ov::frontend::JsonTransformationExtension {
|
||||
public:
|
||||
TestExtension2();
|
||||
|
||||
bool transform(std::shared_ptr<ov::Function>& function, const nlohmann::json& config) const override;
|
||||
bool transform(std::shared_ptr<ov::Model>& function, const nlohmann::json& config) const override;
|
||||
};
|
||||
|
@ -27,7 +27,7 @@ public:
|
||||
DecoderTransformationExtension() = default;
|
||||
|
||||
/// \brief Create a custom functional pass where code of the pass is implemented as a function.
|
||||
explicit DecoderTransformationExtension(const std::function<bool(std::shared_ptr<ov::Function>)>& function_pass);
|
||||
explicit DecoderTransformationExtension(const std::function<bool(std::shared_ptr<ov::Model>)>& function_pass);
|
||||
|
||||
/// \brief Create a custom matcher pass where the code of matcher pass initialization is a given function.
|
||||
explicit DecoderTransformationExtension(
|
||||
|
@ -10,16 +10,16 @@ using namespace ov;
|
||||
using namespace ov::frontend;
|
||||
|
||||
/// \brief Helper class to register user function as a FunctionPass
|
||||
class CustomFunctionPass : public ov::pass::FunctionPass {
|
||||
class CustomFunctionPass : public ov::pass::ModelPass {
|
||||
public:
|
||||
explicit CustomFunctionPass(std::function<bool(std::shared_ptr<ov::Function>)> pass) : m_pass(std::move(pass)) {}
|
||||
explicit CustomFunctionPass(std::function<bool(std::shared_ptr<ov::Model>)> pass) : m_pass(std::move(pass)) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override {
|
||||
bool run_on_function(std::shared_ptr<ov::Model> f) override {
|
||||
return m_pass(f);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<bool(std::shared_ptr<ov::Function>)> m_pass;
|
||||
std::function<bool(std::shared_ptr<ov::Model>)> m_pass;
|
||||
};
|
||||
|
||||
/// \brief Helper class to register user matcher pass initialization as a MatcherPass
|
||||
@ -31,7 +31,7 @@ public:
|
||||
};
|
||||
|
||||
DecoderTransformationExtension::DecoderTransformationExtension(
|
||||
const std::function<bool(std::shared_ptr<ov::Function>)>& function_pass)
|
||||
const std::function<bool(std::shared_ptr<ov::Model>)>& function_pass)
|
||||
: m_registration([=](ov::pass::Manager& manager) {
|
||||
manager.register_pass<CustomFunctionPass>(function_pass);
|
||||
}) {}
|
||||
|
@ -10,7 +10,8 @@
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pdpd {using InPortName = std::string;
|
||||
namespace pdpd {
|
||||
using InPortName = std::string;
|
||||
using OutPortName = std::string;
|
||||
using TensorName = std::string;
|
||||
using NamedOutputs = std::map<OutPortName, OutputVector>;
|
||||
|
@ -308,6 +308,20 @@ ov::frontend::InputModel::Ptr FrontEndTF::load_impl(const std::vector<ov::Any>&
|
||||
|
||||
std::shared_ptr<ov::Model> FrontEndTF::convert(ov::frontend::InputModel::Ptr model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
|
||||
|
||||
if (!m_transformation_extensions.empty()) {
|
||||
auto function = decode(model);
|
||||
|
||||
pass::Manager manager;
|
||||
for (const auto& transformation : m_transformation_extensions) {
|
||||
transformation->register_pass(manager);
|
||||
}
|
||||
manager.run_passes(function);
|
||||
convert(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
translate_graph(model_tf, "here_should_be_a_graph_name", true, false, f);
|
||||
normalize(f);
|
||||
@ -318,6 +332,20 @@ std::shared_ptr<ov::Model> FrontEndTF::convert(ov::frontend::InputModel::Ptr mod
|
||||
|
||||
std::shared_ptr<ov::Model> FrontEndTF::convert_partially(ov::frontend::InputModel::Ptr model) const {
|
||||
auto model_tf = std::dynamic_pointer_cast<InputModelTF>(model);
|
||||
FRONT_END_GENERAL_CHECK(model_tf != nullptr, "Invalid input model");
|
||||
|
||||
if (!m_transformation_extensions.empty()) {
|
||||
auto function = decode(model);
|
||||
|
||||
pass::Manager manager;
|
||||
for (const auto& transformation : m_transformation_extensions) {
|
||||
transformation->register_pass(manager);
|
||||
}
|
||||
manager.run_passes(function);
|
||||
convert(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> f;
|
||||
translate_graph(model_tf, "here_should_be_a_graph_name", false, false, f);
|
||||
normalize(f);
|
||||
@ -353,5 +381,7 @@ void FrontEndTF::normalize(std::shared_ptr<ov::Model> function) const {
|
||||
void FrontEndTF::add_extension(const std::shared_ptr<ov::Extension>& extension) {
|
||||
if (auto telemetry = std::dynamic_pointer_cast<TelemetryExtension>(extension)) {
|
||||
m_telemetry = telemetry;
|
||||
} else if (auto transformation = std::dynamic_pointer_cast<DecoderTransformationExtension>(extension)) {
|
||||
m_transformation_extensions.push_back(transformation);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user