[GNA] Use OV thread_local implementation (#21284)
* [GNA] Use OV thread_local implementation
This commit is contained in:
parent
ad12f114f4
commit
04f2485334
@ -677,7 +677,7 @@ constexpr uint32_t Limitations::kBytesPerCropElement;
|
|||||||
constexpr uint32_t Limitations::kBytesPerConcatElement;
|
constexpr uint32_t Limitations::kBytesPerConcatElement;
|
||||||
constexpr uint32_t Limitations::kMemoryPageSize;
|
constexpr uint32_t Limitations::kMemoryPageSize;
|
||||||
|
|
||||||
thread_local std::shared_ptr<Limitations> Limitations::k_instance{nullptr};
|
ov::threading::ThreadLocal<std::shared_ptr<Limitations>> Limitations::kInstance{nullptr};
|
||||||
|
|
||||||
Limitations::Limitations(const DeviceVersion& target) {
|
Limitations::Limitations(const DeviceVersion& target) {
|
||||||
m_use_only_16bit_conv_weights =
|
m_use_only_16bit_conv_weights =
|
||||||
@ -689,7 +689,13 @@ Limitations::Limitations(const DeviceVersion& target) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Limitations::init(const DeviceVersion& compile_target) {
|
void Limitations::init(const DeviceVersion& compile_target) {
|
||||||
k_instance = std::shared_ptr<Limitations>(new Limitations(compile_target));
|
auto& localInstance = kInstance.local();
|
||||||
|
localInstance.reset(new Limitations(compile_target));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Limitations::deinit() {
|
||||||
|
auto& localInstance = kInstance.local();
|
||||||
|
localInstance.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Limitations::get_min_batch_to_fit_in_buffer(InferenceEngine::DataPtr input) {
|
size_t Limitations::get_min_batch_to_fit_in_buffer(InferenceEngine::DataPtr input) {
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
#include "legacy/ngraph_ops/fully_connected.hpp"
|
#include "legacy/ngraph_ops/fully_connected.hpp"
|
||||||
#include "ngraph/opsets/opset7.hpp"
|
#include "ngraph/opsets/opset7.hpp"
|
||||||
#include "ngraph/opsets/opset9.hpp"
|
#include "ngraph/opsets/opset9.hpp"
|
||||||
|
#include "openvino/runtime/threading/thread_local.hpp"
|
||||||
#include "ops/gna_convolution.hpp"
|
#include "ops/gna_convolution.hpp"
|
||||||
#include "ops/gna_max_pool.hpp"
|
#include "ops/gna_max_pool.hpp"
|
||||||
|
|
||||||
@ -164,12 +165,17 @@ public:
|
|||||||
class Limitations {
|
class Limitations {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Create instance of the Limitations class. Due to Limitations being a singleton, multiple instances of the
|
* @brief Create an instance of the Limitations class. Since Limitations is designed as a singleton, multiple
|
||||||
* plugin with different compilation targets cannot exist at the same time
|
* instances of the plugin with different compilation targets cannot coexist simultaneously for the same thread.
|
||||||
* @param compile_target GNA compile target
|
* @param compile_target GNA compile target
|
||||||
*/
|
*/
|
||||||
static void init(const target::DeviceVersion& compile_target);
|
static void init(const target::DeviceVersion& compile_target);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Delete the instance of the Limitations class for the currently running thread.
|
||||||
|
*/
|
||||||
|
static void deinit();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Returns the instance of Limitations object. Requires an Init call before the first usage
|
* @brief Returns the instance of Limitations object. Requires an Init call before the first usage
|
||||||
*/
|
*/
|
||||||
@ -309,14 +315,16 @@ private:
|
|||||||
bool m_use_only_16bit_conv_weights = false;
|
bool m_use_only_16bit_conv_weights = false;
|
||||||
size_t m_mem_alignment = 0;
|
size_t m_mem_alignment = 0;
|
||||||
std::shared_ptr<cnn2d::AbstractValidator> m_cnn_validator;
|
std::shared_ptr<cnn2d::AbstractValidator> m_cnn_validator;
|
||||||
static thread_local std::shared_ptr<Limitations> k_instance;
|
|
||||||
|
static ov::threading::ThreadLocal<std::shared_ptr<Limitations>> kInstance;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline std::shared_ptr<Limitations> Limitations::get_instance() {
|
inline std::shared_ptr<Limitations> Limitations::get_instance() {
|
||||||
if (!k_instance) {
|
auto& instance = kInstance.local();
|
||||||
|
if (!instance) {
|
||||||
THROW_GNA_EXCEPTION << "Limitations instance is not initialized.\n";
|
THROW_GNA_EXCEPTION << "Limitations instance is not initialized.\n";
|
||||||
}
|
}
|
||||||
return k_instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool Limitations::is_crop_affined_offset(size_t numberOfElements) const {
|
inline bool Limitations::is_crop_affined_offset(size_t numberOfElements) const {
|
||||||
|
@ -1429,4 +1429,6 @@ InferenceEngine::QueryNetworkResult GNAPlugin::QueryNetwork(
|
|||||||
GNAPlugin::~GNAPlugin() {
|
GNAPlugin::~GNAPlugin() {
|
||||||
if (gnadevice)
|
if (gnadevice)
|
||||||
gnadevice->close();
|
gnadevice->close();
|
||||||
|
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,10 @@ protected:
|
|||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
Limitations::init(target::DeviceVersion::Default);
|
Limitations::init(target::DeviceVersion::Default);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: add test for FC weights after quantization
|
// TODO: add test for FC weights after quantization
|
||||||
|
@ -54,6 +54,10 @@ protected:
|
|||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
Limitations::init(target::DeviceVersion::Default);
|
Limitations::init(target::DeviceVersion::Default);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
|
@ -287,6 +287,10 @@ protected:
|
|||||||
ASSERT_TRUE(validator);
|
ASSERT_TRUE(validator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<cnn2d::AbstractValidator> validator;
|
std::shared_ptr<cnn2d::AbstractValidator> validator;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ void RunVariadicSplitSupportedTest(DeviceVersion device_version, std::vector<Var
|
|||||||
split_lengths));
|
split_lengths));
|
||||||
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
|
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
|
||||||
}
|
}
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CheckSplitSupported, CheckVariadicSplitSupported_GNA3_5) {
|
TEST(CheckSplitSupported, CheckVariadicSplitSupported_GNA3_5) {
|
||||||
@ -108,6 +109,7 @@ void RunSplitSupportedTest(DeviceVersion device_version, std::vector<SplitParame
|
|||||||
num_splits);
|
num_splits);
|
||||||
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
|
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
|
||||||
}
|
}
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CheckSplitSupported, CheckSplitSupported_GNA3_5) {
|
TEST(CheckSplitSupported, CheckSplitSupported_GNA3_5) {
|
||||||
|
@ -152,11 +152,13 @@ class MemoryAlignmentTest : public ::testing::Test {};
|
|||||||
TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect64ByteAlignmentWhenTargetIsGNA3_5) {
|
TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect64ByteAlignmentWhenTargetIsGNA3_5) {
|
||||||
Limitations::init(DeviceVersion::GNA3_5);
|
Limitations::init(DeviceVersion::GNA3_5);
|
||||||
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 64);
|
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 64);
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect16ByteAlignmentWhenTargetIsGNA3_6) {
|
TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect16ByteAlignmentWhenTargetIsGNA3_6) {
|
||||||
Limitations::init(DeviceVersion::GNA3_6);
|
Limitations::init(DeviceVersion::GNA3_6);
|
||||||
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 16);
|
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 16);
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
|
@ -295,6 +295,7 @@ class Decompose2DConvTestInvalidFixture : public ov::test::TestsCommon,
|
|||||||
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
|
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
|
||||||
public:
|
public:
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
|
void TearDown() override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
@ -339,12 +340,17 @@ void Decompose2DConvTestInvalidFixture::SetUp() {
|
|||||||
conv_params);
|
conv_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Decompose2DConvTestInvalidFixture::TearDown() {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
class Decompose2DConvTestFixture : public ov::test::TestsCommon,
|
class Decompose2DConvTestFixture : public ov::test::TestsCommon,
|
||||||
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
|
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
|
||||||
public:
|
public:
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
|
void TearDown() override;
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
|
std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
|
||||||
const modelType& model,
|
const modelType& model,
|
||||||
@ -385,6 +391,10 @@ void Decompose2DConvTestFixture::SetUp() {
|
|||||||
reference_function = get_reference(fq, model, input_shape, graph_data, conv_params);
|
reference_function = get_reference(fq, model, input_shape, graph_data, conv_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Decompose2DConvTestFixture::TearDown() {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> ReshapeBiasConst(std::shared_ptr<ngraph::opset7::Add> conv_bias,
|
std::shared_ptr<ngraph::Node> ReshapeBiasConst(std::shared_ptr<ngraph::opset7::Add> conv_bias,
|
||||||
const ConvParams& conv_params) {
|
const ConvParams& conv_params) {
|
||||||
auto add_const =
|
auto add_const =
|
||||||
|
@ -44,6 +44,7 @@ public:
|
|||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
|
void TearDown() override;
|
||||||
virtual void Validate();
|
virtual void Validate();
|
||||||
virtual void Run();
|
virtual void Run();
|
||||||
|
|
||||||
@ -64,6 +65,9 @@ void InsertCopyLayerTest::SetUp() {
|
|||||||
std::tie(m_device_ver, m_axis, m_inputs_num) = this->GetParam();
|
std::tie(m_device_ver, m_axis, m_inputs_num) = this->GetParam();
|
||||||
Limitations::init(m_device_ver);
|
Limitations::init(m_device_ver);
|
||||||
}
|
}
|
||||||
|
void InsertCopyLayerTest::TearDown() {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
void InsertCopyLayerTest::Run() {
|
void InsertCopyLayerTest::Run() {
|
||||||
Validate();
|
Validate();
|
||||||
@ -212,6 +216,7 @@ public:
|
|||||||
|
|
||||||
void TearDown() override {
|
void TearDown() override {
|
||||||
m_func.reset();
|
m_func.reset();
|
||||||
|
Limitations::deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RunPasses(ngraph::pass::Manager& m) {
|
void RunPasses(ngraph::pass::Manager& m) {
|
||||||
|
@ -270,6 +270,7 @@ class SplitConvolutionFixture : public ov::test::TestsCommon,
|
|||||||
public ::testing::WithParamInterface<std::tuple<DeviceVersion, TestParams>> {
|
public ::testing::WithParamInterface<std::tuple<DeviceVersion, TestParams>> {
|
||||||
public:
|
public:
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
|
void TearDown() override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
@ -290,6 +291,10 @@ void SplitConvolutionFixture::SetUp() {
|
|||||||
reference_function = reference_graph.createFunction();
|
reference_function = reference_graph.createFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SplitConvolutionFixture::TearDown() {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
void execute_test(std::shared_ptr<ngraph::Function> function,
|
void execute_test(std::shared_ptr<ngraph::Function> function,
|
||||||
std::shared_ptr<ngraph::Function> reference_function,
|
std::shared_ptr<ngraph::Function> reference_function,
|
||||||
ngraph::pass::Manager& pass_manager) {
|
ngraph::pass::Manager& pass_manager) {
|
||||||
|
@ -134,6 +134,7 @@ class SplitEltwiseTestSuiteFixture : public ov::test::TestsCommon,
|
|||||||
public ::testing::WithParamInterface<EltwiseSplitParams> {
|
public ::testing::WithParamInterface<EltwiseSplitParams> {
|
||||||
public:
|
public:
|
||||||
void SetUp() override;
|
void SetUp() override;
|
||||||
|
void TearDown() override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
@ -151,6 +152,10 @@ void SplitEltwiseTestSuiteFixture::SetUp() {
|
|||||||
reference_function = createFunction(shape, with_const, with_fq, type, true);
|
reference_function = createFunction(shape, with_const, with_fq, type, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SplitEltwiseTestSuiteFixture::TearDown() {
|
||||||
|
Limitations::deinit();
|
||||||
|
}
|
||||||
|
|
||||||
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
|
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||||
|
Loading…
Reference in New Issue
Block a user