Update Accuracy Check inside Func tests (#11314)
This commit is contained in:
parent
ee31b648d1
commit
8317493e65
@ -92,7 +92,9 @@ TEST_P(FQMulFusion, ExpectFusion) {
|
||||
manager.run_passes(m_function);
|
||||
ASSERT_NO_THROW(check_rt_info(m_function));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::NODES);
|
||||
auto res = fc.compare(m_function, m_expected_function);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
};
|
||||
|
@ -113,7 +113,9 @@ TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::NODES);
|
||||
auto res = fc.compare(f, ref_f);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
|
||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
||||
@ -57,7 +57,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {
|
||||
|
@ -998,7 +998,7 @@ TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
if (type == element::f32) {
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1088,5 +1088,5 @@ TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) {
|
||||
manager.register_pass<pass::NopElimination>();
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
@ -120,11 +120,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
manager.register_pass<pass::PadFusion>();
|
||||
@ -134,7 +134,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
||||
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
@ -244,11 +244,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
|
||||
std::shared_ptr<Node> pad_value = opset5::Constant::create(element::f16, Shape{}, {0});
|
||||
pad_value = std::make_shared<opset5::Convert>(pad_value, element::f32);
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, pad_value, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
manager.register_pass<pass::PadFusion>();
|
||||
@ -258,7 +258,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
|
||||
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
@ -422,30 +422,29 @@ TEST_F(TransformationTestsF, NegativePadFusionConvolutionBackpropDataTooSmallPad
|
||||
Shape data_shape{1, 3, 14, 14};
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
set_tensor_name(data, "data");
|
||||
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
manager.register_pass<pass::PadFusion>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
|
||||
set_tensor_name(data, "data");
|
||||
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
set_tensor_name(pad, "pad");
|
||||
|
||||
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
|
||||
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
|
||||
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
|
||||
set_tensor_name(conv, "conv");
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ TEST_F(TransformationTestsF, RICFusionSimple) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionHard) {
|
||||
@ -198,7 +198,7 @@ TEST_F(TransformationTestsF, RICFusionHard) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionDynamic) {
|
||||
@ -248,7 +248,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwise2) {
|
||||
@ -273,7 +273,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise2) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwise3) {
|
||||
@ -298,7 +298,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise3) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwise4) {
|
||||
@ -324,7 +324,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwise5) {
|
||||
@ -350,7 +350,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwiseNegative) {
|
||||
@ -390,7 +390,7 @@ TEST_F(TransformationTestsF, RICFusionEltwiseTwoRIC) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionEltwiseNegative3) {
|
||||
@ -431,7 +431,7 @@ TEST_F(TransformationTestsF, RICFusionGroupConv) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionGroupConvNegative) {
|
||||
@ -473,7 +473,7 @@ TEST_F(TransformationTestsF, RICFusionTranspose) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
|
||||
@ -499,7 +499,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
|
||||
@ -538,7 +538,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
|
||||
@ -578,7 +578,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, RICFusionShapeOf) {
|
||||
@ -749,7 +749,7 @@ TEST_F(TransformationTestsF, FuseScaleValue) {
|
||||
}
|
||||
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, FuseScaleValues) {
|
||||
@ -779,5 +779,5 @@ TEST_F(TransformationTestsF, FuseScaleValues) {
|
||||
}
|
||||
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
@ -90,7 +90,9 @@ protected:
|
||||
};
|
||||
|
||||
TEST_P(ShuffleChannelsFusion, CompareFunctions) {
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::NODES);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -55,7 +55,9 @@ TEST_P(SoftmaxFusionFixture, SoftmaxFusion) {
|
||||
f_ref = std::make_shared<Function>(NodeVector{softmax}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::NODES);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -113,7 +115,9 @@ TEST_P(NegativeSoftmaxFusionFixture, NegativeSoftmaxFusion) {
|
||||
f_ref = std::make_shared<Function>(NodeVector{div}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::NODES);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -94,6 +94,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
@ -180,6 +181,7 @@ TEST_P(TransposeSinking, TransposeReduction) {
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
|
||||
|
@ -109,7 +109,9 @@ TEST_P(TransposeToReshapeTests, CompareFunctions) {
|
||||
f->validate_nodes_and_infer_types();
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -84,6 +84,7 @@ TEST_P(TranslateNewWeightFormatToOldOne, ReshapeMatMul) {
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
|
@ -127,7 +127,9 @@ TEST_P(ConvFusionTests, CompareFunctions) {
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -157,7 +157,9 @@ TEST_P(MulAddConversionTests, CompareFunctions) {
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -178,7 +180,9 @@ TEST_P(MulOrAddConversionTests, CompareFunctions) {
|
||||
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -246,7 +246,7 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -311,7 +311,7 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -491,7 +491,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -643,7 +643,7 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -761,7 +761,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -899,7 +899,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf)
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1017,7 +1017,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1202,7 +1202,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1299,7 +1299,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1406,7 +1406,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1554,7 +1554,7 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1814,7 +1814,7 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -1975,7 +1975,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2050,7 +2050,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2149,7 +2149,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDown) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2403,7 +2403,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2495,7 +2495,7 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2679,7 +2679,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2760,7 +2760,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
|
||||
@ -2855,5 +2855,5 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
@ -126,7 +126,9 @@ TEST_P(ConvertReduceToPoolingTests, CompareFunctions) {
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -90,7 +90,9 @@ void test(std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function>
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
|
||||
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
auto res = fc.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
@ -650,7 +650,9 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
|
||||
///
|
||||
|
||||
{ // check precision only
|
||||
const auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
const auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
const auto res = fc.compare(f1, f2);
|
||||
EXPECT_FALSE(res.valid);
|
||||
EXPECT_THAT(res.message, HasSubstr("Different element type detected"));
|
||||
@ -660,6 +662,7 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
|
||||
|
||||
{ // check precision and attributes
|
||||
const auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::NODES)
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::ATTRIBUTES);
|
||||
const auto res = fc.compare(f1, f2);
|
||||
|
@ -62,7 +62,7 @@ class CompressQuantizeWeightsTests
|
||||
function_ref = std::make_shared<Function>(mul, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
};
|
||||
|
||||
@ -108,7 +108,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph)
|
||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||
@ -133,7 +133,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||
@ -159,7 +159,7 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimiz
|
||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
|
||||
@ -175,5 +175,5 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
|
||||
manager.register_pass<pass::ZeroPointOptimizer>();
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
@ -2,7 +2,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "graph_comparator.hpp"
|
||||
#include "common_test_utils/graph_comparator.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ie_common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@ -20,6 +23,10 @@
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
|
||||
namespace {
|
||||
inline namespace tools {
|
||||
bool is_type_relaxed(const std::string& type) {
|
||||
@ -31,12 +38,12 @@ bool compare_type_info(const ngraph::DiscreteTypeInfo& info1, const ngraph::Disc
|
||||
if (!is_type_relaxed(info1.name) && !is_type_relaxed(info2.name) && (info1.version != info2.version)) {
|
||||
return false;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
const std::string info1Name =
|
||||
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
|
||||
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
|
||||
const std::string info2Name =
|
||||
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
|
||||
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
|
||||
return info1Name == info2Name;
|
||||
}
|
||||
|
||||
@ -121,9 +128,9 @@ public:
|
||||
explicit NodeAndInputDescription(const InputNode& input,
|
||||
const Parameter* parameter,
|
||||
const InputDescripton* description)
|
||||
: m_input(input),
|
||||
m_parameter(not_null(parameter)),
|
||||
m_description(not_null(description)) {}
|
||||
: m_input(input),
|
||||
m_parameter(not_null(parameter)),
|
||||
m_description(not_null(description)) {}
|
||||
|
||||
static bool equal_descriptions(const InputDescripton* lhs, const InputDescripton* rhs) {
|
||||
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
||||
@ -172,7 +179,7 @@ public:
|
||||
}
|
||||
for (size_t i = 0; i != param_shape.size(); ++i) {
|
||||
const auto expected_axis_size =
|
||||
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
|
||||
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
|
||||
if (input_shape[i] != expected_axis_size) {
|
||||
return false;
|
||||
}
|
||||
@ -215,9 +222,9 @@ public:
|
||||
explicit NodeAndOutputDescription(const OutputNode& output,
|
||||
const Result* result,
|
||||
const OutputDescription* description)
|
||||
: m_output(output),
|
||||
m_result(not_null(result)),
|
||||
m_description(not_null(description)) {}
|
||||
: m_output(output),
|
||||
m_result(not_null(result)),
|
||||
m_description(not_null(description)) {}
|
||||
|
||||
static bool equal_descriptions(const OutputDescription* lhs, const OutputDescription* rhs) {
|
||||
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
|
||||
@ -302,8 +309,8 @@ public:
|
||||
using Id = uint64_t;
|
||||
|
||||
explicit BackEdge(const Parameter* parameter, const Result* result)
|
||||
: m_parameter(not_null(parameter)),
|
||||
m_result(not_null(result)) {}
|
||||
: m_parameter(not_null(parameter)),
|
||||
m_result(not_null(result)) {}
|
||||
|
||||
bool result_and_parameter_match() const {
|
||||
return equal_type_and_partial_shape(m_result->output(0), *m_parameter);
|
||||
@ -539,103 +546,113 @@ Comparator::Result Comparator::compare(const std::shared_ptr<ngraph::Function>&
|
||||
* + Check node attributes by Visitor API
|
||||
*/
|
||||
|
||||
auto f_results = f->get_results();
|
||||
auto f_ref_results = f_ref->get_results();
|
||||
if (should_compare(CmpValues::NODES)) {
|
||||
auto f_results = f->get_results();
|
||||
auto f_ref_results = f_ref->get_results();
|
||||
|
||||
auto cmp = less_by_name;
|
||||
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
|
||||
// name An in case of multiple names we sort Result operation using their parent node names
|
||||
if (std::any_of(f_results.begin(),
|
||||
f_results.end(),
|
||||
[](const std::shared_ptr<ngraph::Node>& node) {
|
||||
const auto& t = node->input_value(0).get_tensor_ptr();
|
||||
return t->get_names().size() > 1;
|
||||
}) ||
|
||||
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node>& node) {
|
||||
const auto& t = node->input_value(0).get_tensor_ptr();
|
||||
return t->get_names().size() > 1;
|
||||
})) {
|
||||
cmp = less_by_parent_name;
|
||||
}
|
||||
auto cmp = less_by_name;
|
||||
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
|
||||
// name An in case of multiple names we sort Result operation using their parent node names
|
||||
if (std::any_of(f_results.begin(),
|
||||
f_results.end(),
|
||||
[](const std::shared_ptr<ngraph::Node> &node) {
|
||||
const auto &t = node->input_value(0).get_tensor_ptr();
|
||||
return t->get_names().size() > 1;
|
||||
}) ||
|
||||
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node> &node) {
|
||||
const auto &t = node->input_value(0).get_tensor_ptr();
|
||||
return t->get_names().size() > 1;
|
||||
})) {
|
||||
cmp = less_by_parent_name;
|
||||
}
|
||||
|
||||
std::sort(f_results.begin(), f_results.end(), cmp);
|
||||
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
|
||||
std::sort(f_results.begin(), f_results.end(), cmp);
|
||||
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
|
||||
|
||||
if (f_results.size() != f_ref_results.size()) {
|
||||
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
|
||||
to_str(f_ref_results.size()));
|
||||
}
|
||||
if (f_results.size() != f_ref_results.size()) {
|
||||
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
|
||||
to_str(f_ref_results.size()));
|
||||
}
|
||||
|
||||
const auto& f_sinks = f->get_sinks();
|
||||
const auto& f_ref_sinks = f_ref->get_sinks();
|
||||
if (f_sinks.size() != f_ref_sinks.size()) {
|
||||
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
|
||||
to_str(f_ref_sinks.size()));
|
||||
}
|
||||
const auto &f_sinks = f->get_sinks();
|
||||
const auto &f_ref_sinks = f_ref->get_sinks();
|
||||
if (f_sinks.size() != f_ref_sinks.size()) {
|
||||
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
|
||||
to_str(f_ref_sinks.size()));
|
||||
}
|
||||
|
||||
// Compare sinks
|
||||
if (f_sinks.size() == 1) {
|
||||
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
|
||||
used.insert(f_sinks[0].get());
|
||||
} else {
|
||||
// Cast to Assign and find those that have same variable_id suffix
|
||||
for (const auto& sink1 : f_sinks) {
|
||||
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
|
||||
if (!assign1) {
|
||||
return Result::error("Sink '" + name(sink1) +
|
||||
"' is not a variable - graph comparison is not supported");
|
||||
}
|
||||
auto name1 = assign1->get_variable_id();
|
||||
std::shared_ptr<ov::op::Sink> found_sink2;
|
||||
for (const auto& sink2 : f_ref_sinks) {
|
||||
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
|
||||
if (!assign2) {
|
||||
return Result::error("Sink '" + name(sink2) +
|
||||
// Compare sinks
|
||||
if (f_sinks.size() == 1) {
|
||||
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
|
||||
used.insert(f_sinks[0].get());
|
||||
} else {
|
||||
// Cast to Assign and find those that have same variable_id suffix
|
||||
for (const auto &sink1 : f_sinks) {
|
||||
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
|
||||
if (!assign1) {
|
||||
return Result::error("Sink '" + name(sink1) +
|
||||
"' is not a variable - graph comparison is not supported");
|
||||
}
|
||||
auto name2 = assign2->get_variable_id();
|
||||
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
|
||||
found_sink2 = sink2;
|
||||
break;
|
||||
auto name1 = assign1->get_variable_id();
|
||||
std::shared_ptr<ov::op::Sink> found_sink2;
|
||||
for (const auto &sink2 : f_ref_sinks) {
|
||||
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
|
||||
if (!assign2) {
|
||||
return Result::error("Sink '" + name(sink2) +
|
||||
"' is not a variable - graph comparison is not supported");
|
||||
}
|
||||
auto name2 = assign2->get_variable_id();
|
||||
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
|
||||
found_sink2 = sink2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found_sink2) {
|
||||
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
|
||||
}
|
||||
q.push({sink1.get(), found_sink2.get()});
|
||||
used.insert(sink1.get());
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < f_results.size(); ++i) {
|
||||
if (should_compare(CmpValues::NAMES)) {
|
||||
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
|
||||
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
|
||||
return Result::error(
|
||||
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) +
|
||||
" and " +
|
||||
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
|
||||
}
|
||||
}
|
||||
if (!found_sink2) {
|
||||
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
|
||||
q.push({f_results[i].get(), f_ref_results[i].get()});
|
||||
used.insert(f_results[i].get());
|
||||
}
|
||||
|
||||
std::stringstream errors;
|
||||
|
||||
while (!q.empty()) {
|
||||
ngraph::Node *const node1 = q.front().first;
|
||||
ngraph::Node *const node2 = q.front().second;
|
||||
q.pop();
|
||||
|
||||
const auto result = compare(node1, node2, errors);
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
q.push({sink1.get(), found_sink2.get()});
|
||||
used.insert(sink1.get());
|
||||
|
||||
add_nodes_inputs_to_queue(node1, node2);
|
||||
}
|
||||
const auto msg = errors.str();
|
||||
return msg.empty() ? Result::ok() : Result::error(msg);
|
||||
|
||||
} else if (should_compare(CmpValues::ACCURACY)) {
|
||||
auto status = accuracy_check(f_ref, f);
|
||||
return status.status ? Result::ok() : Result::error(status.message);
|
||||
|
||||
} else {
|
||||
return Result::error("CmpValues are invalid");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < f_results.size(); ++i) {
|
||||
if (should_compare(CmpValues::NAMES)) {
|
||||
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
|
||||
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
|
||||
return Result::error(
|
||||
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) + " and " +
|
||||
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
|
||||
}
|
||||
}
|
||||
q.push({f_results[i].get(), f_ref_results[i].get()});
|
||||
used.insert(f_results[i].get());
|
||||
}
|
||||
|
||||
std::stringstream errors;
|
||||
|
||||
while (!q.empty()) {
|
||||
ngraph::Node* const node1 = q.front().first;
|
||||
ngraph::Node* const node2 = q.front().second;
|
||||
q.pop();
|
||||
|
||||
const auto result = compare(node1, node2, errors);
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
|
||||
add_nodes_inputs_to_queue(node1, node2);
|
||||
}
|
||||
const auto msg = errors.str();
|
||||
return msg.empty() ? Result::ok() : Result::error(msg);
|
||||
}
|
||||
|
||||
Comparator::Result Comparator::compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
||||
@ -759,7 +776,7 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
|
||||
void Comparator::compare_nodes(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
||||
if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(*node1, *node2, err_log)) {
|
||||
err_log << "Different runtime info detected\n" + name(node1) + " and " + name(node2) +
|
||||
" not equal runtime info.\n";
|
||||
" not equal runtime info.\n";
|
||||
}
|
||||
|
||||
if (should_compare(CmpValues::ATTRIBUTES)) {
|
||||
@ -805,123 +822,115 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
|
||||
}
|
||||
}
|
||||
|
||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name) {
|
||||
output.get_tensor_ptr()->set_names({name});
|
||||
}
|
||||
|
||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names) {
|
||||
output.get_tensor_ptr()->set_names(names);
|
||||
}
|
||||
|
||||
namespace attributes {
|
||||
namespace detail {
|
||||
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
insert(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
insert(name, outputs->get());
|
||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
// drop comparison, no more info than port indexes which will be check in
|
||||
// subgraph::compare_io
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
||||
const auto end = beg + a->get()->size();
|
||||
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
||||
} else if (auto framework_node_attr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
insert(name, framework_node_attr->get());
|
||||
} else if (auto variable_ptr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||
insert(name, variable_ptr->get());
|
||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||
insert(name, shape_ptr->get());
|
||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||
insert(name, dim_ptr->get());
|
||||
} else {
|
||||
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||
adapter.get_type_info().name + "']";
|
||||
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
insert(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
insert(name, outputs->get());
|
||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
// drop comparison, no more info than port indexes which will be check in
|
||||
// subgraph::compare_io
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
||||
const auto end = beg + a->get()->size();
|
||||
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
||||
} else if (auto framework_node_attr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
insert(name, framework_node_attr->get());
|
||||
} else if (auto variable_ptr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||
insert(name, variable_ptr->get());
|
||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||
insert(name, shape_ptr->get());
|
||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||
insert(name, dim_ptr->get());
|
||||
} else {
|
||||
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||
adapter.get_type_info().name + "']";
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename AttrValue>
|
||||
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<AttrValue>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
template <typename AttrValue>
|
||||
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<AttrValue>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
}
|
||||
|
||||
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
|
||||
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
|
||||
str::Get<AttrValue>::value(attr_value);
|
||||
}
|
||||
}
|
||||
|
||||
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
|
||||
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
|
||||
str::Get<AttrValue>::value(attr_value);
|
||||
}
|
||||
}
|
||||
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
|
||||
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
}
|
||||
|
||||
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
|
||||
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
if (buffer->size() != ref_value->size() ||
|
||||
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
|
||||
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer->size() != ref_value->size() ||
|
||||
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
|
||||
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
||||
return;
|
||||
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
}
|
||||
Comparator c(m_check_flags);
|
||||
const auto result = c.compare(*ref_value, adapter.get());
|
||||
if (!result.valid) {
|
||||
m_cmp_result += result.message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
|
||||
if (should_return()) {
|
||||
return;
|
||||
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
verify(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
verify(name, outputs->get());
|
||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
// drop comparison, no more info than port indexes which will be check in
|
||||
// subgraph::compare_io
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
verify_mem_buf(name, a->get());
|
||||
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
verify(name, attrs->get());
|
||||
} else if (auto variable_ptr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||
verify(name, variable_ptr->get());
|
||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||
verify(name, shape_ptr->get());
|
||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||
verify(name, dim_ptr->get());
|
||||
} else {
|
||||
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||
adapter.get_type_info().name + "']";
|
||||
}
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
|
||||
if (!ref_value) {
|
||||
m_cmp_result += "missing attribute name: '" + name + "'";
|
||||
return;
|
||||
}
|
||||
Comparator c(m_check_flags);
|
||||
const auto result = c.compare(*ref_value, adapter.get());
|
||||
if (!result.valid) {
|
||||
m_cmp_result += result.message;
|
||||
}
|
||||
}
|
||||
|
||||
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
verify(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
verify(name, outputs->get());
|
||||
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
// drop comparison, no more info than port indexes which will be check in
|
||||
// subgraph::compare_io
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
verify_mem_buf(name, a->get());
|
||||
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
verify(name, attrs->get());
|
||||
} else if (auto variable_ptr =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||
verify(name, variable_ptr->get());
|
||||
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
|
||||
verify(name, shape_ptr->get());
|
||||
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
|
||||
verify(name, dim_ptr->get());
|
||||
} else {
|
||||
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
|
||||
adapter.get_type_info().name + "']";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -937,3 +946,42 @@ Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator:
|
||||
}
|
||||
|
||||
} // namespace attributes
|
||||
|
||||
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
|
||||
const std::shared_ptr<ov::Model>& cur_function) {
|
||||
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
|
||||
return AccuracyCheckResult{true, ""};
|
||||
}
|
||||
try {
|
||||
IE_ASSERT(ref_function->get_parameters().size() == cur_function->get_parameters().size());
|
||||
|
||||
std::map<std::shared_ptr<ov::Node>, ov::Tensor> ref_input_data;
|
||||
std::map<std::shared_ptr<ov::Node>, ov::Tensor> cur_input_data;
|
||||
for (size_t i = 0; i < ref_function->get_parameters().size(); i++) {
|
||||
const auto &tensor = ov::test::utils::create_and_fill_tensor(
|
||||
ref_function->get_parameters()[i]->get_element_type(),
|
||||
ref_function->get_parameters()[i]->get_shape());
|
||||
ref_input_data[ref_function->get_parameters()[i]] = tensor;
|
||||
cur_input_data[cur_function->get_parameters()[i]] = tensor;
|
||||
}
|
||||
|
||||
auto ref_outputs = ngraph::helpers::interpretFunction(ref_function, ref_input_data);
|
||||
auto outputs = ngraph::helpers::interpretFunction(cur_function, cur_input_data);
|
||||
|
||||
IE_ASSERT(ref_outputs.size() == outputs.size());
|
||||
|
||||
for (int i = 0; i < ref_outputs.size(); i++) {
|
||||
ov::test::utils::compare(ref_outputs[i], outputs[i],
|
||||
std::numeric_limits<double>::max(),
|
||||
std::numeric_limits<double>::max());
|
||||
}
|
||||
}
|
||||
catch (const std::runtime_error &re) {
|
||||
return AccuracyCheckResult{false, re.what()};
|
||||
} catch (const std::exception &ex) {
|
||||
return AccuracyCheckResult{false, ex.what()};
|
||||
} catch (...) {
|
||||
return AccuracyCheckResult{false, ""};
|
||||
}
|
||||
return AccuracyCheckResult{true, ""};
|
||||
}
|
||||
|
@ -20,12 +20,14 @@ class FunctionsComparator {
|
||||
public:
|
||||
enum CmpValues {
|
||||
NONE = 0,
|
||||
CONST_VALUES = 1 << 0,
|
||||
NAMES = 1 << 1,
|
||||
RUNTIME_KEYS = 1 << 2,
|
||||
PRECISIONS = 1 << 3,
|
||||
ATTRIBUTES = 1 << 4,
|
||||
TENSOR_NAMES = 1 << 5,
|
||||
NODES = 1 << 0,
|
||||
CONST_VALUES = 1 << 1,
|
||||
NAMES = 1 << 2,
|
||||
RUNTIME_KEYS = 1 << 3,
|
||||
PRECISIONS = 1 << 4,
|
||||
ATTRIBUTES = 1 << 5,
|
||||
TENSOR_NAMES = 1 << 6,
|
||||
ACCURACY = 1 << 7
|
||||
};
|
||||
|
||||
struct Result {
|
||||
@ -45,6 +47,7 @@ public:
|
||||
}
|
||||
static FunctionsComparator with_default() noexcept {
|
||||
auto fc = FunctionsComparator::no_default();
|
||||
fc.enable(NODES);
|
||||
fc.enable(PRECISIONS);
|
||||
fc.enable(TENSOR_NAMES);
|
||||
return fc;
|
||||
@ -55,6 +58,11 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
FunctionsComparator& disable(CmpValues f) noexcept {
|
||||
m_comparison_flags = static_cast<CmpValues>(m_comparison_flags & ~f);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool should_compare(CmpValues f) const noexcept {
|
||||
return m_comparison_flags & f;
|
||||
}
|
||||
@ -85,6 +93,7 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
|
||||
auto fc = FunctionsComparator::no_default();
|
||||
|
||||
using Cmp = FunctionsComparator::CmpValues;
|
||||
fc.enable(Cmp::NODES);
|
||||
if (compareConstValues)
|
||||
fc.enable(Cmp::CONST_VALUES);
|
||||
if (compareNames)
|
||||
@ -104,10 +113,6 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
|
||||
|
||||
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
|
||||
|
||||
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name);
|
||||
|
||||
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names);
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class InjectionPass : public ov::pass::ModelPass {
|
||||
@ -188,8 +193,8 @@ public:
|
||||
for (auto output : node->outputs()) {
|
||||
const auto& tensor_names = output.get_names();
|
||||
if (std::any_of(tensor_names.begin(), tensor_names.end(), [&](const std::string& name) {
|
||||
return unique_tensor_names.count(name);
|
||||
})) {
|
||||
return unique_tensor_names.count(name);
|
||||
})) {
|
||||
std::stringstream ss;
|
||||
ss << "Node: " << node->get_type_info() << " with name " << node->get_friendly_name() << " ";
|
||||
ss << "has non unique tensor name.";
|
||||
@ -478,9 +483,9 @@ public:
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
||||
|
||||
#define ON_ADAPTER(TYPE) \
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||
insert(name, adapter.get()); \
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||
insert(name, adapter.get()); \
|
||||
}
|
||||
|
||||
ON_ADAPTER(bool)
|
||||
ON_ADAPTER(std::string)
|
||||
@ -711,37 +716,37 @@ struct Equal<std::shared_ptr<Constant>> {
|
||||
}
|
||||
|
||||
switch (lhs_t) {
|
||||
case ngraph::element::Type_t::u1: {
|
||||
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
|
||||
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
|
||||
const auto lhs_bit_size = shape_size(lhs->get_shape());
|
||||
const auto rhs_bit_size = shape_size(rhs->get_shape());
|
||||
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
|
||||
}
|
||||
case ngraph::element::Type_t::bf16: {
|
||||
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
|
||||
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
|
||||
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f16: {
|
||||
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
|
||||
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
|
||||
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f32: {
|
||||
const auto& lhs_v = lhs->cast_vector<float>();
|
||||
const auto& rhs_v = rhs->cast_vector<float>();
|
||||
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
const auto& lhs_v = lhs->cast_vector<double>();
|
||||
const auto& rhs_v = rhs->cast_vector<double>();
|
||||
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::u1: {
|
||||
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
|
||||
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
|
||||
const auto lhs_bit_size = shape_size(lhs->get_shape());
|
||||
const auto rhs_bit_size = shape_size(rhs->get_shape());
|
||||
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
|
||||
}
|
||||
case ngraph::element::Type_t::bf16: {
|
||||
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
|
||||
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
|
||||
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f16: {
|
||||
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
|
||||
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
|
||||
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f32: {
|
||||
const auto& lhs_v = lhs->cast_vector<float>();
|
||||
const auto& rhs_v = rhs->cast_vector<float>();
|
||||
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
const auto& lhs_v = lhs->cast_vector<double>();
|
||||
const auto& rhs_v = rhs->cast_vector<double>();
|
||||
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -862,18 +867,18 @@ struct Get<std::shared_ptr<ngraph::Variable>, void> {
|
||||
class ReadAndCompareAttributes : public ngraph::AttributeVisitor {
|
||||
public:
|
||||
ReadAndCompareAttributes(const ReadAndStoreAttributes& ref, Comparator::CmpValues check_flags)
|
||||
: m_attr_ref(ref),
|
||||
m_cmp_result{ref.read_result()},
|
||||
m_check_flags(check_flags) {}
|
||||
: m_attr_ref(ref),
|
||||
m_cmp_result{ref.read_result()},
|
||||
m_check_flags(check_flags) {}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||
verify_others(name, adapter);
|
||||
}
|
||||
|
||||
#define ON_ADAPTER(TYPE) \
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||
verify(name, adapter.get()); \
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||
verify(name, adapter.get()); \
|
||||
}
|
||||
|
||||
ON_ADAPTER(bool)
|
||||
ON_ADAPTER(std::string)
|
||||
@ -981,3 +986,11 @@ private:
|
||||
Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags);
|
||||
|
||||
} // namespace attributes
|
||||
|
||||
struct AccuracyCheckResult {
|
||||
bool status;
|
||||
std::string message;
|
||||
};
|
||||
|
||||
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
|
||||
const std::shared_ptr<ov::Model>& cur_function);
|
||||
|
@ -2,248 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <functional_test_utils/include/functional_test_utils/blob_utils.hpp>
|
||||
#include "ngraph_test_utils.hpp"
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
template<class T_IE, class T_NGRAPH>
|
||||
static void Compare(const T_NGRAPH *expected, const T_IE *actual, std::size_t size, float threshold, float abs_threshold = -1.f) {
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
const T_NGRAPH &ref = expected[i];
|
||||
const auto &res = actual[i];
|
||||
const auto absoluteDifference = CommonTestUtils::ie_abs(res - ref);
|
||||
if (abs_threshold > 0.f && absoluteDifference > abs_threshold) {
|
||||
IE_THROW() << "Absolute comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
|
||||
<< " at index " << i << " with absolute threshold " << abs_threshold
|
||||
<< " failed";
|
||||
}
|
||||
if (absoluteDifference <= threshold) {
|
||||
continue;
|
||||
}
|
||||
double max;
|
||||
if (sizeof(T_IE) < sizeof(T_NGRAPH)) {
|
||||
max = std::max(CommonTestUtils::ie_abs(T_NGRAPH(res)), CommonTestUtils::ie_abs(ref));
|
||||
} else {
|
||||
max = std::max(CommonTestUtils::ie_abs(res), CommonTestUtils::ie_abs(T_IE(ref)));
|
||||
}
|
||||
double diff = static_cast<float>(absoluteDifference) / max;
|
||||
if (max == 0 || (diff > static_cast<float>(threshold)) ||
|
||||
(std::isnan(static_cast<float>(res)) ^ std::isnan(static_cast<float>(ref)))) {
|
||||
IE_THROW() << "Relative comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
|
||||
<< " at index " << i << " with threshold " << threshold
|
||||
<< " failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_IE>
|
||||
void callCompare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
|
||||
const T_IE* actualBuffer, size_t size, float threshold, float abs_threshold) {
|
||||
auto expectedBuffer = expected.second.data();
|
||||
switch (expected.first) {
|
||||
case ngraph::element::Type_t::i64:
|
||||
Compare<T_IE, int64_t>(reinterpret_cast<const int64_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::i32:
|
||||
Compare<T_IE, int32_t>(reinterpret_cast<const int32_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::i16:
|
||||
Compare<T_IE, int16_t>(reinterpret_cast<const int16_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::i8:
|
||||
Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::u64:
|
||||
Compare<T_IE, uint64_t>(reinterpret_cast<const uint64_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::u32:
|
||||
Compare<T_IE, uint32_t>(reinterpret_cast<const uint32_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::u16:
|
||||
Compare<T_IE, uint16_t>(reinterpret_cast<const uint16_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::boolean:
|
||||
case ngraph::element::Type_t::u8:
|
||||
Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::f64:
|
||||
Compare<T_IE, double>(reinterpret_cast<const double *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::f32:
|
||||
Compare<T_IE, float>(reinterpret_cast<const float *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::f16:
|
||||
Compare<T_IE, ngraph::float16>(reinterpret_cast<const ngraph::float16 *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
case ngraph::element::Type_t::bf16:
|
||||
Compare<T_IE, ngraph::bfloat16>(reinterpret_cast<const ngraph::bfloat16 *>(expectedBuffer),
|
||||
actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
// case ngraph::element::Type_t::i4: {
|
||||
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
|
||||
// expected.second,
|
||||
// expected.first,
|
||||
// ngraph::element::Type_t::i8,
|
||||
// size);
|
||||
// Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedOut.data()),
|
||||
// actualBuffer, size, threshold, abs_threshold);
|
||||
// break;
|
||||
// }
|
||||
// case ngraph::element::Type_t::u4: {
|
||||
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
|
||||
// expected.second,
|
||||
// expected.first,
|
||||
// ngraph::element::Type_t::u8,
|
||||
// size);
|
||||
// Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedOut.data()),
|
||||
// actualBuffer, size, threshold, abs_threshold);
|
||||
// break;
|
||||
// }
|
||||
case ngraph::element::Type_t::dynamic:
|
||||
case ngraph::element::Type_t::undefined:
|
||||
Compare<T_IE, T_IE>(reinterpret_cast<const T_IE *>(expectedBuffer), actualBuffer, size, threshold, abs_threshold);
|
||||
break;
|
||||
default: FAIL() << "Comparator for " << expected.first << " precision isn't supported";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
void Compare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
|
||||
const InferenceEngine::Blob::Ptr &actual,
|
||||
float threshold,
|
||||
float abs_threshold) {
|
||||
const auto &precision = actual->getTensorDesc().getPrecision();
|
||||
auto k = static_cast<float>(expected.first.size()) / precision.size();
|
||||
// W/A for int4, uint4
|
||||
if (expected.first == ngraph::element::Type_t::u4 || expected.first == ngraph::element::Type_t::i4) {
|
||||
k /= 2;
|
||||
} else if (expected.first == ngraph::element::Type_t::undefined || expected.first == ngraph::element::Type_t::dynamic) {
|
||||
k = 1;
|
||||
}
|
||||
ASSERT_EQ(expected.second.size(), actual->byteSize() * k);
|
||||
|
||||
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(actual);
|
||||
IE_ASSERT(memory);
|
||||
const auto lockedMemory = memory->wmap();
|
||||
const auto actualBuffer = lockedMemory.as<const std::uint8_t *>();
|
||||
|
||||
const auto &size = actual->size();
|
||||
switch (precision) {
|
||||
case InferenceEngine::Precision::FP32:
|
||||
callCompare<float>(expected, reinterpret_cast<const float *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::I32:
|
||||
callCompare<int32_t>(expected, reinterpret_cast<const int32_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::I64:
|
||||
callCompare<int64_t>(expected, reinterpret_cast<const int64_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::I8:
|
||||
callCompare<int8_t>(expected, reinterpret_cast<const int8_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::U16:
|
||||
callCompare<uint16_t>(expected, reinterpret_cast<const uint16_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::I16:
|
||||
callCompare<int16_t>(expected, reinterpret_cast<const int16_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::BOOL:
|
||||
case InferenceEngine::Precision::U8:
|
||||
callCompare<uint8_t>(expected, reinterpret_cast<const uint8_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::U64:
|
||||
callCompare<uint64_t>(expected, reinterpret_cast<const uint64_t *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::BF16:
|
||||
callCompare<ngraph::bfloat16>(expected, reinterpret_cast<const ngraph::bfloat16 *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
case InferenceEngine::Precision::FP16:
|
||||
callCompare<ngraph::float16>(expected, reinterpret_cast<const ngraph::float16 *>(actualBuffer), size, threshold, abs_threshold);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "Comparator for " << precision << " precision isn't supported";
|
||||
}
|
||||
}
|
||||
|
||||
void Compare(const std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> &expectedOutputs,
|
||||
const std::vector<InferenceEngine::Blob::Ptr> &actualOutputs,
|
||||
float threshold, float abs_threshold) {
|
||||
for (std::size_t outputIndex = 0; outputIndex < expectedOutputs.size(); ++outputIndex) {
|
||||
const auto &expected = expectedOutputs[outputIndex];
|
||||
const auto &actual = actualOutputs[outputIndex];
|
||||
Compare(expected, actual, threshold, abs_threshold);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_function,
|
||||
std::shared_ptr<ov::Model> cur_function) {
|
||||
try {
|
||||
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
|
||||
return;
|
||||
}
|
||||
std::vector<std::vector<uint8_t>> input_data;
|
||||
ngraph::element::TypeVector types;
|
||||
for (auto param : ref_function->get_parameters()) {
|
||||
types.push_back(param->get_element_type());
|
||||
|
||||
auto layout = InferenceEngine::Layout::ANY;
|
||||
if (ov::is_scalar(param->get_shape())) {
|
||||
layout = InferenceEngine::Layout::SCALAR;
|
||||
}
|
||||
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout);
|
||||
const auto &input = FuncTestUtils::createAndFillBlob(td);
|
||||
const auto &input_size = input->byteSize();
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
data.resize(input_size);
|
||||
|
||||
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(input);
|
||||
IE_ASSERT(memory);
|
||||
|
||||
const auto lockedMemory = memory->wmap();
|
||||
const auto buffer = lockedMemory.as<const std::uint8_t *>();
|
||||
std::copy(buffer, buffer + input_size, data.data());
|
||||
|
||||
input_data.push_back(std::move(data));
|
||||
}
|
||||
|
||||
auto ref_outputs = ngraph::helpers::interpreterFunction(ref_function, input_data, types);
|
||||
auto outputs = ngraph::helpers::interpreterFunction(cur_function, input_data, types);
|
||||
|
||||
IE_ASSERT(ref_outputs.size() == outputs.size());
|
||||
|
||||
for (size_t i = 0; i < ref_outputs.size(); ++i) {
|
||||
IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size());
|
||||
auto * ref = reinterpret_cast<float *>(ref_outputs[i].second.data());
|
||||
auto * out = reinterpret_cast<float *>(outputs[i].second.data());
|
||||
size_t size = ref_outputs[i].second.size() / sizeof(float);
|
||||
IE_ASSERT(size > 0);
|
||||
Compare<float, float>(ref, out, size, 1e-5);
|
||||
}
|
||||
}
|
||||
catch (const std::runtime_error &re) {
|
||||
GTEST_FATAL_FAILURE_(re.what());
|
||||
} catch (const std::exception &ex) {
|
||||
GTEST_FATAL_FAILURE_(ex.what());
|
||||
} catch (...) {
|
||||
GTEST_FATAL_FAILURE_("Unknown failure occurred.");
|
||||
}
|
||||
}
|
||||
|
||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
||||
ngraph::pass::Manager manager;
|
||||
|
@ -32,6 +32,7 @@ class TransformationTestsF : public CommonTestUtils::TestsCommon {
|
||||
public:
|
||||
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
||||
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
// TODO: enable attributes and constant values comparison by default XXX-68694
|
||||
@ -56,12 +57,16 @@ public:
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
}
|
||||
|
||||
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
||||
auto acc_comparator = FunctionsComparator::no_default();
|
||||
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
auto res = acc_comparator.compare(function, cloned_function);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
|
||||
if (m_enable_accuracy_check) {
|
||||
accuracy_check(cloned_function, function);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
||||
@ -74,12 +79,6 @@ public:
|
||||
m_soft_names_comparison = true;
|
||||
}
|
||||
|
||||
void enable_accuracy_check() {
|
||||
m_enable_accuracy_check = true;
|
||||
}
|
||||
|
||||
void accuracy_check(std::shared_ptr<ov::Model> ref_function, std::shared_ptr<ov::Model> cur_function);
|
||||
|
||||
std::shared_ptr<ov::Model> function, function_ref;
|
||||
ngraph::pass::Manager manager;
|
||||
FunctionsComparator comparator;
|
||||
@ -88,7 +87,6 @@ private:
|
||||
std::shared_ptr<ngraph::pass::UniqueNamesHolder> m_unh;
|
||||
bool m_disable_rt_info_check{false};
|
||||
bool m_soft_names_comparison{true};
|
||||
bool m_enable_accuracy_check{false};
|
||||
};
|
||||
|
||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh);
|
||||
|
Loading…
Reference in New Issue
Block a user