Fix Add/MulFQFusion transformations (#10252)

This commit is contained in:
Gleb Kazantaev
2022-02-10 01:22:16 +03:00
committed by GitHub
parent 36afedd93d
commit 87c6e09cae
7 changed files with 50 additions and 9 deletions

View File

@@ -167,7 +167,7 @@ TRANSFORMATIONS_API bool constantIsEqualTo(const std::shared_ptr<ngraph::op::Con
TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr<const ngraph::Function> &function);
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape);
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::PartialShape &ref_shape, const ngraph::PartialShape &other_shape);
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name,
const ngraph::Output<ngraph::Node>& apply_to);

View File

@@ -41,8 +41,16 @@ ngraph::pass::AddFakeQuantizeFusion::AddFakeQuantizeFusion() {
auto add_const = std::dynamic_pointer_cast<opset5::Constant>(pattern_value_map.at(const_pattern).get_node_shared_ptr());
if (!add_const)
return false;
std::shared_ptr<Node> new_const = add_const;
auto const_shape = add_const->get_shape();
if (ngraph::op::util::check_for_broadcast(input.get_partial_shape(), const_shape)) {
// We can't eliminate Add if Constant input broadcasts another input shape because
// when we reconnect input from Add to FQ won't broadcast given input, so it will result
// in shape collision.
return false;
}
std::shared_ptr<Node> new_const = add_const;
size_t const_shape_size = shape_size(const_shape);
bool is_single_value = const_shape_size == 1;

View File

@@ -41,12 +41,19 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
if (!mul_const)
return false;
auto const_shape = mul_const->get_shape();
if (ngraph::op::util::check_for_broadcast(input.get_partial_shape(), const_shape)) {
// We can't eliminate Multiply if Constant input broadcasts another input shape because
// when we reconnect input from Multiply to FQ won't broadcast given input, so it will result
// in shape collision.
return false;
}
auto mul_const_value = mul_const->cast_vector<float>();
if (std::any_of(mul_const_value.begin(), mul_const_value.end(), [] (float f) -> bool { return f <= 0.0f; }))
return false;
std::shared_ptr<Node> new_const = mul_const;
auto const_shape = mul_const->get_shape();
size_t const_shape_size = shape_size(const_shape);
bool is_single_value = const_shape_size == 1;

View File

@@ -199,8 +199,7 @@ bool ngraph::pass::GroupedStridedSliceOptimizer::run_on_model(const std::shared_
if (!valid_for_replacement) break;
}
if (!valid_for_replacement) continue;
if (output_to_partition.size() < 2) continue;
if (!valid_for_replacement || output_to_partition.size() < 2 || axis == -1) continue;
std::sort(output_to_partition.begin(), output_to_partition.end(),
[](OutputToPatrition lhs, OutputToPatrition rhs)

View File

@@ -90,18 +90,19 @@ bool has_f16_constants(const std::shared_ptr<const ngraph::Function> &function)
return false;
}
bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape) {
bool check_for_broadcast(const ngraph::PartialShape &ref_shape, const ngraph::PartialShape &other_shape) {
// Check that other_shape doesn't broadcast ref_shape
if (other_shape.size() > ref_shape.size()) {
if (ref_shape.rank().is_dynamic() || other_shape.rank().is_dynamic() || other_shape.size() > ref_shape.size()) {
return true;
}
auto ref_it = ref_shape.rbegin();
auto other_it = other_shape.rbegin();
// Check that other_shape dims are equal to ref_shape dims
// In case if other_shape rank is less than ref_shape rank
// we stop comparision and return true
// we stop comparison and return true
while (other_it != other_shape.rend()) {
if (*other_it != *ref_it && *other_it != 1) {
if ((other_it->is_dynamic() || other_it->get_length() != 1) &&
(ref_it->is_dynamic() || ref_it->get_length() == 1)) {
return true;
}
++other_it;

View File

@@ -264,3 +264,16 @@ TEST_F(TransformationTestsF, NegativeAddFakeQuantizeFusionWithNonPerChannelConst
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
manager.register_pass<pass::AddFakeQuantizeFusion>();
}
TEST_F(TransformationTestsF, AddFakeQuantizeFusionWithBroadcastingConstant) {
auto data = std::make_shared<opset5::Parameter>(element::f32, ov::PartialShape{DYN, 3});
auto add_const = opset5::Constant::create(element::f32, Shape{3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto add = std::make_shared<opset5::Add>(data, add_const);
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
auto input_high = opset5::Constant::create(element::f32, Shape{1}, {20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 11);
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
manager.register_pass<pass::AddFakeQuantizeFusion>();
}

View File

@@ -244,3 +244,16 @@ TEST_F(TransformationTestsF, NegativeMulFakeQuantizeFusionWithNonPerChannelConst
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
manager.register_pass<pass::MulFakeQuantizeFusion>();
}
TEST_F(TransformationTestsF, MulFakeQuantizeFusionWithBroadcastingConstant) {
auto data = std::make_shared<opset5::Parameter>(element::f32, ov::PartialShape{DYN, 3});
auto mul_const = opset5::Constant::create(element::f32, Shape{3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto mul = std::make_shared<opset5::Multiply>(data, mul_const);
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {1});
auto input_high = opset5::Constant::create(element::f32, Shape{1}, {20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {0});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(mul, input_low, input_high, output_low, output_high, 11);
function = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
manager.register_pass<pass::MulFakeQuantizeFusion>();
}