[LPT] AddTransformation fix (#5022)

* [LPT] AddTransformation fix

* [LPT][TESTS] FunctionalTests: added test-case with AddTransformation without convert
This commit is contained in:
Vladislav Golubev 2021-04-02 12:41:53 +03:00 committed by GitHub
parent bf1b7ef19c
commit d7a2d13152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 13 deletions

View File

@ -47,13 +47,13 @@ std::shared_ptr<opset1::Subtract> replaceToSubtract(const std::shared_ptr<Node>&
return nullptr;
}
auto constant = fold<opset1::Negative>(add->get_input_node_shared_ptr(constBranchIndex));
auto constant = fold<opset1::Negative>(add->input_value(constBranchIndex));
auto constOutput = constant->output(0);
const auto subtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ op->get_output_element_type(0) },
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(dataBranchIndex), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(add->input_value(dataBranchIndex), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(constOutput, element::f32).get(),
add->get_autob());
@ -73,13 +73,13 @@ std::shared_ptr<opset1::Subtract> fuseWithSubtract(const std::shared_ptr<Node>&
}
const auto newSubConst = fold<opset1::Subtract>(
add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1),
add->get_input_node_shared_ptr(1));
add->get_input_node_shared_ptr(0)->input_value(1),
add->input_value(1));
const auto newSubtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ op->get_output_element_type(0) },
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->input_value(0), element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(newSubConst, element::f32).get());
NetworkHelper::copyInfo(add, newSubtract);
@ -178,7 +178,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
}
// graph update
std::vector<Output<Node>> inputs{ {}, {} };
OutputVector inputs{ {}, {} };
auto fullPathInput = dequantizationFullPath.convert == nullptr ? dequantizationFullPath.data : dequantizationFullPath.convert;
inputs[emptyPathIndex] = dequantizationEmptyPath.data;

View File

@ -189,7 +189,7 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
if (multiplyConst == nullptr)
return addAfterMultiply;
const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
const auto x = multiply->get_input_source_output(multiplyInputBranch);
auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
std::shared_ptr<Node> bDivA;
@ -228,14 +228,13 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
bDivA = fold<opset1::Convert>(bDivA, a->get_output_element_type(0));
}
std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
OutputVector inputs{ {}, {} };
inputs[0] = x;
inputs[1] = bDivA;
std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ x->get_output_element_type(0) },
std::vector<element::Type>{ x.get_element_type() },
ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
copyInfo(addAfterMultiply, newAdd);

View File

@ -57,7 +57,7 @@ SimpleLowPrecisionTransformer getTransformerWithTransformationByName(
using namespace pass::low_precision;
SimpleLowPrecisionTransformer transformer;
if (name == "AddTransformation") {
if (name == "AddTransformationWithoutConcat" || name == "AddTransformationWithConcat") {
transformer.add<AddTransformation, ngraph::opset1::Add>(params);
return transformer;
}
@ -185,7 +185,8 @@ TEST_P(TransformationsAfterSplitTransformation, Run) {
}
const std::vector<std::string> transformationNames = {
"AddTransformation",
"AddTransformationWithoutConcat",
"AddTransformationWithConcat",
"AvgPoolTransformation",
"ClampTransformation",
"ConvolutionTransformation",

View File

@ -41,7 +41,12 @@ std::shared_ptr<Function> TransformationsAfterSplitFunction::get(const std::stri
std::shared_ptr<Node> TransformationsAfterSplitFunction::getLayerByTransformationName(
const std::string transformationName,
const Output<Node> parent) {
if (transformationName == "AddTransformation") {
if (transformationName == "AddTransformationWithoutConcat") {
const auto dequantization = makeDequantization(parent, { {}, {}, { 3.f } });
const auto addConstant = opset1::Constant::create(element::u8, Shape{}, { 128.f });
return std::make_shared<opset1::Add>(dequantization, addConstant);
}
if (transformationName == "AddTransformationWithConcat") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto addConstant = opset1::Constant::create(element::f32, Shape{}, { 128.f });
return std::make_shared<opset1::Add>(dequantization, addConstant);