[LPT] isPrecisionPreserved/canBeTransformed/isQuantized: handling unexpected layers tests (#3139)

This commit is contained in:
Vladislav Golubev 2020-12-03 16:26:24 +03:00 committed by GitHub
parent c2e1f488e4
commit f2c2636bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 1 deletions

View File

@ -90,7 +90,7 @@ void AddTransformation::registerMatcherIn(GraphRewrite &pass, TransformationCont
bool AddTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
std::shared_ptr<opset1::Add> op = as_type_ptr<opset1::Add>(m.get_match_root());
if (!canBeTransformed(context, op)) {
if ((op == nullptr) || (!canBeTransformed(context, op))) {
return false;
}

View File

@ -46,6 +46,9 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
}
auto mvn = as_type_ptr<op::MVN>(operation);
if (mvn == nullptr) {
return false;
}
const std::shared_ptr<Node> multiply = mvn->get_input_node_shared_ptr(0);
auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));

View File

@ -0,0 +1,56 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "low_precision/transformer.hpp"
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
TEST(LPT, isPrecisionPreservedTransformation) {
const auto layer = std::make_shared<opset1::Parameter>(element::f32, Shape{ 1, 3, 16, 16 });
const auto transformations = low_precision::LowPrecisionTransformer::getAllTransformations();
for (const auto& transformation : transformations.transformations) {
ASSERT_NO_THROW(transformation.second->isPrecisionPreserved(layer));
}
}
TEST(LPT, canBeTransformedTransformation) {
const auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{ 1, 3, 16, 16 });
const auto mulConst = op::v0::Constant::create(element::f32, Shape{}, { 1.f });
const auto mul = std::make_shared<ngraph::opset1::Multiply>(input, mulConst);
const auto shapeConst = op::v0::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 1, 3, 16, 16 });
const auto layer = std::make_shared<opset1::Reshape>(mul, shapeConst, true);
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(layer) };
const auto function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "TestFunction");
const auto transformations = low_precision::LowPrecisionTransformer::getAllTransformations();
for (const auto& transformation : transformations.transformations) {
ASSERT_NO_THROW(transformation.second->canBeTransformed(low_precision::TransformationContext(function), layer));
}
}
TEST(LPT, isQuantizedTransformation) {
const auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{ 1, 3, 16, 16 });
const auto mulConst = op::v0::Constant::create(element::f32, Shape{}, { 1.f });
const auto mul = std::make_shared<ngraph::opset1::Multiply>(input, mulConst);
const auto shapeConst = op::v0::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 1, 3, 16, 16 });
const auto layer = std::make_shared<opset1::Reshape>(mul, shapeConst, true);
const auto transformations = low_precision::LowPrecisionTransformer::getAllTransformations();
for (const auto& transformation : transformations.transformations) {
ASSERT_NO_THROW(transformation.second->isQuantized(layer));
}
}