Enable SoftmaxFusion inside MOC Transformations pipeline (#7684)

* Enable SoftmaxFusion inside MOC Transformations pipeline

* Disable SoftmaxDecomposition by default
This commit is contained in:
Gleb Kazantaev 2021-09-28 10:46:31 +03:00 committed by GitHub
parent 476fbee00f
commit 204c17cc21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 9 deletions

View File

@ -32,7 +32,6 @@
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/softmax_fusion.hpp"
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
@ -64,6 +63,7 @@
#include <transformations/op_conversions/convert_gather_0d.hpp>
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
#include "transformations/op_conversions/softmax_decomposition.hpp"
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
@ -333,9 +333,10 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
return false;
});
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
pass_config->enable<ngraph::pass::SoftmaxDecomposition>();
pass_config->set_callback<ngraph::pass::SoftmaxDecomposition>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
return node->input_value(0).get_partial_shape().rank().get_length() <= 5;
});
// List of enabled/disabled transformations

View File

@ -28,7 +28,6 @@
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/softmax_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
@ -47,6 +46,7 @@
#include <transformations/op_conversions/convert_batch_to_space.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
#include <transformations/op_conversions/convert_subtract.hpp>
#include <transformations/op_conversions/softmax_decomposition.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/op_conversions/convert_mod.hpp>
#include <transformations/op_conversions/convert_ti_to_sequences.hpp>
@ -289,9 +289,10 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
return MKLDNNNormalizeL2Node::isSupportedOperation(node, errorMsg);
});
pass_config->set_callback<ngraph::pass::SoftmaxFusion>(
pass_config->enable<ngraph::pass::SoftmaxDecomposition>();
pass_config->set_callback<ngraph::pass::SoftmaxDecomposition>(
[](const_node_ptr &node) -> bool {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
return node->input_value(0).get_partial_shape().rank().get_length() <= 5;
});
pass_config->set_callback<ngraph::pass::ConvertNMSToNMSIEInternal>(

View File

@ -39,6 +39,7 @@
#include <transformations/common_optimizations/leaky_relu_fusion.hpp>
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
#include <transformations/common_optimizations/softmax_fusion.hpp>
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
@ -89,6 +90,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
common_fusions->add_matcher<ngraph::pass::NormalizeL2Fusion>();
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->add_matcher<ngraph::pass::SoftmaxFusion>();
common_fusions->add_matcher<ngraph::pass::MVNFusion>();
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();

View File

@ -0,0 +1,75 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SoftmaxDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SoftmaxDecomposition transformation replaces softmax with following graph:
*
* +---------------+
*
* input
*
* +---------------+
*
* v
* +-----------+
*
* ReduceMax
*
* +-----------+
*
*
* v v
* +---------------+
*
* Sub
*
* +---------------+
* |
* |
* v
* +---------------+
*
* Exp
*
* +---------------+
*
* v
* +-----------+
*
* ReduceSum
*
* +-----------+
*
*
* v v
* +-------------+
* |
* | Div
*
* +-------------+
*
*/
class ngraph::pass::SoftmaxDecomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SoftmaxDecomposition();
};

View File

@ -83,6 +83,7 @@
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/op_conversions/normalize_l2_decomposition.hpp>
#include <transformations/op_conversions/softmax_decomposition.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);
@ -171,6 +172,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::NormalizeL2Decomposition, false>();
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
decomp->add_matcher<ngraph::pass::EinsumDecomposition>();
decomp->add_matcher<ngraph::pass::SoftmaxDecomposition, false>();
decomp->add_matcher<ngraph::pass::GatherNegativeConstIndicesNormalize>();
decomp->add_matcher<ngraph::pass::DropoutWithRandomUniformReplacer>();
decomp->set_name("ngraph::pass::CommonDecompositions");

View File

@ -0,0 +1,43 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include <transformations/op_conversions/softmax_decomposition.hpp>
#include <memory>
#include <vector>
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::SoftmaxDecomposition, "SoftmaxDecomposition", 0);
ngraph::pass::SoftmaxDecomposition::SoftmaxDecomposition() {
MATCHER_SCOPE(SoftmaxDecomposition);
auto softmax = pattern::wrap_type<ngraph::opset8::Softmax>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto node = std::dynamic_pointer_cast<opset8::Softmax>(m.get_match_root());
if (!node || transformation_callback(node)) {
return false;
}
auto input = node->input_value(0);
auto axis = opset8::Constant::create(element::i64, Shape{1}, {node->get_axis()});
auto reduce_max = std::make_shared<opset8::ReduceMax>(input, axis, true);
auto sub = std::make_shared<opset8::Subtract>(input, reduce_max);
auto exp = std::make_shared<opset8::Exp>(sub);
auto reduce_sum = std::make_shared<opset8::ReduceSum>(exp, axis, true);
auto div = std::make_shared<opset8::Divide>(exp, reduce_sum);
replace_node(node, div);
copy_runtime_info(node, {reduce_max, reduce_sum, sub, exp, div});
div->set_friendly_name(node->get_friendly_name());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(softmax, matcher_name);
register_matcher(m, callback);
}

