[LPT] matmul with qdq on weights tests (#4283)
This commit is contained in:
parent
c206830e7b
commit
a719534889
@ -19,6 +19,7 @@
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/constant.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -31,25 +32,25 @@ public:
|
||||
public:
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
ngraph::Shape weightsConstShape;
|
||||
std::vector<float> weightsConstValues;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationOnData;
|
||||
|
||||
ngraph::builder::subgraph::Constant weights;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationOnWeights;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
ngraph::element::Type weightsConstPrecision;
|
||||
ngraph::Shape weightsConstShape;
|
||||
std::vector<float> weightsConstValues;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationOnData;
|
||||
ngraph::builder::subgraph::Constant weights;
|
||||
|
||||
ngraph::element::Type precisionBeforeOperation;
|
||||
ngraph::builder::subgraph::DequantizationOperations resultDequantization;
|
||||
|
||||
ngraph::builder::subgraph::FakeQuantizeOnWeights fqOnWeights;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationOnWeights;
|
||||
};
|
||||
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
@ -61,18 +62,20 @@ inline std::ostream& operator << (std::ostream& out, const MatMullTransformation
|
||||
return out << "_" <<
|
||||
actual.inputShape << "_" <<
|
||||
actual.precisionBeforeDequantization << "_" <<
|
||||
actual.dequantization << "_" <<
|
||||
actual.weightsConstShape << "_" <<
|
||||
actual.fqOnWeights;
|
||||
actual.dequantizationOnData << "_" <<
|
||||
actual.weights.shape << "_" <<
|
||||
actual.fqOnWeights << "_" <<
|
||||
actual.dequantizationOnWeights;
|
||||
}
|
||||
|
||||
inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues::Expected& expected) {
|
||||
return out << "_" <<
|
||||
expected.weightsConstShape <<"_" <<
|
||||
expected.dequantization << "_" <<
|
||||
expected.weights.shape <<"_" <<
|
||||
expected.dequantizationOnData << "_" <<
|
||||
expected.precisionBeforeOperation << "_" <<
|
||||
expected.resultDequantization << "_" <<
|
||||
expected.fqOnWeights;
|
||||
expected.fqOnWeights << "_" <<
|
||||
expected.dequantizationOnWeights;
|
||||
}
|
||||
|
||||
inline std::ostream& operator << (std::ostream& out, const MatMullTransformationTestValues& values) {
|
||||
@ -94,37 +97,36 @@ public:
|
||||
testValues.actual.inputShape[0] = batch;
|
||||
testValues.expected.inputShape[0] = batch;
|
||||
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::MatMulFunction::getOriginal(
|
||||
precision,
|
||||
testValues.actual.inputShape,
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization,
|
||||
testValues.actual.weightsConstShape,
|
||||
testValues.actual.weightsConstValues,
|
||||
testValues.actual.fqOnWeights);
|
||||
testValues.actual.dequantizationOnData,
|
||||
testValues.actual.weights,
|
||||
testValues.actual.fqOnWeights,
|
||||
testValues.actual.dequantizationOnWeights);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::MatMulTransformation, ngraph::opset1::MatMul>(testValues.params);
|
||||
transformer.transform(actualFunction);
|
||||
|
||||
referenceFunction = testValues.expected.fqOnWeights.empty() ?
|
||||
referenceFunction = (testValues.expected.fqOnWeights.empty() && testValues.expected.dequantizationOnWeights.empty()) ?
|
||||
ngraph::builder::subgraph::MatMulFunction::getReference(
|
||||
precision,
|
||||
testValues.expected.inputShape,
|
||||
testValues.expected.precisionBeforeDequantization,
|
||||
testValues.expected.dequantization,
|
||||
testValues.expected.weightsConstPrecision,
|
||||
testValues.expected.weightsConstShape,
|
||||
testValues.expected.weightsConstValues,
|
||||
testValues.expected.dequantizationOnData,
|
||||
testValues.expected.weights,
|
||||
testValues.expected.resultDequantization) :
|
||||
ngraph::builder::subgraph::MatMulFunction::getOriginal(
|
||||
precision,
|
||||
testValues.expected.inputShape,
|
||||
testValues.expected.precisionBeforeDequantization,
|
||||
testValues.expected.dequantization,
|
||||
testValues.expected.weightsConstShape,
|
||||
testValues.expected.weightsConstValues,
|
||||
testValues.expected.fqOnWeights);
|
||||
testValues.expected.dequantizationOnData,
|
||||
testValues.expected.weights,
|
||||
testValues.expected.fqOnWeights,
|
||||
testValues.expected.dequantizationOnWeights);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<MatMulTransformationParams> obj) {
|
||||
@ -160,19 +162,41 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 384, 1024 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ 1024, 1024 },
|
||||
std::vector<float>(1024 * 1024, 1.f),
|
||||
{ std::vector<float>(1024 * 1024, 1.f), ngraph::element::f32, ngraph::Shape{ 1024, 1024 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 384, 1024 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::i8,
|
||||
{ 1024, 1024 },
|
||||
std::vector<float>(1024 * 1024, -126),
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
{ std::vector<float>(1024 * 1024, -126.f), ngraph::element::i8, ngraph::Shape{ 1024, 1024 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// supported 3D: U8 & I8 with Dq on weights
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 1, 384, 1024 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(1024 * 1024, 1.f), ngraph::element::i8, ngraph::Shape{ 1024, 1024 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.1f } },
|
||||
},
|
||||
{
|
||||
{ 1, 384, 1024 },
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ std::vector<float>(1024 * 1024, 1.f), ngraph::element::i8, ngraph::Shape{ 1024, 1024 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
@ -184,20 +208,42 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ std::vector<float>(4 * 4, -126.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { {0.001f, 0.002f, 0.003f} } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 3D: U8 & I8 with Dq on weights
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f} } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.1f } }
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::i8,
|
||||
{4, 4},
|
||||
std::vector<float>(4 * 4, -126.f),
|
||||
ngraph::element::f32,
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { {0.001f, 0.002f, 0.003f} } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
@ -208,8 +254,7 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{
|
||||
255,
|
||||
{ 1, 4 },
|
||||
@ -218,17 +263,40 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{-127.f, -12.7f, -1.27f , -0.127f},
|
||||
{127.f, 12.7f, 1.27f , 0.127f},
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::i8,
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, -126.f),
|
||||
ngraph::element::f32,
|
||||
{ std::vector<float>(4 * 4, -126.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {{ 0.02f, 0.002f, 0.0002f, 0.00002f }, ngraph::element::f32, ngraph::Shape{ 1, 1, 4 }}},
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 3D: U8 & I8 with Dq on weights with different values
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { {1.f, 0.1f, 0.01f, 0.001f} } }
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::i8, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {{ 0.02f, 0.002f, 0.0002f, 0.00002f }, ngraph::element::f32, ngraph::Shape{ 1, 1, 4 }}},
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
@ -239,20 +307,19 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f, 0.01f}, ngraph::element::f32, ngraph::Shape{1, 1, 4} } },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f, 0.01f}, ngraph::element::f32, ngraph::Shape{1, 1, 4} } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::f32,
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {}},
|
||||
{},
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
@ -263,8 +330,7 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{
|
||||
255,
|
||||
{ 4, 1 },
|
||||
@ -273,16 +339,15 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{-127.f, -12.7f, -1.27f , -0.127f},
|
||||
{127.f, 12.7f, 1.27f , 0.127f},
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::f32,
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 1.f),
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{
|
||||
255,
|
||||
{ 4, 1 },
|
||||
@ -291,6 +356,30 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{-127.f, -12.7f, -1.27f , -0.127f},
|
||||
{127.f, 12.7f, 1.27f , 0.127f},
|
||||
},
|
||||
{}
|
||||
},
|
||||
},
|
||||
|
||||
// U8 & I8: dequantization by rows in second input: can't be transformed (Dq on weights)
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f, 0.01f}, ngraph::element::f32, ngraph::Shape{4, 1} } },
|
||||
},
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(4 * 4, 1.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { {0.01f, 0.02f, 0.03f, 0.01f}, ngraph::element::f32, ngraph::Shape{4, 1} } },
|
||||
},
|
||||
},
|
||||
|
||||
@ -301,22 +390,45 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 2048 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ 2048, 1000 },
|
||||
std::vector<float>(2048 * 1000, 1.f),
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::i8,
|
||||
{2048, 1000},
|
||||
std::vector<float>(2048 * 1000, -126),
|
||||
ngraph::element::i8,
|
||||
{ std::vector<float>(2048 * 1000, -126.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 2D: U8 & I8 with Dq on weights
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.2f } },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.2f } }
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, {} },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { 0.2f * 0.2f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 2D: I8 & I8
|
||||
{
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
@ -324,42 +436,110 @@ std::vector<MatMullTransformationTestValues> testValues = {
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ 2048, 1000 },
|
||||
std::vector<float>(2048 * 1000, 1.f),
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::i8,
|
||||
{2048, 1000},
|
||||
std::vector<float>(2048 * 1000, -126),
|
||||
{ std::vector<float>(2048 * 1000, -126.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
// 2D: FP32 & FP328
|
||||
|
||||
// 2D: I8 & I8 with Dq on weights with small subtract values
|
||||
{
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
{},
|
||||
{ ngraph::element::f32, { 1e-7f }, { 0.02f } }
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, {} },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, { 0.02f * 0.02f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 2D: I8 & I8 with Dq on weights with zero subtract values
|
||||
{
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ ngraph::element::f32, {}, { 0.02f } },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
{},
|
||||
{ ngraph::element::f32, { 0.f }, { 0.02f } }
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, {} },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::i8, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::i8,
|
||||
{ {}, {}, { 0.02f * 0.02f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 2D: FP32 & FP32
|
||||
{
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 0.02f } },
|
||||
{ 2048, 1000 },
|
||||
std::vector<float>(2048 * 1000, 1.f),
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
{ 255, { 1, 1 }, {0.f}, {254.f}, {-12.7f}, {12.7} },
|
||||
{}
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, {} },
|
||||
ngraph::element::f32,
|
||||
{2048, 1000},
|
||||
std::vector<float>(2048 * 1000, -126),
|
||||
{ std::vector<float>(2048 * 1000, -126.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 0.02f * 0.1f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
||||
// 2D: FP32 & FP32 with Dq on weights
|
||||
{
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 0.02f } },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.02f } }
|
||||
},
|
||||
{
|
||||
{ 1, 2048 },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, {} },
|
||||
{ std::vector<float>(2048 * 1000, 1.f), ngraph::element::f32, ngraph::Shape{ 2048, 1000 } },
|
||||
ngraph::element::f32,
|
||||
{ {}, {}, { 0.02f * 0.02f } },
|
||||
{},
|
||||
{}
|
||||
}
|
||||
},
|
||||
|
@ -18,9 +18,19 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 2, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ 2, 4 },
|
||||
std::vector<float>(4 * 2, 2.f),
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
// 3D with dequantize on weights
|
||||
{
|
||||
{ 2, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::i8, ngraph::Shape{ 2, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {0.1f} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
@ -28,9 +38,9 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, {-10.5f}, {4.5f}, {-10.5f}, {4.5f} },
|
||||
{ 2, 4 },
|
||||
std::vector<float>(4 * 2, 2.f),
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
@ -38,9 +48,19 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 1, 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ 2, 4 },
|
||||
std::vector<float>(4 * 2, 2.f),
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
// 4D with Dq on weights
|
||||
{
|
||||
{ 1, 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::i8, ngraph::Shape{ 2, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {{0.1f, 0.01}, ngraph::element::f32, ngraph::Shape{ 2, 1 }} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
@ -48,9 +68,9 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {255.f}, {0.f}, {25.5f} },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 2.f),
|
||||
{ std::vector<float>(4 * 4, 2.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {-128.f}, {127.f}, {-128.f}, {127.f} },
|
||||
{ {}, {}, {} },
|
||||
"matMul/FC",
|
||||
"U8"
|
||||
},
|
||||
@ -58,9 +78,19 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
{
|
||||
{ 2, 3 },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-10.f}, {5.f}, {-10.f, -5.f}, {5.f, 5.f} },
|
||||
{ 2, 3 },
|
||||
std::vector<float>{1, 2, 3, 4, 5, 6},
|
||||
{ std::vector<float>{1, 2, 3, 4, 5, 6}, ngraph::element::f32, ngraph::Shape{ 2, 3 } },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {-128.f}, {127.f}, {-12.8f}, {12.7f} },
|
||||
{ {}, {}, {} },
|
||||
"matMul/1",
|
||||
"U8"
|
||||
},
|
||||
// 2D with subtract on activations & Dq on weights
|
||||
{
|
||||
{ 2, 3 },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-10.f}, {5.f}, {-10.f, -5.f}, {5.f, 5.f} },
|
||||
{ std::vector<float>{1, 2, 3, 4, 5, 6}, ngraph::element::i8, ngraph::Shape{ 2, 3 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {0.1f} },
|
||||
"matMul/1",
|
||||
"U8"
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -14,38 +14,62 @@ const std::vector<ngraph::element::Type> precisions = { ngraph::element::f32 };
|
||||
|
||||
//transpose_a = false, transpose_b = true
|
||||
std::vector<MatMulWithConstantTransformationTestValues> testValues = {
|
||||
// 3D with different values
|
||||
{
|
||||
{ 2, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ 2, 4 },
|
||||
std::vector<float>(4 * 2, 2.f),
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
},
|
||||
{
|
||||
{ 2, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::i8, ngraph::Shape{ 2, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {0.1f} },
|
||||
},
|
||||
// 3D with different values
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, {-10.5f}, {4.5f}, {-10.5f}, {4.5f} },
|
||||
{ 2, 4 },
|
||||
std::vector<float>(4 * 2, 2.f),
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
},
|
||||
{
|
||||
{ 1, 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
|
||||
{ {}, {}, {} },
|
||||
},
|
||||
{
|
||||
{ 1, 1, 3, 4 },
|
||||
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
|
||||
{ std::vector<float>(4 * 2, 2.f), ngraph::element::i8, ngraph::Shape{ 2, 4 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {{0.1f, 0.01}, ngraph::element::f32, ngraph::Shape{ 2, 1 }} },
|
||||
},
|
||||
// 3D with the same values
|
||||
{
|
||||
{ 1, 3, 4 },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {255.f}, {0.f}, {25.5f} },
|
||||
{ 4, 4 },
|
||||
std::vector<float>(4 * 4, 2.f),
|
||||
{ std::vector<float>(4 * 4, 2.f), ngraph::element::f32, ngraph::Shape{ 4, 4 } },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {-128.f}, {127.f}, {-128.f}, {127.f} },
|
||||
{ {}, {}, {} },
|
||||
},
|
||||
// 2D with subtract on activations
|
||||
{
|
||||
{ 2, 3 },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-10.f}, {5.f}, {-10.f, -5.f}, {5.f, 5.f} },
|
||||
{ 2, 3 },
|
||||
std::vector<float>{1, 2, 3, 4, 5, 6},
|
||||
{ std::vector<float>{1, 2, 3, 4, 5, 6}, ngraph::element::f32, ngraph::Shape{ 2, 3 } },
|
||||
{ 256ul, {{1}, {1}, {1}, {1}}, {-128.f}, {127.f}, {-12.8f}, {12.7f} },
|
||||
{ {}, {}, {} },
|
||||
},
|
||||
{
|
||||
{ 2, 3 },
|
||||
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-10.f}, {5.f}, {-10.f, -5.f}, {5.f, 5.f} },
|
||||
{ std::vector<float>{1, 2, 3, 4, 5, 6}, ngraph::element::i8, ngraph::Shape{ 2, 3 } },
|
||||
{},
|
||||
{ ngraph::element::f32, {}, {0.1f} },
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_LPT, MatMulWithConstantTransformation,
|
||||
|
@ -9,6 +9,9 @@
|
||||
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp"
|
||||
#include "lpt_ngraph_functions/common/constant.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
#include "lpt_ngraph_functions/mat_mul_function.hpp"
|
||||
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
|
||||
|
||||
@ -18,9 +21,11 @@ class MatMulWithConstantTransformationTestValues {
|
||||
public:
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fqOnData;
|
||||
ngraph::Shape weightsConstShape;
|
||||
std::vector<float> weightsConstValues;
|
||||
|
||||
ngraph::builder::subgraph::Constant weights;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fqOnWeights;
|
||||
ngraph::builder::subgraph::DequantizationOperations deqOnWeights;
|
||||
|
||||
std::string layerName;
|
||||
std::string expectedKernelType;
|
||||
};
|
||||
|
@ -31,7 +31,8 @@ std::string MatMulWithConstantTransformation::getTestCaseName(testing::TestParam
|
||||
precision << "_" <<
|
||||
targetDevice << "_" <<
|
||||
testValues.fqOnData << "_" <<
|
||||
testValues.fqOnWeights;
|
||||
testValues.fqOnWeights << "_" <<
|
||||
testValues.deqOnWeights;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
@ -65,12 +66,15 @@ void MatMulWithConstantTransformation::SetUp() {
|
||||
precision,
|
||||
testValues.inputShape,
|
||||
testValues.fqOnData,
|
||||
testValues.weightsConstShape,
|
||||
testValues.weightsConstValues,
|
||||
testValues.fqOnWeights);
|
||||
testValues.weights,
|
||||
testValues.fqOnWeights,
|
||||
testValues.deqOnWeights);
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(function);
|
||||
|
||||
if (testValues.deqOnWeights.empty()) {
|
||||
validate();
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulWithConstantTransformation::validate() {
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp"
|
||||
#include "lpt_ngraph_functions/common/constant.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
@ -49,10 +50,10 @@ public:
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const DequantizationOperations& dequantization,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const FakeQuantizeOnWeights& fqOnWeights);
|
||||
const DequantizationOperations& deqOnData,
|
||||
const Constant& weights,
|
||||
const FakeQuantizeOnWeights& fqOnWeights,
|
||||
const DequantizationOperations& deqOnWeights);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type precision,
|
||||
@ -69,18 +70,16 @@ public:
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const DequantizationOperations& dequantization,
|
||||
const ngraph::element::Type weightsConstPrecision,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const Constant& weights,
|
||||
const DequantizationOperations& resultDequantization);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnWeights);
|
||||
const Constant& weights,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnWeights,
|
||||
const DequantizationOperations& deqOnWeights);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
|
@ -168,39 +168,40 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const DequantizationOperations& dequantizationOperations,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const FakeQuantizeOnWeights& fqOnWeights) {
|
||||
const DequantizationOperations& deqOnData,
|
||||
const Constant& weights,
|
||||
const FakeQuantizeOnWeights& fqOnWeights,
|
||||
const DequantizationOperations& deqOnWeights) {
|
||||
const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
precisionBeforeDequantization,
|
||||
inputShape);
|
||||
input->set_friendly_name("input1");
|
||||
|
||||
auto lastDequantization = makeDequantization(input, dequantizationOperations);
|
||||
const auto dequantizationOnData = makeDequantization(input, deqOnData);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
|
||||
precision,
|
||||
weightsConstShape,
|
||||
weightsConstValues);
|
||||
const std::shared_ptr<ngraph::Node> weightsConst = std::make_shared<ngraph::opset1::Constant>(
|
||||
weights.outPrecision,
|
||||
weights.shape,
|
||||
weights.values);
|
||||
|
||||
auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
|
||||
const std::shared_ptr<ngraph::Node> fakeQuantize = fqOnWeights.empty() ? nullptr : makeFakeQuantize(weightsConst, precision, fqOnWeights);
|
||||
const auto dequantizationOnWeights = makeDequantization(fakeQuantize == nullptr ? weightsConst : fakeQuantize, deqOnWeights);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
|
||||
lastDequantization,
|
||||
fakeQuantize,
|
||||
const auto matMul = std::make_shared<ngraph::opset1::MatMul>(
|
||||
dequantizationOnData,
|
||||
dequantizationOnWeights,
|
||||
false,
|
||||
false);
|
||||
matMul->set_friendly_name("matMul");
|
||||
auto& rtInfo = matMul->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
|
||||
|
||||
const auto result = std::make_shared<ngraph::opset1::Result>(matMul);
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ result },
|
||||
std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
|
||||
"MatMulTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
@ -262,9 +263,7 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const DequantizationOperations& dequantization,
|
||||
const ngraph::element::Type weightsConstPrecision,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const Constant& weights,
|
||||
const DequantizationOperations& resultDequantization) {
|
||||
const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
precisionBeforeDequantization,
|
||||
@ -274,9 +273,9 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
|
||||
const std::shared_ptr<ngraph::Node> lastDequantizationBefore = makeDequantization(input, dequantization);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
|
||||
weightsConstPrecision,
|
||||
weightsConstShape,
|
||||
weightsConstValues);
|
||||
weights.outPrecision,
|
||||
weights.shape,
|
||||
weights.values);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
|
||||
std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
|
||||
@ -305,36 +304,35 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData,
|
||||
const ngraph::Shape& weightsConstShape,
|
||||
const std::vector<float>& weightsConstValues,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnWeights) {
|
||||
const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
|
||||
precision,
|
||||
inputShape);
|
||||
const Constant& weights,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnWeights,
|
||||
const DequantizationOperations& deqOnWeights) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input->set_friendly_name("input1");
|
||||
|
||||
auto lastDequantization = makeFakeQuantize(input, precision, fqOnData);
|
||||
const auto dequantizationOnData = makeFakeQuantize(input, precision, fqOnData);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
|
||||
precision,
|
||||
weightsConstShape,
|
||||
weightsConstValues);
|
||||
const std::shared_ptr<ngraph::Node> weightsConst = std::make_shared<ngraph::opset1::Constant>(
|
||||
weights.outPrecision,
|
||||
weights.shape,
|
||||
weights.values);
|
||||
|
||||
auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
|
||||
const std::shared_ptr<ngraph::Node> fakeQuantize = fqOnWeights.empty() ? nullptr : makeFakeQuantize(weightsConst, precision, fqOnWeights);
|
||||
const auto dequantizationOnWeights = makeDequantization(fakeQuantize == nullptr ? weightsConst : fakeQuantize, deqOnWeights);
|
||||
|
||||
const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
|
||||
lastDequantization,
|
||||
fakeQuantize,
|
||||
dequantizationOnData,
|
||||
dequantizationOnWeights,
|
||||
false,
|
||||
true);
|
||||
matMul->set_friendly_name("matMul");
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
|
||||
|
||||
const auto result = std::make_shared<ngraph::opset1::Result>(matMul);
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ result },
|
||||
std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
|
||||
"MatMulTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user