[CPU] FakeQuantize: new cases support (#5497)
This commit is contained in:
parent
0face0e7cb
commit
49a8714ee5
@ -10,6 +10,7 @@
|
||||
#include <mkldnn_types.h>
|
||||
#include <mkldnn_extension_utils.h>
|
||||
#include "utils/general_utils.h"
|
||||
#include "utils/cpu_utils.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
@ -841,7 +842,7 @@ bool MKLDNNFakeQuantizeNode::isSupportedOperation(const std::shared_ptr<const ng
|
||||
}
|
||||
for (size_t i = 1; i < fq->get_input_size(); i++) {
|
||||
size_t count_not_unit_axis = 0;
|
||||
auto shape = fq->get_input_shape(i);
|
||||
auto shape = getNormalizedDimsBySize(fq->get_input_shape(i), fq->get_input_shape(0).size());
|
||||
|
||||
if (ngraph::shape_size(shape) != 1) {
|
||||
size_t not_unit_axis = 0;
|
||||
@ -885,9 +886,7 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
|
||||
if (fq->get_output_size() != 1)
|
||||
IE_THROW() << errorPrefix << "has incorrect number of output edges: " << fq->get_output_size();
|
||||
|
||||
auto initAxisIdx = [&](size_t edgeIdx) {
|
||||
const auto &inputDims = fq->get_input_shape(edgeIdx);
|
||||
|
||||
auto initAxisIdx = [&](const ngraph::Shape& inputDims) {
|
||||
size_t axisIdx = 0;
|
||||
for (int i = 1; i < inputDims.size(); i++) {
|
||||
if (inputDims[i] > 1) {
|
||||
@ -898,35 +897,36 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
|
||||
return axisIdx;
|
||||
};
|
||||
|
||||
axis = fq->get_input_shape(0).size() == 1 ? 0 : 1;
|
||||
const size_t dataNDims = fq->get_input_shape(0).size();
|
||||
axis = dataNDims == 1 ? 0 : 1;
|
||||
int axisSize = -1;
|
||||
|
||||
auto inputLowAxis = initAxisIdx(1);
|
||||
const auto ilShape = fq->get_input_shape(1);
|
||||
const auto ilShape = getNormalizedDimsBySize(fq->get_input_shape(1), dataNDims);
|
||||
auto inputLowAxis = initAxisIdx(ilShape);
|
||||
isInputLowBroadcasted = (ngraph::is_scalar(ilShape) || ilShape[inputLowAxis] == 1);
|
||||
if (!isInputLowBroadcasted) {
|
||||
axis = inputLowAxis;
|
||||
axisSize = ilShape[inputLowAxis];
|
||||
}
|
||||
|
||||
auto inputHighAxis = initAxisIdx(2);
|
||||
const auto ihShape = fq->get_input_shape(2);
|
||||
const auto ihShape = getNormalizedDimsBySize(fq->get_input_shape(2), dataNDims);
|
||||
auto inputHighAxis = initAxisIdx(ihShape);
|
||||
isInputHighBroadcasted = (ngraph::is_scalar(ihShape) || ihShape[inputHighAxis] == 1);
|
||||
if (!isInputHighBroadcasted) {
|
||||
axis = inputHighAxis;
|
||||
axisSize = ihShape[inputHighAxis];
|
||||
}
|
||||
|
||||
auto outputLowAxis = initAxisIdx(3);
|
||||
const auto olShape = fq->get_input_shape(3);
|
||||
const auto olShape = getNormalizedDimsBySize(fq->get_input_shape(3), dataNDims);
|
||||
auto outputLowAxis = initAxisIdx(olShape);
|
||||
isOutputLowBroadcasted = (ngraph::is_scalar(olShape) || olShape[outputLowAxis] == 1);
|
||||
if (!isOutputLowBroadcasted) {
|
||||
axis = outputLowAxis;
|
||||
axisSize = olShape[outputLowAxis];
|
||||
}
|
||||
|
||||
auto outputHighAxis = initAxisIdx(4);
|
||||
const auto ohShape = fq->get_input_shape(4);
|
||||
const auto ohShape = getNormalizedDimsBySize(fq->get_input_shape(4), dataNDims);
|
||||
auto outputHighAxis = initAxisIdx(ohShape);
|
||||
isOutputHighBroadcasted = (ngraph::is_scalar(ohShape) || ohShape[outputHighAxis] == 1);
|
||||
if (!isOutputHighBroadcasted) {
|
||||
axis = outputHighAxis;
|
||||
|
@ -89,4 +89,27 @@ INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannelAxis1, FakeQuantizeLayerTest
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(config)),
|
||||
FakeQuantizeLayerTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<size_t>> inputShapesPerChannel2D = {{1, 10}};
|
||||
const std::vector<std::vector<size_t>> constShapesPerChannel2D = { {10}, {1, 10}, {1} };
|
||||
const auto fqParamsPerChannel2D = ::testing::Combine(
|
||||
::testing::ValuesIn(levels),
|
||||
::testing::ValuesIn(constShapesPerChannel2D),
|
||||
::testing::Values(fqArgs),
|
||||
::testing::Values(inputParams)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannel2D, FakeQuantizeLayerTest,
|
||||
::testing::Combine(
|
||||
fqParamsPerChannel2D,
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::ValuesIn(inputShapesPerChannel2D),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(config)),
|
||||
FakeQuantizeLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user