View File

@ -112,6 +112,15 @@ const std::vector<SoftMaxConfig> notOptimizedConfigsFP32 {
{InferenceEngine::SizeVector{10, 10, 10}, 1},
};
const std::vector<SoftMaxConfig> unsupportedConfigsFP32 {
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 0},
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 1},
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 2},
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 3},
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 4},
{InferenceEngine::SizeVector{5, 5, 5, 5, 5, 5}, 5},
};
const auto OptimizedParams = testing::Combine(
testing::Values(Precision::FP32, Precision::BF16),
testing::ValuesIn(optimizedConfigsFP32),
@ -128,5 +137,13 @@ const auto NotOptimizedParams = testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_SoftMax_CPU, SoftMaxLayerCPUTest, NotOptimizedParams, SoftMaxLayerCPUTest::getTestCaseName);
const auto UnsupportedParams = testing::Combine(
testing::Values(Precision::FP32, Precision::BF16),
testing::ValuesIn(unsupportedConfigsFP32),
testing::Values(CommonTestUtils::DEVICE_CPU),
testing::Values(notOptimizedCPUSpec));
INSTANTIATE_TEST_SUITE_P(smoke_SoftMax_Unsupported_CPU, SoftMaxLayerCPUTest, UnsupportedParams, SoftMaxLayerCPUTest::getTestCaseName);
} // namespace
} // namespace CPULayerTestsDefinitions

View File

@ -116,7 +116,6 @@ void CPUTestsBase::CheckPluginRelatedResults(InferenceEngine::ExecutableNetwork
if (nodeType.empty()) return;
ASSERT_TRUE(!selectedType.empty()) << "Node type is not defined.";
bool isNodeFound = false;
InferenceEngine::CNNNetwork execGraphInfo = execNet.GetExecGraphInfo();
auto function = execGraphInfo.getFunction();
ASSERT_NE(nullptr, function);
@ -145,7 +144,6 @@ void CPUTestsBase::CheckPluginRelatedResults(InferenceEngine::ExecutableNetwork
};
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == nodeType) {
isNodeFound = true;
ASSERT_LE(inFmts.size(), node->get_input_size());
ASSERT_LE(outFmts.size(), node->get_output_size());
for (int i = 0; i < inFmts.size(); i++) {
@ -205,7 +203,6 @@ void CPUTestsBase::CheckPluginRelatedResults(InferenceEngine::ExecutableNetwork
ASSERT_EQ(selectedType, primType);
}
}
ASSERT_TRUE(isNodeFound) << "Node type name: \"" << nodeType << "\" has not been found.";
}
std::string CPUTestsBase::getTestCaseName(CPUSpecificParams params) {