[LPT] AddTransformation fix (#5022)
* [LPT] AddTransformation fix * [LPT][TESTS] FunctionalTests: added test-case with AddTransformation without convert
This commit is contained in:
parent
bf1b7ef19c
commit
d7a2d13152
@ -47,13 +47,13 @@ std::shared_ptr<opset1::Subtract> replaceToSubtract(const std::shared_ptr<Node>&
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto constant = fold<opset1::Negative>(add->get_input_node_shared_ptr(constBranchIndex));
|
||||
auto constant = fold<opset1::Negative>(add->input_value(constBranchIndex));
|
||||
auto constOutput = constant->output(0);
|
||||
|
||||
const auto subtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ op->get_output_element_type(0) },
|
||||
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(dataBranchIndex), element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(add->input_value(dataBranchIndex), element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(constOutput, element::f32).get(),
|
||||
add->get_autob());
|
||||
|
||||
@ -73,13 +73,13 @@ std::shared_ptr<opset1::Subtract> fuseWithSubtract(const std::shared_ptr<Node>&
|
||||
}
|
||||
|
||||
const auto newSubConst = fold<opset1::Subtract>(
|
||||
add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(1),
|
||||
add->get_input_node_shared_ptr(1));
|
||||
add->get_input_node_shared_ptr(0)->input_value(1),
|
||||
add->input_value(1));
|
||||
|
||||
const auto newSubtract = std::make_shared<op::TypeRelaxed<DequantizationSubtract>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ op->get_output_element_type(0) },
|
||||
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0), element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(add->get_input_node_shared_ptr(0)->input_value(0), element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(newSubConst, element::f32).get());
|
||||
NetworkHelper::copyInfo(add, newSubtract);
|
||||
|
||||
@ -178,7 +178,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
|
||||
}
|
||||
|
||||
// graph update
|
||||
std::vector<Output<Node>> inputs{ {}, {} };
|
||||
OutputVector inputs{ {}, {} };
|
||||
auto fullPathInput = dequantizationFullPath.convert == nullptr ? dequantizationFullPath.data : dequantizationFullPath.convert;
|
||||
|
||||
inputs[emptyPathIndex] = dequantizationEmptyPath.data;
|
||||
|
@ -189,7 +189,7 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
|
||||
if (multiplyConst == nullptr)
|
||||
return addAfterMultiply;
|
||||
|
||||
const auto x = multiply->get_input_node_shared_ptr(multiplyInputBranch);
|
||||
const auto x = multiply->get_input_source_output(multiplyInputBranch);
|
||||
auto a = multiply->get_input_node_shared_ptr(multiplyInputBranch == 0 ? 1 : 0);
|
||||
auto b = addAfterMultiply->get_input_node_shared_ptr(multiplyBranch == 0 ? 1 : 0);
|
||||
std::shared_ptr<Node> bDivA;
|
||||
@ -228,14 +228,13 @@ std::shared_ptr<Node> NetworkHelper::swapMultiplyAndAdd(std::shared_ptr<opset1::
|
||||
bDivA = fold<opset1::Convert>(bDivA, a->get_output_element_type(0));
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
|
||||
|
||||
OutputVector inputs{ {}, {} };
|
||||
inputs[0] = x;
|
||||
inputs[1] = bDivA;
|
||||
|
||||
std::shared_ptr<opset1::Add> newAdd = std::make_shared<op::TypeRelaxed<opset1::Add>>(
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{ x->get_output_element_type(0) },
|
||||
std::vector<element::Type>{ x.get_element_type() },
|
||||
ngraph::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
|
||||
copyInfo(addAfterMultiply, newAdd);
|
||||
|
@ -57,7 +57,7 @@ SimpleLowPrecisionTransformer getTransformerWithTransformationByName(
|
||||
using namespace pass::low_precision;
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
|
||||
if (name == "AddTransformation") {
|
||||
if (name == "AddTransformationWithoutConcat" || name == "AddTransformationWithConcat") {
|
||||
transformer.add<AddTransformation, ngraph::opset1::Add>(params);
|
||||
return transformer;
|
||||
}
|
||||
@ -185,7 +185,8 @@ TEST_P(TransformationsAfterSplitTransformation, Run) {
|
||||
}
|
||||
|
||||
const std::vector<std::string> transformationNames = {
|
||||
"AddTransformation",
|
||||
"AddTransformationWithoutConcat",
|
||||
"AddTransformationWithConcat",
|
||||
"AvgPoolTransformation",
|
||||
"ClampTransformation",
|
||||
"ConvolutionTransformation",
|
||||
|
@ -41,7 +41,12 @@ std::shared_ptr<Function> TransformationsAfterSplitFunction::get(const std::stri
|
||||
std::shared_ptr<Node> TransformationsAfterSplitFunction::getLayerByTransformationName(
|
||||
const std::string transformationName,
|
||||
const Output<Node> parent) {
|
||||
if (transformationName == "AddTransformation") {
|
||||
if (transformationName == "AddTransformationWithoutConcat") {
|
||||
const auto dequantization = makeDequantization(parent, { {}, {}, { 3.f } });
|
||||
const auto addConstant = opset1::Constant::create(element::u8, Shape{}, { 128.f });
|
||||
return std::make_shared<opset1::Add>(dequantization, addConstant);
|
||||
}
|
||||
if (transformationName == "AddTransformationWithConcat") {
|
||||
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
|
||||
const auto addConstant = opset1::Constant::create(element::f32, Shape{}, { 128.f });
|
||||
return std::make_shared<opset1::Add>(dequantization, addConstant);
|
||||
|
Loading…
Reference in New Issue
Block a user