Floats comparison implementation in IE_Engines (#4342)

This commit is contained in:
Tomasz Dołbniak
2021-02-24 12:10:00 +01:00
committed by GitHub
parent 5fa1e2140e
commit 4d9ede42ca
6 changed files with 119 additions and 35 deletions

View File

@@ -1583,15 +1583,10 @@ IE_GPU.onnx_model_gather_elements_int32_axis_0
IE_GPU.onnx_model_gather_elements_int8_axis_1
IE_GPU.onnx_model_gather_elements_float_3D_axis_2
IE_CPU.evaluate_ctc_greedy_decoder_seq_len
IE_GPU.evaluate_ctc_greedy_decoder_seq_len
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_f16
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_f16
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_merge
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_merge
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
# incorrect result for Minimum if u16 type is used

View File

@@ -22,6 +22,7 @@ set (SRC
all_close_f.cpp
engine/ie_engines.cpp
engine/interpreter_engine.cpp
engine/shared_utils.cpp
float_util.cpp
test_tools.cpp
test_control.cpp

View File

@@ -19,6 +19,7 @@
#include "ngraph/opsets/opset.hpp"
#include "ngraph/pass/manager.hpp"
#include "pass/opset1_upgrade.hpp"
#include "shared_utils.hpp"
using namespace ngraph;
@@ -121,7 +122,7 @@ namespace
default: THROW_IE_EXCEPTION << "Not implemented yet";
}
}
};
}; // namespace
namespace
{
@@ -156,7 +157,7 @@ namespace
#endif
throw std::runtime_error("unsupported type");
}
}
} // namespace
test::IE_Engine::IE_Engine(const std::shared_ptr<Function> function, const char* device)
: m_function{function}
@@ -244,6 +245,36 @@ testing::AssertionResult
test::IE_Engine::compare_results_with_tolerance_as_fp(const float tolerance)
{
auto comparison_result = testing::AssertionSuccess();
for (const auto& output : m_network_outputs)
{
if (comparison_result == testing::AssertionFailure())
{
break;
}
InferenceEngine::MemoryBlob::CPtr computed_output_blob =
InferenceEngine::as<InferenceEngine::MemoryBlob>(m_inference_req.GetBlob(output.first));
const auto& expected_output_blob = m_expected_outputs[output.first];
switch (expected_output_blob->getTensorDesc().getPrecision())
{
case InferenceEngine::Precision::FP32:
{
const auto test_results =
extract_test_results<float>(computed_output_blob, expected_output_blob);
comparison_result =
test::compare_with_tolerance(test_results.first, test_results.second, tolerance);
break;
}
default:
comparison_result = testing::AssertionFailure()
<< "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
}
}
return comparison_result;
}
@@ -308,4 +339,4 @@ namespace InferenceEngine
template class TBlob<ngraph::float16>;
template class TBlob<char>;
#endif
}
} // namespace InferenceEngine

View File

@@ -14,11 +14,13 @@
// limitations under the License.
//*****************************************************************************
#include "interpreter_engine.hpp"
#include <cmath>
#include <iomanip>
#include <sstream>
#include "interpreter_engine.hpp"
#include "shared_utils.hpp"
using namespace ngraph;
namespace
@@ -45,32 +47,7 @@ namespace
const auto expected = expected_results->get_vector<float>();
const auto result = read_vector<float>(results);
Shape out_shape = expected_results->get_shape();
size_t num_of_elems = shape_size(out_shape);
std::stringstream msg;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
bool rc = true;
for (std::size_t j = 0; j < num_of_elems; ++j)
{
float diff = std::abs(result[j] - expected[j]);
if (diff > tolerance)
{
msg << expected[j] << " is not close to " << result[j] << " at index " << j << "\n";
rc = false;
}
}
if (!rc)
{
comparison_result = testing::AssertionFailure();
}
comparison_result << msg.str();
return comparison_result;
return ngraph::test::compare_with_tolerance(expected, result, tolerance);
}
template <typename T>

View File

@@ -0,0 +1,50 @@
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <sstream>
#include "shared_utils.hpp"
testing::AssertionResult ngraph::test::compare_with_tolerance(const std::vector<float>& expected,
const std::vector<float>& results,
const float tolerance)
{
auto comparison_result = testing::AssertionSuccess();
std::stringstream msg;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
bool rc = true;
for (std::size_t j = 0; j < expected.size(); ++j)
{
float diff = std::fabs(results[j] - expected[j]);
if (diff > tolerance)
{
msg << expected[j] << " is not close to " << results[j] << " at index " << j << "\n";
rc = false;
}
}
if (!rc)
{
comparison_result = testing::AssertionFailure();
comparison_result << msg.str();
}
return comparison_result;
}

View File

@@ -0,0 +1,30 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <gtest/gtest.h>
#include <vector>
namespace ngraph
{
namespace test
{
testing::AssertionResult compare_with_tolerance(const std::vector<float>& expected_results,
const std::vector<float>& results,
const float tolerance);
}
} // namespace ngraph