[CPU] FakeQuantize: new cases support (#5497)

This commit is contained in:
Vladislav Golubev 2021-05-18 14:40:13 +03:00 committed by GitHub
parent 0face0e7cb
commit 49a8714ee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 13 deletions

View File

@ -10,6 +10,7 @@
#include <mkldnn_types.h>
#include <mkldnn_extension_utils.h>
#include "utils/general_utils.h"
#include "utils/cpu_utils.hpp"
#include <algorithm>
#include <set>
@ -841,7 +842,7 @@ bool MKLDNNFakeQuantizeNode::isSupportedOperation(const std::shared_ptr<const ng
}
for (size_t i = 1; i < fq->get_input_size(); i++) {
size_t count_not_unit_axis = 0;
auto shape = fq->get_input_shape(i);
auto shape = getNormalizedDimsBySize(fq->get_input_shape(i), fq->get_input_shape(0).size());
if (ngraph::shape_size(shape) != 1) {
size_t not_unit_axis = 0;
@ -885,9 +886,7 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
if (fq->get_output_size() != 1)
IE_THROW() << errorPrefix << "has incorrect number of output edges: " << fq->get_output_size();
auto initAxisIdx = [&](size_t edgeIdx) {
const auto &inputDims = fq->get_input_shape(edgeIdx);
auto initAxisIdx = [&](const ngraph::Shape& inputDims) {
size_t axisIdx = 0;
for (int i = 1; i < inputDims.size(); i++) {
if (inputDims[i] > 1) {
@ -898,35 +897,36 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
return axisIdx;
};
axis = fq->get_input_shape(0).size() == 1 ? 0 : 1;
const size_t dataNDims = fq->get_input_shape(0).size();
axis = dataNDims == 1 ? 0 : 1;
int axisSize = -1;
auto inputLowAxis = initAxisIdx(1);
const auto ilShape = fq->get_input_shape(1);
const auto ilShape = getNormalizedDimsBySize(fq->get_input_shape(1), dataNDims);
auto inputLowAxis = initAxisIdx(ilShape);
isInputLowBroadcasted = (ngraph::is_scalar(ilShape) || ilShape[inputLowAxis] == 1);
if (!isInputLowBroadcasted) {
axis = inputLowAxis;
axisSize = ilShape[inputLowAxis];
}
auto inputHighAxis = initAxisIdx(2);
const auto ihShape = fq->get_input_shape(2);
const auto ihShape = getNormalizedDimsBySize(fq->get_input_shape(2), dataNDims);
auto inputHighAxis = initAxisIdx(ihShape);
isInputHighBroadcasted = (ngraph::is_scalar(ihShape) || ihShape[inputHighAxis] == 1);
if (!isInputHighBroadcasted) {
axis = inputHighAxis;
axisSize = ihShape[inputHighAxis];
}
auto outputLowAxis = initAxisIdx(3);
const auto olShape = fq->get_input_shape(3);
const auto olShape = getNormalizedDimsBySize(fq->get_input_shape(3), dataNDims);
auto outputLowAxis = initAxisIdx(olShape);
isOutputLowBroadcasted = (ngraph::is_scalar(olShape) || olShape[outputLowAxis] == 1);
if (!isOutputLowBroadcasted) {
axis = outputLowAxis;
axisSize = olShape[outputLowAxis];
}
auto outputHighAxis = initAxisIdx(4);
const auto ohShape = fq->get_input_shape(4);
const auto ohShape = getNormalizedDimsBySize(fq->get_input_shape(4), dataNDims);
auto outputHighAxis = initAxisIdx(ohShape);
isOutputHighBroadcasted = (ngraph::is_scalar(ohShape) || ohShape[outputHighAxis] == 1);
if (!isOutputHighBroadcasted) {
axis = outputHighAxis;

View File

@ -89,4 +89,27 @@ INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannelAxis1, FakeQuantizeLayerTest
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(config)),
FakeQuantizeLayerTest::getTestCaseName);
const std::vector<std::vector<size_t>> inputShapesPerChannel2D = {{1, 10}};
const std::vector<std::vector<size_t>> constShapesPerChannel2D = { {10}, {1, 10}, {1} };
const auto fqParamsPerChannel2D = ::testing::Combine(
::testing::ValuesIn(levels),
::testing::ValuesIn(constShapesPerChannel2D),
::testing::Values(fqArgs),
::testing::Values(inputParams)
);
INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannel2D, FakeQuantizeLayerTest,
::testing::Combine(
fqParamsPerChannel2D,
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::ValuesIn(inputShapesPerChannel2D),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(config)),
FakeQuantizeLayerTest::getTestCaseName);
} // namespace