Update Accuracy Check inside Func tests (#11314)
This commit is contained in:
parent
ee31b648d1
commit
8317493e65
@ -92,7 +92,9 @@ TEST_P(FQMulFusion, ExpectFusion) {
|
|||||||
manager.run_passes(m_function);
|
manager.run_passes(m_function);
|
||||||
ASSERT_NO_THROW(check_rt_info(m_function));
|
ASSERT_NO_THROW(check_rt_info(m_function));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
|
.enable(FunctionsComparator::NODES);
|
||||||
auto res = fc.compare(m_function, m_expected_function);
|
auto res = fc.compare(m_function, m_expected_function);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
};
|
};
|
||||||
|
@ -113,7 +113,9 @@ TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
|
|||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
|
.enable(FunctionsComparator::NODES);
|
||||||
auto res = fc.compare(f, ref_f);
|
auto res = fc.compare(f, ref_f);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
|
|||||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
||||||
@ -57,7 +57,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
|||||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {
|
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {
|
||||||
|
@ -998,7 +998,7 @@ TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
|
|||||||
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
if (type == element::f32) {
|
if (type == element::f32) {
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1088,5 +1088,5 @@ TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) {
|
|||||||
manager.register_pass<pass::NopElimination>();
|
manager.register_pass<pass::NopElimination>();
|
||||||
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
@ -120,11 +120,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
|
|||||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
||||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||||
set_tensor_name(pad, "pad");
|
|
||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||||
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
manager.register_pass<pass::PadFusion>();
|
manager.register_pass<pass::PadFusion>();
|
||||||
@ -134,7 +134,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
|
|||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
||||||
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
}
|
}
|
||||||
@ -244,11 +244,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
|
|||||||
std::shared_ptr<Node> pad_value = opset5::Constant::create(element::f16, Shape{}, {0});
|
std::shared_ptr<Node> pad_value = opset5::Constant::create(element::f16, Shape{}, {0});
|
||||||
pad_value = std::make_shared<opset5::Convert>(pad_value, element::f32);
|
pad_value = std::make_shared<opset5::Convert>(pad_value, element::f32);
|
||||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, pad_value, op::PadMode::CONSTANT);
|
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, pad_value, op::PadMode::CONSTANT);
|
||||||
set_tensor_name(pad, "pad");
|
|
||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||||
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
manager.register_pass<pass::PadFusion>();
|
manager.register_pass<pass::PadFusion>();
|
||||||
@ -258,7 +258,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
|
|||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
||||||
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
}
|
}
|
||||||
@ -422,30 +422,29 @@ TEST_F(TransformationTestsF, NegativePadFusionConvolutionBackpropDataTooSmallPad
|
|||||||
Shape data_shape{1, 3, 14, 14};
|
Shape data_shape{1, 3, 14, 14};
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||||
set_tensor_name(data, "data");
|
|
||||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||||
set_tensor_name(pad, "pad");
|
|
||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
manager.register_pass<pass::PadFusion>();
|
manager.register_pass<pass::PadFusion>();
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||||
set_tensor_name(data, "data");
|
|
||||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||||
set_tensor_name(pad, "pad");
|
|
||||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||||
set_tensor_name(conv, "conv");
|
|
||||||
|
|
||||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,7 @@ TEST_F(TransformationTestsF, RICFusionSimple) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionHard) {
|
TEST_F(TransformationTestsF, RICFusionHard) {
|
||||||
@ -198,7 +198,7 @@ TEST_F(TransformationTestsF, RICFusionHard) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionDynamic) {
|
TEST_F(TransformationTestsF, RICFusionDynamic) {
|
||||||
@ -248,7 +248,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwise2) {
|
TEST_F(TransformationTestsF, RICFusionEltwise2) {
|
||||||
@ -273,7 +273,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise2) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwise3) {
|
TEST_F(TransformationTestsF, RICFusionEltwise3) {
|
||||||
@ -298,7 +298,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise3) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwise4) {
|
TEST_F(TransformationTestsF, RICFusionEltwise4) {
|
||||||
@ -324,7 +324,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwise5) {
|
TEST_F(TransformationTestsF, RICFusionEltwise5) {
|
||||||
@ -350,7 +350,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwiseNegative) {
|
TEST_F(TransformationTestsF, RICFusionEltwiseNegative) {
|
||||||
@ -390,7 +390,7 @@ TEST_F(TransformationTestsF, RICFusionEltwiseTwoRIC) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionEltwiseNegative3) {
|
TEST_F(TransformationTestsF, RICFusionEltwiseNegative3) {
|
||||||
@ -431,7 +431,7 @@ TEST_F(TransformationTestsF, RICFusionGroupConv) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionGroupConvNegative) {
|
TEST_F(TransformationTestsF, RICFusionGroupConvNegative) {
|
||||||
@ -473,7 +473,7 @@ TEST_F(TransformationTestsF, RICFusionTranspose) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
|
TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
|
||||||
@ -499,7 +499,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
|
TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
|
||||||
@ -538,7 +538,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
|
TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
|
||||||
@ -578,7 +578,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
|
|||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, RICFusionShapeOf) {
|
TEST_F(TransformationTestsF, RICFusionShapeOf) {
|
||||||
@ -749,7 +749,7 @@ TEST_F(TransformationTestsF, FuseScaleValue) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, FuseScaleValues) {
|
TEST_F(TransformationTestsF, FuseScaleValues) {
|
||||||
@ -779,5 +779,5 @@ TEST_F(TransformationTestsF, FuseScaleValues) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,9 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(ShuffleChannelsFusion, CompareFunctions) {
|
TEST_P(ShuffleChannelsFusion, CompareFunctions) {
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
|
.enable(FunctionsComparator::NODES);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,9 @@ TEST_P(SoftmaxFusionFixture, SoftmaxFusion) {
|
|||||||
f_ref = std::make_shared<Function>(NodeVector{softmax}, ParameterVector{data});
|
f_ref = std::make_shared<Function>(NodeVector{softmax}, ParameterVector{data});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
|
.enable(FunctionsComparator::NODES);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
@ -113,7 +115,9 @@ TEST_P(NegativeSoftmaxFusionFixture, NegativeSoftmaxFusion) {
|
|||||||
f_ref = std::make_shared<Function>(NodeVector{div}, ParameterVector{data});
|
f_ref = std::make_shared<Function>(NodeVector{div}, ParameterVector{data});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
|
.enable(FunctionsComparator::NODES);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -94,6 +94,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
|
|||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default()
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
.enable(FunctionsComparator::PRECISIONS)
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
.enable(FunctionsComparator::CONST_VALUES);
|
.enable(FunctionsComparator::CONST_VALUES);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
@ -180,6 +181,7 @@ TEST_P(TransposeSinking, TransposeReduction) {
|
|||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default()
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
.enable(FunctionsComparator::PRECISIONS)
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
.enable(FunctionsComparator::CONST_VALUES);
|
.enable(FunctionsComparator::CONST_VALUES);
|
||||||
|
|
||||||
|
@ -109,7 +109,9 @@ TEST_P(TransposeToReshapeTests, CompareFunctions) {
|
|||||||
f->validate_nodes_and_infer_types();
|
f->validate_nodes_and_infer_types();
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -84,6 +84,7 @@ TEST_P(TranslateNewWeightFormatToOldOne, ReshapeMatMul) {
|
|||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default()
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
.enable(FunctionsComparator::PRECISIONS)
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
.enable(FunctionsComparator::CONST_VALUES);
|
.enable(FunctionsComparator::CONST_VALUES);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
|
@ -127,7 +127,9 @@ TEST_P(ConvFusionTests, CompareFunctions) {
|
|||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,9 @@ TEST_P(MulAddConversionTests, CompareFunctions) {
|
|||||||
ASSERT_NO_THROW(manager.run_passes(f));
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
f->validate_nodes_and_infer_types();
|
f->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
@ -178,7 +180,9 @@ TEST_P(MulOrAddConversionTests, CompareFunctions) {
|
|||||||
|
|
||||||
f->validate_nodes_and_infer_types();
|
f->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,7 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -311,7 +311,7 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -491,7 +491,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -643,7 +643,7 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -761,7 +761,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -899,7 +899,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf)
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1017,7 +1017,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1202,7 +1202,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1299,7 +1299,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1406,7 +1406,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1554,7 +1554,7 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1814,7 +1814,7 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1975,7 +1975,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2050,7 +2050,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2149,7 +2149,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDown) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2403,7 +2403,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2495,7 +2495,7 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2679,7 +2679,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2760,7 +2760,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2855,5 +2855,5 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) {
|
|||||||
m.run_passes(function);
|
m.run_passes(function);
|
||||||
}
|
}
|
||||||
disable_rt_info_check();
|
disable_rt_info_check();
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,9 @@ TEST_P(ConvertReduceToPoolingTests, CompareFunctions) {
|
|||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,9 @@ void test(std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function>
|
|||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
ASSERT_NO_THROW(manager.run_passes(f));
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
|
||||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
auto res = fc.compare(f, f_ref);
|
auto res = fc.compare(f, f_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
}
|
}
|
||||||
|
@ -650,7 +650,9 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
|
|||||||
///
|
///
|
||||||
|
|
||||||
{ // check precision only
|
{ // check precision only
|
||||||
const auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
const auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
const auto res = fc.compare(f1, f2);
|
const auto res = fc.compare(f1, f2);
|
||||||
EXPECT_FALSE(res.valid);
|
EXPECT_FALSE(res.valid);
|
||||||
EXPECT_THAT(res.message, HasSubstr("Different element type detected"));
|
EXPECT_THAT(res.message, HasSubstr("Different element type detected"));
|
||||||
@ -660,6 +662,7 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
|
|||||||
|
|
||||||
{ // check precision and attributes
|
{ // check precision and attributes
|
||||||
const auto fc = FunctionsComparator::no_default()
|
const auto fc = FunctionsComparator::no_default()
|
||||||
|
.enable(FunctionsComparator::NODES)
|
||||||
.enable(FunctionsComparator::PRECISIONS)
|
.enable(FunctionsComparator::PRECISIONS)
|
||||||
.enable(FunctionsComparator::ATTRIBUTES);
|
.enable(FunctionsComparator::ATTRIBUTES);
|
||||||
const auto res = fc.compare(f1, f2);
|
const auto res = fc.compare(f1, f2);
|
||||||
|
@ -62,7 +62,7 @@ class CompressQuantizeWeightsTests
|
|||||||
function_ref = std::make_shared<Function>(mul, ParameterVector{});
|
function_ref = std::make_shared<Function>(mul, ParameterVector{});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph)
|
|||||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||||
@ -133,7 +133,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
|||||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) {
|
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||||
@ -159,7 +159,7 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimiz
|
|||||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
|
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
|
||||||
@ -175,5 +175,5 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
|
|||||||
manager.register_pass<pass::ZeroPointOptimizer>();
|
manager.register_pass<pass::ZeroPointOptimizer>();
|
||||||
|
|
||||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
enable_accuracy_check();
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,10 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "graph_comparator.hpp"
|
#include "common_test_utils/graph_comparator.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "ie_common.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -20,6 +23,10 @@
|
|||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||||
|
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
inline namespace tools {
|
inline namespace tools {
|
||||||
bool is_type_relaxed(const std::string& type) {
|
bool is_type_relaxed(const std::string& type) {
|
||||||
@ -31,12 +38,12 @@ bool compare_type_info(const ngraph::DiscreteTypeInfo& info1, const ngraph::Disc
|
|||||||
if (!is_type_relaxed(info1.name) && !is_type_relaxed(info2.name) && (info1.version != info2.version)) {
|
if (!is_type_relaxed(info1.name) && !is_type_relaxed(info2.name) && (info1.version != info2.version)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
const std::string info1Name =
|
const std::string info1Name =
|
||||||
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
|
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
|
||||||
const std::string info2Name =
|
const std::string info2Name =
|
||||||
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
|
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
|
||||||
return info1Name == info2Name;
|
return info1Name == info2Name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,9 +128,9 @@ public:
|
|||||||
explicit NodeAndInputDescription(const InputNode& input,
|
explicit NodeAndInputDescription(const InputNode& input,
|
||||||
const Parameter* parameter,
|
const Parameter* parameter,
|
||||||
const InputDescripton* description)
|
const InputDescripton* description)
|
||||||
: m_input(input),
|
: m_input(input),
|
||||||
m_parameter(not_null(parameter)),
|
m_parameter(not_null(parameter)),
|
||||||
m_description(not_null(description)) {}
|
m_description(not_null(description)) {}
|
||||||
|
|
||||||
static bool equal_descriptions(const InputDescripton* lhs, const InputDescripton* rhs) {
|
static bool equal_descriptions(const InputDescripton* lhs, const InputDescripton* rhs) {
|
||||||
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
||||||
@ -172,7 +179,7 @@ public:
|
|||||||
}
|
}
|
||||||
for (size_t i = 0; i != param_shape.size(); ++i) {
|
for (size_t i = 0; i != param_shape.size(); ++i) {
|
||||||
const auto expected_axis_size =
|
const auto expected_axis_size =
|
||||||
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
|
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
|
||||||
if (input_shape[i] != expected_axis_size) {
|
if (input_shape[i] != expected_axis_size) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -215,9 +222,9 @@ public:
|
|||||||
explicit NodeAndOutputDescription(const OutputNode& output,
|
explicit NodeAndOutputDescription(const OutputNode& output,
|
||||||
const Result* result,
|
const Result* result,
|
||||||
const OutputDescription* description)
|
const OutputDescription* description)
|
||||||
: m_output(output),
|
: m_output(output),
|
||||||
m_result(not_null(result)),
|
m_result(not_null(result)),
|
||||||
m_description(not_null(description)) {}
|
m_description(not_null(description)) {}
|
||||||
|
|
||||||
static bool equal_descriptions(const OutputDescription* lhs, const OutputDescription* rhs) {
|
static bool equal_descriptions(const OutputDescription* lhs, const OutputDescription* rhs) {
|
||||||
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
||||||
@ -302,8 +309,8 @@ public:
|
|||||||
using Id = uint64_t;
|
using Id = uint64_t;
|
||||||
|
|
||||||
explicit BackEdge(const Parameter* parameter, const Result* result)
|
explicit BackEdge(const Parameter* parameter, const Result* result)
|
||||||
: m_parameter(not_null(parameter)),
|
: m_parameter(not_null(parameter)),
|
||||||
m_result(not_null(result)) {}
|
m_result(not_null(result)) {}
|
||||||
|
|
||||||
bool result_and_parameter_match() const {
|
bool result_and_parameter_match() const {
|
||||||
return equal_type_and_partial_shape(m_result->output(0), *m_parameter);
|
return equal_type_and_partial_shape(m_result->output(0), *m_parameter);
|
||||||
@ -539,103 +546,113 @@ Comparator::Result Comparator::compare(const std::shared_ptr<ngraph::Function>&
|
|||||||
* + Check node attributes by Visitor API
|
* + Check node attributes by Visitor API
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto f_results = f->get_results();
|
if (should_compare(CmpValues::NODES)) {
|
||||||
auto f_ref_results = f_ref->get_results();
|
auto f_results = f->get_results();
|
||||||
|
auto f_ref_results = f_ref->get_results();
|
||||||
|
|
||||||
auto cmp = less_by_name;
|
auto cmp = less_by_name;
|
||||||
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
|
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
|
||||||
// name An in case of multiple names we sort Result operation using their parent node names
|
// name An in case of multiple names we sort Result operation using their parent node names
|
||||||
if (std::any_of(f_results.begin(),
|
if (std::any_of(f_results.begin(),
|
||||||
f_results.end(),
|
f_results.end(),
|
||||||
[](const std::shared_ptr<ngraph::Node>& node) {
|
[](const std::shared_ptr<ngraph::Node> &node) {
|
||||||
const auto& t = node->input_value(0).get_tensor_ptr();
|
const auto &t = node->input_value(0).get_tensor_ptr();
|
||||||
return t->get_names().size() > 1;
|
return t->get_names().size() > 1;
|
||||||
}) ||
|
}) ||
|
||||||
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node>& node) {
|
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node> &node) {
|
||||||
const auto& t = node->input_value(0).get_tensor_ptr();
|
const auto &t = node->input_value(0).get_tensor_ptr();
|
||||||
return t->get_names().size() > 1;
|
return t->get_names().size() > 1;
|
||||||
})) {
|
})) {
|
||||||
cmp = less_by_parent_name;
|
cmp = less_by_parent_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(f_results.begin(), f_results.end(), cmp);
|
std::sort(f_results.begin(), f_results.end(), cmp);
|
||||||
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
|
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
|
||||||
|
|
||||||
if (f_results.size() != f_ref_results.size()) {
|
if (f_results.size() != f_ref_results.size()) {
|
||||||
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
|
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
|
||||||
to_str(f_ref_results.size()));
|
to_str(f_ref_results.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& f_sinks = f->get_sinks();
|
const auto &f_sinks = f->get_sinks();
|
||||||
const auto& f_ref_sinks = f_ref->get_sinks();
|
const auto &f_ref_sinks = f_ref->get_sinks();
|
||||||
if (f_sinks.size() != f_ref_sinks.size()) {
|
if (f_sinks.size() != f_ref_sinks.size()) {
|
||||||
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
|
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
|
||||||
to_str(f_ref_sinks.size()));
|
to_str(f_ref_sinks.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compare sinks
|
// Compare sinks
|
||||||
if (f_sinks.size() == 1) {
|
if (f_sinks.size() == 1) {
|
||||||
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
|
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
|
||||||
used.insert(f_sinks[0].get());
|
used.insert(f_sinks[0].get());
|
||||||
} else {
|
} else {
|
||||||
// Cast to Assign and find those that have same variable_id suffix
|
// Cast to Assign and find those that have same variable_id suffix
|
||||||
for (const auto& sink1 : f_sinks) {
|
for (const auto &sink1 : f_sinks) {
|
||||||
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
|
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
|
||||||
if (!assign1) {
|
if (!assign1) {
|
||||||
return Result::error("Sink '" + name(sink1) +
|
return Result::error("Sink '" + name(sink1) +
|
||||||
"' is not a variable - graph comparison is not supported");
|
|
||||||
}
|
|
||||||
auto name1 = assign1->get_variable_id();
|
|
||||||
std::shared_ptr<ov::op::Sink> found_sink2;
|
|
||||||
for (const auto& sink2 : f_ref_sinks) {
|
|
||||||
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
|
|
||||||
if (!assign2) {
|
|
||||||
return Result::error("Sink '" + name(sink2) +
|
|
||||||
"' is not a variable - graph comparison is not supported");
|
"' is not a variable - graph comparison is not supported");
|
||||||
}
|
}
|
||||||
auto name2 = assign2->get_variable_id();
|
auto name1 = assign1->get_variable_id();
|
||||||
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
|
std::shared_ptr<ov::op::Sink> found_sink2;
|
||||||
found_sink2 = sink2;
|
for (const auto &sink2 : f_ref_sinks) {
|
||||||
break;
|
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
|
||||||
|
if (!assign2) {
|
||||||
|
return Result::error("Sink '" + name(sink2) +
|
||||||
|
"' is not a variable - graph comparison is not supported");
|
||||||
|
}
|
||||||
|
auto name2 = assign2->get_variable_id();
|
||||||
|
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
|
||||||
|
found_sink2 = sink2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found_sink2) {
|
||||||
|
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
|
||||||
|
}
|
||||||
|
q.push({sink1.get(), found_sink2.get()});
|
||||||
|
used.insert(sink1.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < f_results.size(); ++i) {
|
||||||
|
if (should_compare(CmpValues::NAMES)) {
|
||||||
|
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
|
||||||
|
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
|
||||||
|
return Result::error(
|
||||||
|
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) +
|
||||||
|
" and " +
|
||||||
|
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!found_sink2) {
|
q.push({f_results[i].get(), f_ref_results[i].get()});
|
||||||
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
|
used.insert(f_results[i].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream errors;
|
||||||
|
|
||||||
|
while (!q.empty()) {
|
||||||
|
ngraph::Node *const node1 = q.front().first;
|
||||||
|
ngraph::Node *const node2 = q.front().second;
|
||||||
|
q.pop();
|
||||||
|
|
||||||
|
const auto result = compare(node1, node2, errors);
|
||||||
|
if (!result.valid) {
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
q.push({sink1.get(), found_sink2.get()});
|
|
||||||
used.insert(sink1.get());
|
add_nodes_inputs_to_queue(node1, node2);
|
||||||
}
|
}
|
||||||
|
const auto msg = errors.str();
|
||||||
|
return msg.empty() ? Result::ok() : Result::error(msg);
|
||||||
|
|
||||||
|
} else if (should_compare(CmpValues::ACCURACY)) {
|
||||||
|
auto status = accuracy_check(f_ref, f);
|
||||||
|
return status.status ? Result::ok() : Result::error(status.message);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return Result::error("CmpValues are invalid");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < f_results.size(); ++i) {
|
|
||||||
if (should_compare(CmpValues::NAMES)) {
|
|
||||||
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
|
|
||||||
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
|
|
||||||
return Result::error(
|
|
||||||
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) + " and " +
|
|
||||||
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
q.push({f_results[i].get(), f_ref_results[i].get()});
|
|
||||||
used.insert(f_results[i].get());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::stringstream errors;
|
|
||||||
|
|
||||||
while (!q.empty()) {
|
|
||||||
ngraph::Node* const node1 = q.front().first;
|
|
||||||
ngraph::Node* const node2 = q.front().second;
|
|
||||||
q.pop();
|
|
||||||
|
|
||||||
const auto result = compare(node1, node2, errors);
|
|
||||||
if (!result.valid) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_nodes_inputs_to_queue(node1, node2);
|
|
||||||
}
|
|
||||||
const auto msg = errors.str();
|
|
||||||
return msg.empty() ? Result::ok() : Result::error(msg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Comparator::Result Comparator::compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
Comparator::Result Comparator::compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
||||||
@ -759,7 +776,7 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
|
|||||||
void Comparator::compare_nodes(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
void Comparator::compare_nodes(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
||||||
if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(*node1, *node2, err_log)) {
|
if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(*node1, *node2, err_log)) {
|
||||||
err_log << "Different runtime info detected\n" + name(node1) + " and " + name(node2) +
|
err_log << "Different runtime info detected\n" + name(node1) + " and " + name(node2) +
|
||||||
" not equal runtime info.\n";
|
" not equal runtime info.\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (should_compare(CmpValues::ATTRIBUTES)) {
|
if (should_compare(CmpValues::ATTRIBUTES)) {
|
||||||
@ -805,123 +822,115 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name) {
|
|
||||||
output.get_tensor_ptr()->set_names({name});
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names) {
|
|
||||||
output.get_tensor_ptr()->set_names(names);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace attributes {
|
namespace attributes {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||||
insert(name, inputs->get());
|
insert(name, inputs->get());
|
||||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||||
insert(name, outputs->get());
|
insert(name, outputs->get());
|
||||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||||
// drop comparison, no more info than port indexes which will be check in
|
// drop comparison, no more info than port indexes which will be check in
|
||||||
// subgraph::compare_io
|
// subgraph::compare_io
|
||||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||||
&adapter)) {
|
&adapter)) {
|
||||||
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
||||||
const auto end = beg + a->get()->size();
|
const auto end = beg + a->get()->size();
|
||||||
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
||||||
} else if (auto framework_node_attr =
|
} else if (auto framework_node_attr =
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||||
insert(name, framework_node_attr->get());
|
insert(name, framework_node_attr->get());
|
||||||
} else if (auto variable_ptr =
|
} else if (auto variable_ptr =
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||||
insert(name, variable_ptr->get());
|
insert(name, variable_ptr->get());
|
||||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||||
insert(name, shape_ptr->get());
|
insert(name, shape_ptr->get());
|
||||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||||
insert(name, dim_ptr->get());
|
insert(name, dim_ptr->get());
|
||||||
} else {
|
} else {
|
||||||
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||||
adapter.get_type_info().name + "']";
|
adapter.get_type_info().name + "']";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
template <typename AttrValue>
|
||||||
template <typename AttrValue>
|
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
|
||||||
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
|
if (should_return()) {
|
||||||
if (should_return()) {
|
return;
|
||||||
return;
|
}
|
||||||
}
|
m_visited_attributes.insert(name);
|
||||||
m_visited_attributes.insert(name);
|
const auto ref_value = m_attr_ref.get<AttrValue>(name);
|
||||||
const auto ref_value = m_attr_ref.get<AttrValue>(name);
|
if (!ref_value) {
|
||||||
if (!ref_value) {
|
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
return;
|
||||||
return;
|
}
|
||||||
|
|
||||||
|
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
|
||||||
|
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
|
||||||
|
str::Get<AttrValue>::value(attr_value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
|
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
|
||||||
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
|
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
|
||||||
str::Get<AttrValue>::value(attr_value);
|
if (should_return()) {
|
||||||
}
|
return;
|
||||||
}
|
}
|
||||||
|
m_visited_attributes.insert(name);
|
||||||
|
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
||||||
|
if (!ref_value) {
|
||||||
|
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
|
if (buffer->size() != ref_value->size() ||
|
||||||
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
|
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
|
||||||
if (should_return()) {
|
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
m_visited_attributes.insert(name);
|
|
||||||
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
|
||||||
if (!ref_value) {
|
|
||||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (buffer->size() != ref_value->size() ||
|
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
|
||||||
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
|
if (should_return()) {
|
||||||
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
return;
|
||||||
return;
|
}
|
||||||
|
m_visited_attributes.insert(name);
|
||||||
|
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
|
||||||
|
if (!ref_value) {
|
||||||
|
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Comparator c(m_check_flags);
|
||||||
|
const auto result = c.compare(*ref_value, adapter.get());
|
||||||
|
if (!result.valid) {
|
||||||
|
m_cmp_result += result.message;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
|
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||||
if (should_return()) {
|
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||||
return;
|
verify(name, inputs->get());
|
||||||
|
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||||
|
verify(name, outputs->get());
|
||||||
|
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||||
|
// drop comparison, no more info than port indexes which will be check in
|
||||||
|
// subgraph::compare_io
|
||||||
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||||
|
&adapter)) {
|
||||||
|
verify_mem_buf(name, a->get());
|
||||||
|
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||||
|
verify(name, attrs->get());
|
||||||
|
} else if (auto variable_ptr =
|
||||||
|
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||||
|
verify(name, variable_ptr->get());
|
||||||
|
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||||
|
verify(name, shape_ptr->get());
|
||||||
|
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||||
|
verify(name, dim_ptr->get());
|
||||||
|
} else {
|
||||||
|
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||||
|
adapter.get_type_info().name + "']";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m_visited_attributes.insert(name);
|
|
||||||
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
|
|
||||||
if (!ref_value) {
|
|
||||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Comparator c(m_check_flags);
|
|
||||||
const auto result = c.compare(*ref_value, adapter.get());
|
|
||||||
if (!result.valid) {
|
|
||||||
m_cmp_result += result.message;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
|
||||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
|
||||||
verify(name, inputs->get());
|
|
||||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
|
||||||
verify(name, outputs->get());
|
|
||||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
|
||||||
// drop comparison, no more info than port indexes which will be check in
|
|
||||||
// subgraph::compare_io
|
|
||||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
|
||||||
&adapter)) {
|
|
||||||
verify_mem_buf(name, a->get());
|
|
||||||
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
|
||||||
verify(name, attrs->get());
|
|
||||||
} else if (auto variable_ptr =
|
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
|
||||||
verify(name, variable_ptr->get());
|
|
||||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
|
||||||
verify(name, shape_ptr->get());
|
|
||||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
|
||||||
verify(name, dim_ptr->get());
|
|
||||||
} else {
|
|
||||||
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
|
||||||
adapter.get_type_info().name + "']";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
@ -937,3 +946,42 @@ Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace attributes
|
} // namespace attributes
|
||||||
|
|
||||||
|
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
|
||||||
|
const std::shared_ptr<ov::Model>& cur_function) {
|
||||||
|
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
|
||||||
|
return AccuracyCheckResult{true, ""};
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
IE_ASSERT(ref_function->get_parameters().size() == cur_function->get_parameters().size());
|
||||||
|
|
||||||
|
std::map<std::shared_ptr<ov::Node>, ov::Tensor> ref_input_data;
|
||||||
|
std::map<std::shared_ptr<ov::Node>, ov::Tensor> cur_input_data;
|
||||||
|
for (size_t i = 0; i < ref_function->get_parameters().size(); i++) {
|
||||||
|
const auto &tensor = ov::test::utils::create_and_fill_tensor(
|
||||||
|
ref_function->get_parameters()[i]->get_element_type(),
|
||||||
|
ref_function->get_parameters()[i]->get_shape());
|
||||||
|
ref_input_data[ref_function->get_parameters()[i]] = tensor;
|
||||||
|
cur_input_data[cur_function->get_parameters()[i]] = tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ref_outputs = ngraph::helpers::interpretFunction(ref_function, ref_input_data);
|
||||||
|
auto outputs = ngraph::helpers::interpretFunction(cur_function, cur_input_data);
|
||||||
|
|
||||||
|
IE_ASSERT(ref_outputs.size() == outputs.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < ref_outputs.size(); i++) {
|
||||||
|
ov::test::utils::compare(ref_outputs[i], outputs[i],
|
||||||
|
std::numeric_limits<double>::max(),
|
||||||
|
std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::runtime_error &re) {
|
||||||
|
return AccuracyCheckResult{false, re.what()};
|
||||||
|
} catch (const std::exception &ex) {
|
||||||
|
return AccuracyCheckResult{false, ex.what()};
|
||||||
|
} catch (...) {
|
||||||
|
return AccuracyCheckResult{false, ""};
|
||||||
|
}
|
||||||
|
return AccuracyCheckResult{true, ""};
|
||||||
|
}
|
||||||
|
@ -20,12 +20,14 @@ class FunctionsComparator {
|
|||||||
public:
|
public:
|
||||||
enum CmpValues {
|
enum CmpValues {
|
||||||
NONE = 0,
|
NONE = 0,
|
||||||
CONST_VALUES = 1 << 0,
|
NODES = 1 << 0,
|
||||||
NAMES = 1 << 1,
|
CONST_VALUES = 1 << 1,
|
||||||
RUNTIME_KEYS = 1 << 2,
|
NAMES = 1 << 2,
|
||||||
PRECISIONS = 1 << 3,
|
RUNTIME_KEYS = 1 << 3,
|
||||||
ATTRIBUTES = 1 << 4,
|
PRECISIONS = 1 << 4,
|
||||||
TENSOR_NAMES = 1 << 5,
|
ATTRIBUTES = 1 << 5,
|
||||||
|
TENSOR_NAMES = 1 << 6,
|
||||||
|
ACCURACY = 1 << 7
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Result {
|
struct Result {
|
||||||
@ -45,6 +47,7 @@ public:
|
|||||||
}
|
}
|
||||||
static FunctionsComparator with_default() noexcept {
|
static FunctionsComparator with_default() noexcept {
|
||||||
auto fc = FunctionsComparator::no_default();
|
auto fc = FunctionsComparator::no_default();
|
||||||
|
fc.enable(NODES);
|
||||||
fc.enable(PRECISIONS);
|
fc.enable(PRECISIONS);
|
||||||
fc.enable(TENSOR_NAMES);
|
fc.enable(TENSOR_NAMES);
|
||||||
return fc;
|
return fc;
|
||||||
@ -55,6 +58,11 @@ public:
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionsComparator& disable(CmpValues f) noexcept {
|
||||||
|
m_comparison_flags = static_cast<CmpValues>(m_comparison_flags & ~f);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
bool should_compare(CmpValues f) const noexcept {
|
bool should_compare(CmpValues f) const noexcept {
|
||||||
return m_comparison_flags & f;
|
return m_comparison_flags & f;
|
||||||
}
|
}
|
||||||
@ -85,6 +93,7 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
|
|||||||
auto fc = FunctionsComparator::no_default();
|
auto fc = FunctionsComparator::no_default();
|
||||||
|
|
||||||
using Cmp = FunctionsComparator::CmpValues;
|
using Cmp = FunctionsComparator::CmpValues;
|
||||||
|
fc.enable(Cmp::NODES);
|
||||||
if (compareConstValues)
|
if (compareConstValues)
|
||||||
fc.enable(Cmp::CONST_VALUES);
|
fc.enable(Cmp::CONST_VALUES);
|
||||||
if (compareNames)
|
if (compareNames)
|
||||||
@ -104,10 +113,6 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
|
|||||||
|
|
||||||
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
|
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
|
||||||
|
|
||||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name);
|
|
||||||
|
|
||||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names);
|
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class InjectionPass : public ov::pass::ModelPass {
|
class InjectionPass : public ov::pass::ModelPass {
|
||||||
@ -188,8 +193,8 @@ public:
|
|||||||
for (auto output : node->outputs()) {
|
for (auto output : node->outputs()) {
|
||||||
const auto& tensor_names = output.get_names();
|
const auto& tensor_names = output.get_names();
|
||||||
if (std::any_of(tensor_names.begin(), tensor_names.end(), [&](const std::string& name) {
|
if (std::any_of(tensor_names.begin(), tensor_names.end(), [&](const std::string& name) {
|
||||||
return unique_tensor_names.count(name);
|
return unique_tensor_names.count(name);
|
||||||
})) {
|
})) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Node: " << node->get_type_info() << " with name " << node->get_friendly_name() << " ";
|
ss << "Node: " << node->get_type_info() << " with name " << node->get_friendly_name() << " ";
|
||||||
ss << "has non unique tensor name.";
|
ss << "has non unique tensor name.";
|
||||||
@ -478,9 +483,9 @@ public:
|
|||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
||||||
|
|
||||||
#define ON_ADAPTER(TYPE) \
|
#define ON_ADAPTER(TYPE) \
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||||
insert(name, adapter.get()); \
|
insert(name, adapter.get()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
ON_ADAPTER(bool)
|
ON_ADAPTER(bool)
|
||||||
ON_ADAPTER(std::string)
|
ON_ADAPTER(std::string)
|
||||||
@ -711,37 +716,37 @@ struct Equal<std::shared_ptr<Constant>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch (lhs_t) {
|
switch (lhs_t) {
|
||||||
case ngraph::element::Type_t::u1: {
|
case ngraph::element::Type_t::u1: {
|
||||||
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
|
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
|
||||||
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
|
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
|
||||||
const auto lhs_bit_size = shape_size(lhs->get_shape());
|
const auto lhs_bit_size = shape_size(lhs->get_shape());
|
||||||
const auto rhs_bit_size = shape_size(rhs->get_shape());
|
const auto rhs_bit_size = shape_size(rhs->get_shape());
|
||||||
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
|
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
|
||||||
}
|
}
|
||||||
case ngraph::element::Type_t::bf16: {
|
case ngraph::element::Type_t::bf16: {
|
||||||
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
|
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
|
||||||
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
|
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
|
||||||
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
|
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ngraph::element::Type_t::f16: {
|
case ngraph::element::Type_t::f16: {
|
||||||
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
|
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
|
||||||
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
|
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
|
||||||
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
|
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ngraph::element::Type_t::f32: {
|
case ngraph::element::Type_t::f32: {
|
||||||
const auto& lhs_v = lhs->cast_vector<float>();
|
const auto& lhs_v = lhs->cast_vector<float>();
|
||||||
const auto& rhs_v = rhs->cast_vector<float>();
|
const auto& rhs_v = rhs->cast_vector<float>();
|
||||||
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
|
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
const auto& lhs_v = lhs->cast_vector<double>();
|
const auto& lhs_v = lhs->cast_vector<double>();
|
||||||
const auto& rhs_v = rhs->cast_vector<double>();
|
const auto& rhs_v = rhs->cast_vector<double>();
|
||||||
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
|
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -862,18 +867,18 @@ struct Get<std::shared_ptr<ngraph::Variable>, void> {
|
|||||||
class ReadAndCompareAttributes : public ngraph::AttributeVisitor {
|
class ReadAndCompareAttributes : public ngraph::AttributeVisitor {
|
||||||
public:
|
public:
|
||||||
ReadAndCompareAttributes(const ReadAndStoreAttributes& ref, Comparator::CmpValues check_flags)
|
ReadAndCompareAttributes(const ReadAndStoreAttributes& ref, Comparator::CmpValues check_flags)
|
||||||
: m_attr_ref(ref),
|
: m_attr_ref(ref),
|
||||||
m_cmp_result{ref.read_result()},
|
m_cmp_result{ref.read_result()},
|
||||||
m_check_flags(check_flags) {}
|
m_check_flags(check_flags) {}
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||||
verify_others(name, adapter);
|
verify_others(name, adapter);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define ON_ADAPTER(TYPE) \
|
#define ON_ADAPTER(TYPE) \
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||||
verify(name, adapter.get()); \
|
verify(name, adapter.get()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
ON_ADAPTER(bool)
|
ON_ADAPTER(bool)
|
||||||
ON_ADAPTER(std::string)
|
ON_ADAPTER(std::string)
|
||||||
@ -981,3 +986,11 @@ private:
|
|||||||
Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags);
|
Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags);
|
||||||
|
|
||||||
} // namespace attributes
|
} // namespace attributes
|
||||||
|
|
||||||
|
struct AccuracyCheckResult {
|
||||||
|
bool status;
|
||||||
|
std::string message;
|
||||||
|
};
|
||||||
|
|
||||||
|
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
|
||||||
|
const std::shared_ptr<ov::Model>& cur_function);
|
||||||
|
@ -2,248 +2,7 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <functional_test_utils/include/functional_test_utils/blob_utils.hpp>
|
|
||||||
#include "ngraph_test_utils.hpp"
|
#include "ngraph_test_utils.hpp"
|
||||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template<class T_IE, class T_NGRAPH>
|
|
||||||
static void Compare(const T_NGRAPH *expected, const T_IE *actual, std::size_t size, float threshold, float abs_threshold = -1.f) {
|
|
||||||
for (std::size_t i = 0; i < size; ++i) {
|
|
||||||
const T_NGRAPH &ref = expected[i];
|
|
||||||
const auto &res = actual[i];
|
|
||||||
const auto absoluteDifference = CommonTestUtils::ie_abs(res - ref);
|
|
||||||
if (abs_threshold > 0.f && absoluteDifference > abs_threshold) {
|
|
||||||
IE_THROW() << "Absolute comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
|
|
||||||
<< " at index " << i << " with absolute threshold " << abs_threshold
|
|
||||||
<< " failed";
|
|
||||||
}
|
|
||||||
if (absoluteDifference <= threshold) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
double max;
|
|
||||||
if (sizeof(T_IE) < sizeof(T_NGRAPH)) {
|
|
||||||
max = std::max(CommonTestUtils::ie_abs(T_NGRAPH(res)), CommonTestUtils::ie_abs(ref));
|
|
||||||
} else {
|
|
||||||
max = std::max(CommonTestUtils::ie_abs(res), CommonTestUtils::ie_abs(T_IE(ref)));
|
|
||||||
}
|
|
||||||
double diff = static_cast<float>(absoluteDifference) / max;
|
|
||||||
if (max == 0 || (diff > static_cast<float>(threshold)) ||
|
|
||||||
(std::isnan(static_cast<float>(res)) ^ std::isnan(static_cast<float>(ref)))) {
|
|
||||||
IE_THROW() << "Relative comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
|
|
||||||
<< " at index " << i << " with threshold " << threshold
|
|
||||||
<< " failed";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T_IE>
|
|
||||||
void callCompare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
|
|
||||||
const T_IE* actualBuffer, size_t size, float threshold, float abs_threshold) {
|
|
||||||
auto expectedBuffer = expected.second.data();
|
|
||||||
switch (expected.first) {
|
|
||||||
case ngraph::element::Type_t::i64:
|
|
||||||
Compare<T_IE, int64_t>(reinterpret_cast<const int64_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::i32:
|
|
||||||
Compare<T_IE, int32_t>(reinterpret_cast<const int32_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::i16:
|
|
||||||
Compare<T_IE, int16_t>(reinterpret_cast<const int16_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::i8:
|
|
||||||
Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::u64:
|
|
||||||
Compare<T_IE, uint64_t>(reinterpret_cast<const uint64_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::u32:
|
|
||||||
Compare<T_IE, uint32_t>(reinterpret_cast<const uint32_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::u16:
|
|
||||||
Compare<T_IE, uint16_t>(reinterpret_cast<const uint16_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::boolean:
|
|
||||||
case ngraph::element::Type_t::u8:
|
|
||||||
Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::f64:
|
|
||||||
Compare<T_IE, double>(reinterpret_cast<const double *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::f32:
|
|
||||||
Compare<T_IE, float>(reinterpret_cast<const float *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::f16:
|
|
||||||
Compare<T_IE, ngraph::float16>(reinterpret_cast<const ngraph::float16 *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case ngraph::element::Type_t::bf16:
|
|
||||||
Compare<T_IE, ngraph::bfloat16>(reinterpret_cast<const ngraph::bfloat16 *>(expectedBuffer),
|
|
||||||
actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
// case ngraph::element::Type_t::i4: {
|
|
||||||
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
|
|
||||||
// expected.second,
|
|
||||||
// expected.first,
|
|
||||||
// ngraph::element::Type_t::i8,
|
|
||||||
// size);
|
|
||||||
// Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedOut.data()),
|
|
||||||
// actualBuffer, size, threshold, abs_threshold);
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// case ngraph::element::Type_t::u4: {
|
|
||||||
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
|
|
||||||
// expected.second,
|
|
||||||
// expected.first,
|
|
||||||
// ngraph::element::Type_t::u8,
|
|
||||||
// size);
|
|
||||||
// Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedOut.data()),
|
|
||||||
// actualBuffer, size, threshold, abs_threshold);
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
case ngraph::element::Type_t::dynamic:
|
|
||||||
case ngraph::element::Type_t::undefined:
|
|
||||||
Compare<T_IE, T_IE>(reinterpret_cast<const T_IE *>(expectedBuffer), actualBuffer, size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
default: FAIL() << "Comparator for " << expected.first << " precision isn't supported";
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void Compare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
|
|
||||||
const InferenceEngine::Blob::Ptr &actual,
|
|
||||||
float threshold,
|
|
||||||
float abs_threshold) {
|
|
||||||
const auto &precision = actual->getTensorDesc().getPrecision();
|
|
||||||
auto k = static_cast<float>(expected.first.size()) / precision.size();
|
|
||||||
// W/A for int4, uint4
|
|
||||||
if (expected.first == ngraph::element::Type_t::u4 || expected.first == ngraph::element::Type_t::i4) {
|
|
||||||
k /= 2;
|
|
||||||
} else if (expected.first == ngraph::element::Type_t::undefined || expected.first == ngraph::element::Type_t::dynamic) {
|
|
||||||
k = 1;
|
|
||||||
}
|
|
||||||
ASSERT_EQ(expected.second.size(), actual->byteSize() * k);
|
|
||||||
|
|
||||||
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(actual);
|
|
||||||
IE_ASSERT(memory);
|
|
||||||
const auto lockedMemory = memory->wmap();
|
|
||||||
const auto actualBuffer = lockedMemory.as<const std::uint8_t *>();
|
|
||||||
|
|
||||||
const auto &size = actual->size();
|
|
||||||
switch (precision) {
|
|
||||||
case InferenceEngine::Precision::FP32:
|
|
||||||
callCompare<float>(expected, reinterpret_cast<const float *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::I32:
|
|
||||||
callCompare<int32_t>(expected, reinterpret_cast<const int32_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::I64:
|
|
||||||
callCompare<int64_t>(expected, reinterpret_cast<const int64_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::I8:
|
|
||||||
callCompare<int8_t>(expected, reinterpret_cast<const int8_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::U16:
|
|
||||||
callCompare<uint16_t>(expected, reinterpret_cast<const uint16_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::I16:
|
|
||||||
callCompare<int16_t>(expected, reinterpret_cast<const int16_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::BOOL:
|
|
||||||
case InferenceEngine::Precision::U8:
|
|
||||||
callCompare<uint8_t>(expected, reinterpret_cast<const uint8_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::U64:
|
|
||||||
callCompare<uint64_t>(expected, reinterpret_cast<const uint64_t *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::BF16:
|
|
||||||
callCompare<ngraph::bfloat16>(expected, reinterpret_cast<const ngraph::bfloat16 *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
case InferenceEngine::Precision::FP16:
|
|
||||||
callCompare<ngraph::float16>(expected, reinterpret_cast<const ngraph::float16 *>(actualBuffer), size, threshold, abs_threshold);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
FAIL() << "Comparator for " << precision << " precision isn't supported";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Compare(const std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> &expectedOutputs,
|
|
||||||
const std::vector<InferenceEngine::Blob::Ptr> &actualOutputs,
|
|
||||||
float threshold, float abs_threshold) {
|
|
||||||
for (std::size_t outputIndex = 0; outputIndex < expectedOutputs.size(); ++outputIndex) {
|
|
||||||
const auto &expected = expectedOutputs[outputIndex];
|
|
||||||
const auto &actual = actualOutputs[outputIndex];
|
|
||||||
Compare(expected, actual, threshold, abs_threshold);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_function,
|
|
||||||
std::shared_ptr<ov::Model> cur_function) {
|
|
||||||
try {
|
|
||||||
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::vector<std::vector<uint8_t>> input_data;
|
|
||||||
ngraph::element::TypeVector types;
|
|
||||||
for (auto param : ref_function->get_parameters()) {
|
|
||||||
types.push_back(param->get_element_type());
|
|
||||||
|
|
||||||
auto layout = InferenceEngine::Layout::ANY;
|
|
||||||
if (ov::is_scalar(param->get_shape())) {
|
|
||||||
layout = InferenceEngine::Layout::SCALAR;
|
|
||||||
}
|
|
||||||
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout);
|
|
||||||
const auto &input = FuncTestUtils::createAndFillBlob(td);
|
|
||||||
const auto &input_size = input->byteSize();
|
|
||||||
|
|
||||||
std::vector<uint8_t> data;
|
|
||||||
data.resize(input_size);
|
|
||||||
|
|
||||||
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(input);
|
|
||||||
IE_ASSERT(memory);
|
|
||||||
|
|
||||||
const auto lockedMemory = memory->wmap();
|
|
||||||
const auto buffer = lockedMemory.as<const std::uint8_t *>();
|
|
||||||
std::copy(buffer, buffer + input_size, data.data());
|
|
||||||
|
|
||||||
input_data.push_back(std::move(data));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ref_outputs = ngraph::helpers::interpreterFunction(ref_function, input_data, types);
|
|
||||||
auto outputs = ngraph::helpers::interpreterFunction(cur_function, input_data, types);
|
|
||||||
|
|
||||||
IE_ASSERT(ref_outputs.size() == outputs.size());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ref_outputs.size(); ++i) {
|
|
||||||
IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size());
|
|
||||||
auto * ref = reinterpret_cast<float *>(ref_outputs[i].second.data());
|
|
||||||
auto * out = reinterpret_cast<float *>(outputs[i].second.data());
|
|
||||||
size_t size = ref_outputs[i].second.size() / sizeof(float);
|
|
||||||
IE_ASSERT(size > 0);
|
|
||||||
Compare<float, float>(ref, out, size, 1e-5);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (const std::runtime_error &re) {
|
|
||||||
GTEST_FATAL_FAILURE_(re.what());
|
|
||||||
} catch (const std::exception &ex) {
|
|
||||||
GTEST_FATAL_FAILURE_(ex.what());
|
|
||||||
} catch (...) {
|
|
||||||
GTEST_FATAL_FAILURE_("Unknown failure occurred.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
|
@ -32,6 +32,7 @@ class TransformationTestsF : public CommonTestUtils::TestsCommon {
|
|||||||
public:
|
public:
|
||||||
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
||||||
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||||
// TODO: enable attributes and constant values comparison by default XXX-68694
|
// TODO: enable attributes and constant values comparison by default XXX-68694
|
||||||
@ -56,12 +57,16 @@ public:
|
|||||||
ASSERT_NO_THROW(check_rt_info(function));
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
||||||
|
auto acc_comparator = FunctionsComparator::no_default();
|
||||||
|
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
auto res = acc_comparator.compare(function, cloned_function);
|
||||||
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
|
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
}
|
||||||
|
|
||||||
auto res = comparator.compare(function, function_ref);
|
auto res = comparator.compare(function, function_ref);
|
||||||
ASSERT_TRUE(res.valid) << res.message;
|
ASSERT_TRUE(res.valid) << res.message;
|
||||||
|
|
||||||
if (m_enable_accuracy_check) {
|
|
||||||
accuracy_check(cloned_function, function);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
||||||
@ -74,12 +79,6 @@ public:
|
|||||||
m_soft_names_comparison = true;
|
m_soft_names_comparison = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_accuracy_check() {
|
|
||||||
m_enable_accuracy_check = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void accuracy_check(std::shared_ptr<ov::Model> ref_function, std::shared_ptr<ov::Model> cur_function);
|
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> function, function_ref;
|
std::shared_ptr<ov::Model> function, function_ref;
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
FunctionsComparator comparator;
|
FunctionsComparator comparator;
|
||||||
@ -88,7 +87,6 @@ private:
|
|||||||
std::shared_ptr<ngraph::pass::UniqueNamesHolder> m_unh;
|
std::shared_ptr<ngraph::pass::UniqueNamesHolder> m_unh;
|
||||||
bool m_disable_rt_info_check{false};
|
bool m_disable_rt_info_check{false};
|
||||||
bool m_soft_names_comparison{true};
|
bool m_soft_names_comparison{true};
|
||||||
bool m_enable_accuracy_check{false};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh);
|
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh);
|
||||||
|
Loading…
Reference in New Issue
Block a user