[IE CLDNN] Fixed MVN-6 with negative axes (#4523)
This commit is contained in:
@@ -278,6 +278,9 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
if (auto axesNode = dynamic_cast<ngraph::op::v0::Constant*>(mvn->get_input_node_ptr(1))) {
|
||||
auto axesVal = axesNode->cast_vector<int>();
|
||||
auto& mvnShape = mvn->get_output_shape(0);
|
||||
for (int32_t& axis : axesVal)
|
||||
axis = axis < 0 ? axis + mvnShape.size() : axis;
|
||||
std::sort(axesVal.begin(), axesVal.end());
|
||||
if (mvnShape.size() == 1)
|
||||
return false;
|
||||
if (mvnShape.size() > 5 || (mvnShape.size() != axesVal.size() + 1 && mvnShape.size() != axesVal.size() + 2))
|
||||
|
||||
@@ -46,7 +46,10 @@ void CreateMVNOp(Program& p, const std::shared_ptr<ngraph::op::v6::MVN>& op) {
|
||||
if (!inConst)
|
||||
THROW_IE_EXCEPTION << "Unsupported parameter nodes type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
|
||||
|
||||
auto& mvnShape = op->get_output_shape(0);
|
||||
std::vector<int32_t> axes = inConst->cast_vector<int32_t>();
|
||||
for (int32_t& axis : axes)
|
||||
axis = axis < 0 ? axis + mvnShape.size() : axis;
|
||||
|
||||
const size_t chanelAxis = 1;
|
||||
bool across_channels = std::find(axes.begin(), axes.end(), chanelAxis) != axes.end();
|
||||
|
||||
@@ -73,7 +73,7 @@ INSTANTIATE_TEST_CASE_P(smoke_MVN_5D, Mvn6LayerTest,
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>{{1, 10, 5, 7, 8}, {1, 3, 8, 9, 49}}),
|
||||
::testing::ValuesIn(dataPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2, 3, 4}, {2, 3, 4}}),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2, 3, 4}, {2, 3, 4}, {-3, -2, -1}, {-1, -4, -2, -3}}),
|
||||
::testing::ValuesIn(normalizeVariance),
|
||||
::testing::ValuesIn(epsilonF),
|
||||
::testing::ValuesIn(epsMode),
|
||||
@@ -85,7 +85,7 @@ INSTANTIATE_TEST_CASE_P(smoke_MVN_4D, Mvn6LayerTest,
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>{{1, 10, 5, 17}, {1, 3, 8, 9}}),
|
||||
::testing::ValuesIn(dataPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2, 3}, {2, 3}}),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2, 3}, {2, 3}, {-2, -1}, {-2, -1, -3}}),
|
||||
::testing::ValuesIn(normalizeVariance),
|
||||
::testing::ValuesIn(epsilonF),
|
||||
::testing::ValuesIn(epsMode),
|
||||
@@ -97,7 +97,7 @@ INSTANTIATE_TEST_CASE_P(smoke_MVN_3D, Mvn6LayerTest,
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>{{1, 32, 17}, {1, 37, 9}}),
|
||||
::testing::ValuesIn(dataPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2}, {2}}),
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>{{1, 2}, {2}, {-1}, {-1, -2}}),
|
||||
::testing::ValuesIn(normalizeVariance),
|
||||
::testing::ValuesIn(epsilonF),
|
||||
::testing::ValuesIn(epsMode),
|
||||
|
||||
Reference in New Issue
Block a user