Disallow SoftmaxFusion when input rank is greater than 5 (#5028)
This commit is contained in:
parent
83ec2d321a
commit
acf778d655
@ -32,6 +32,7 @@
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/common_optimizations/softmax_fusion.hpp"
|
||||
#include <transformations/op_conversions/convert_depth_to_space.hpp>
|
||||
#include <transformations/op_conversions/convert_space_to_depth.hpp>
|
||||
#include <transformations/op_conversions/convert_gelu.hpp>
|
||||
@ -323,6 +324,11 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
return false;
|
||||
});
|
||||
|
||||
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
|
||||
[](const_node_ptr &node) -> bool {
|
||||
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
|
||||
});
|
||||
|
||||
// List of enabled/disabled transformations
|
||||
pass_config->disable<ngraph::pass::ConvertGELU>();
|
||||
pass_config->disable<ngraph::pass::ConvertMod>();
|
||||
|
@ -35,6 +35,7 @@
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
|
||||
#include <transformations/common_optimizations/softmax_fusion.hpp>
|
||||
#include <transformations/op_conversions/convert_depth_to_space.hpp>
|
||||
#include <transformations/op_conversions/convert_space_to_depth.hpp>
|
||||
#include <transformations/op_conversions/convert_gelu.hpp>
|
||||
@ -260,6 +261,11 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
return MKLDNNMVNNode::checkAxesSuitability(node);
|
||||
});
|
||||
|
||||
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
|
||||
[](const_node_ptr &node) -> bool {
|
||||
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
|
||||
});
|
||||
|
||||
// List of enabled/disabled transformations
|
||||
pass_config->disable<ngraph::pass::ConvertGELU>();
|
||||
pass_config->disable<ngraph::pass::Gelu7Downgrade>();
|
||||
|
@ -28,6 +28,9 @@ ngraph::pass::SoftmaxFusion::SoftmaxFusion() {
|
||||
auto div_pattern = ngraph::pattern::wrap_type<opset6::Divide>({exp_pattern, reduce_sum_pattern});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
if (transformation_callback(m.get_match_root()))
|
||||
return false;
|
||||
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
|
||||
auto reduce_max_axes = std::dynamic_pointer_cast<opset6::Constant>(pattern_map.at(reduce_max_axes_pattern).get_node_shared_ptr());
|
||||
|
Loading…
Reference in New Issue
Block a user