Disallow SoftmaxFusion when input rank is greater than 5 (#5028)

This commit is contained in:
Mateusz Tabaka 2021-03-30 14:43:35 +02:00 committed by GitHub
parent 83ec2d321a
commit acf778d655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 0 deletions

View File

@ -32,6 +32,7 @@
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/softmax_fusion.hpp"
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
@ -323,6 +324,11 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
return false;
});
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
});
// List of enabled/disabled transformations
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::ConvertMod>();

View File

@ -35,6 +35,7 @@
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/common_optimizations/softmax_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
@ -260,6 +261,11 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
return MKLDNNMVNNode::checkAxesSuitability(node);
});
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
});
// List of enabled/disabled transformations
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::Gelu7Downgrade>();

View File

@ -28,6 +28,9 @@ ngraph::pass::SoftmaxFusion::SoftmaxFusion() {
auto div_pattern = ngraph::pattern::wrap_type<opset6::Divide>({exp_pattern, reduce_sum_pattern});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
if (transformation_callback(m.get_match_root()))
return false;
const auto& pattern_map = m.get_pattern_value_map();
auto reduce_max_axes = std::dynamic_pointer_cast<opset6::Constant>(pattern_map.at(reduce_max_axes_pattern).get_node_shared_ptr());