moved TransformationsTestsF method's definitions from .hpp to .cpp (#11359)
* moved * fix style
This commit is contained in:
parent
16a5962698
commit
a87e8f7880
@ -4,6 +4,52 @@
|
|||||||
|
|
||||||
#include "ngraph_test_utils.hpp"
|
#include "ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
||||||
|
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||||
|
// TODO: enable attributes and constant values comparison by default XXX-68694
|
||||||
|
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
|
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransformationTestsF::SetUp() {
|
||||||
|
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransformationTestsF::TearDown() {
|
||||||
|
auto cloned_function = ngraph::clone_function(*function);
|
||||||
|
if (!function_ref) {
|
||||||
|
function_ref = cloned_function;
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
|
||||||
|
manager.run_passes(function);
|
||||||
|
if (!m_disable_rt_info_check) {
|
||||||
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
||||||
|
auto acc_comparator = FunctionsComparator::no_default();
|
||||||
|
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
auto res = acc_comparator.compare(function, cloned_function);
|
||||||
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
|
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
}
|
||||||
|
auto res = comparator.compare(function, function_ref);
|
||||||
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransformationTestsF::disable_rt_info_check() {
|
||||||
|
m_disable_rt_info_check = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransformationTestsF::enable_soft_names_comparison() {
|
||||||
|
m_soft_names_comparison = true;
|
||||||
|
}
|
||||||
|
|
||||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
|
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
|
||||||
|
@ -30,54 +30,17 @@ using TransformationTests = CommonTestUtils::TestsCommon;
|
|||||||
|
|
||||||
class TransformationTestsF : public CommonTestUtils::TestsCommon {
|
class TransformationTestsF : public CommonTestUtils::TestsCommon {
|
||||||
public:
|
public:
|
||||||
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
TransformationTestsF();
|
||||||
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
|
||||||
// TODO: enable attributes and constant values comparison by default XXX-68694
|
|
||||||
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
|
||||||
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetUp() override {
|
void SetUp() override;
|
||||||
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
|
|
||||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TearDown() override {
|
void TearDown() override;
|
||||||
auto cloned_function = ngraph::clone_function(*function);
|
|
||||||
if (!function_ref) {
|
|
||||||
function_ref = cloned_function;
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
|
|
||||||
manager.run_passes(function);
|
|
||||||
if (!m_disable_rt_info_check) {
|
|
||||||
ASSERT_NO_THROW(check_rt_info(function));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
|
||||||
auto acc_comparator = FunctionsComparator::no_default();
|
|
||||||
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
|
||||||
auto res = acc_comparator.compare(function, cloned_function);
|
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
|
||||||
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = comparator.compare(function, function_ref);
|
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
||||||
// first tests must be fixed then this method must be removed XXX-68696
|
// first tests must be fixed then this method must be removed XXX-68696
|
||||||
void disable_rt_info_check() {
|
void disable_rt_info_check();
|
||||||
m_disable_rt_info_check = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void enable_soft_names_comparison() {
|
void enable_soft_names_comparison();
|
||||||
m_soft_names_comparison = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> function, function_ref;
|
std::shared_ptr<ov::Model> function, function_ref;
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
@ -103,4 +66,4 @@ size_t count_ops_of_type(const std::shared_ptr<ngraph::Function>& f) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user