[LPT] foldFakeQuantize extending to support empty shapes (#10116)

This commit is contained in:
Edward Shogulin 2022-02-03 23:01:27 +03:00 committed by GitHub
parent 64aabc74d1
commit e8b88b9021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 67 additions and 30 deletions

View File

@ -746,7 +746,7 @@ std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
assert(constPShape.is_static());
const Shape constShape = constPShape.to_shape();
if (constShape.empty() || constShape.size() > 5lu) {
if (constShape.size() > 5lu) {
THROW_IE_LPT_EXCEPTION(*fq) << "Unexpected dimensions count " << constShape.size();
}
if (outChannelsShapeIndex != 0 && outChannelsShapeIndex != 1) {
@ -756,8 +756,8 @@ std::shared_ptr<Node> NetworkHelper::foldFakeQuantize(
size_t OC;
size_t IC;
// OIDHW or IODHW
if (constShape.size() == 1) {
OC = constShape[0];
if (constShape.size() <= 1) {
OC = constShape.empty() ? 1ul : constShape[0];
IC = 1;
} else {
OC = constShape[outChannelsShapeIndex];

View File

@ -16,13 +16,23 @@ const std::vector<ngraph::element::Type> netPrecisions = {
};
const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
{
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
false,
{ 256ul, ngraph::Shape {}, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::f32,
true
},
{
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::i8
ngraph::element::i8,
false
},
{
false,
@ -30,7 +40,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::u8
ngraph::element::u8,
false
},
{
true,
@ -38,7 +49,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::u8
ngraph::element::u8,
false
},
{
true,
@ -46,7 +58,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::i8
ngraph::element::i8,
false
},
{
false,
@ -54,7 +67,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
true,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::i8
ngraph::element::i8,
false
},
{
false,
@ -62,7 +76,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::u8
ngraph::element::u8,
false
},
{
false,
@ -70,10 +85,11 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
true,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.27f }, { 1.28f }, { -1.27f }, { 1.28f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::u8
ngraph::element::u8,
false
},
{ false, {}, false, {}, {}, ngraph::element::f32 },
{ true, {}, true, {}, {}, ngraph::element::f32 },
{ false, {}, false, {}, {}, ngraph::element::f32, false },
{ true, {}, true, {}, {}, ngraph::element::f32, false },
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, MultiplyTransformation,

View File

@ -22,7 +22,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::undefined // ngraph::element::i8
ngraph::element::undefined, // ngraph::element::i8
false
},
{
false,
@ -30,7 +31,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::undefined // ngraph::element::u8
ngraph::element::undefined, // ngraph::element::u8
false
},
{
true,
@ -38,7 +40,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::undefined //ngraph::element::u8
ngraph::element::undefined, //ngraph::element::u8
false
},
{
true,
@ -46,7 +49,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::undefined // ngraph::element::i8
ngraph::element::undefined, // ngraph::element::i8
false
},
{
false,
@ -54,7 +58,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
true,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.28f }, { 1.27f }, { -1.28f }, { 1.27f } },
ngraph::element::undefined // ngraph::element::i8
ngraph::element::undefined, // ngraph::element::i8
false
},
{
false,
@ -62,7 +67,8 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
false,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::undefined // ngraph::element::u8
ngraph::element::undefined, // ngraph::element::u8
false
},
{
false,
@ -70,10 +76,11 @@ const std::vector<LayerTestsDefinitions::MultiplyTestValues> params = {
true,
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { -1.27f }, { 1.28f }, { -1.27f }, { 1.28f } },
{ 256ul, ngraph::Shape { 1, 1, 1, 1 }, { 0.f }, { 2.55f }, { 0.f }, { 2.55f } },
ngraph::element::undefined // ngraph::element::u8
ngraph::element::undefined, // ngraph::element::u8
false
},
{ false, {}, false, {}, {}, ngraph::element::undefined /* ngraph::element::f32 */ },
{ true, {}, true, {}, {}, ngraph::element::undefined /* ngraph::element::f32 */ },
{ false, {}, false, {}, {}, ngraph::element::undefined /* ngraph::element::f32 */, false },
{ true, {}, true, {}, {}, ngraph::element::undefined /* ngraph::element::f32 */, false },
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, MultiplyTransformation,

View File

@ -20,6 +20,7 @@ public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantizeAfter;
ngraph::element::Type expectedPrecisions;
bool secondInputIsConstant;
};
typedef std::tuple<

View File

@ -46,6 +46,7 @@ std::string MultiplyTransformation::getTestCaseName(const testing::TestParamInfo
param.fakeQuantize2.outputLowValues[0] << "_" <<
param.fakeQuantize2.outputHighValues[0];
}
result << "_" << param.secondInputIsConstant;
return result.str();
}
@ -62,7 +63,8 @@ void MultiplyTransformation::SetUp() {
param.fakeQuantize1,
param.broadcast2,
param.fakeQuantize2,
param.fakeQuantizeAfter);
param.fakeQuantizeAfter,
param.secondInputIsConstant);
ngraph::pass::InitNodeInfo().run_on_function(function);
}

View File

@ -51,7 +51,8 @@ public:
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter);
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant = false);
};
} // namespace subgraph

View File

@ -94,17 +94,23 @@ std::shared_ptr<ngraph::Function> MultiplyFunction::getOriginal(
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter) {
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant) {
auto inputShape1 = inputShape;
if (broadcast1) {
inputShape1[2] = 1;
inputShape1[3] = 1;
}
auto inputShape2 = inputShape;
if (broadcast2) {
inputShape2[2] = 1;
inputShape2[3] = 1;
ngraph::PartialShape inputShape2;
if (secondInputIsConstant) {
inputShape2 = {};
} else {
inputShape2 = inputShape;
if (broadcast2) {
inputShape2[2] = 1;
inputShape2[3] = 1;
}
}
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
@ -117,7 +123,9 @@ std::shared_ptr<ngraph::Function> MultiplyFunction::getOriginal(
fakeQuantize1->set_friendly_name("fakeQuantize1");
}
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
const std::shared_ptr<ngraph::Node> input2 = secondInputIsConstant ?
makeConstant(element::f32, Shape{}, std::vector<float>{0.5f}, false) :
std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
const auto fakeQuantize2 = fq2.empty() ?
nullptr :
ngraph::builder::makeFakeQuantize(
@ -143,7 +151,9 @@ std::shared_ptr<ngraph::Function> MultiplyFunction::getOriginal(
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(result) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
secondInputIsConstant ?
ngraph::ParameterVector{ input1 } :
ngraph::ParameterVector{ input1, ngraph::as_type_ptr<ngraph::opset1::Parameter>(input2) },
"MultiplyTransformation");
return function;