[GNA] Transpose bias (#6759)

* transpose bias

* removed bias transpose; added bias validation predicate to pattern

* fixed after review; added handling of the case bias_output_shape.size() == 1 and bias_output_shape.at(0) > 1

* moved bias shape size check to matcher pattern; replaced loop with algorithm
This commit is contained in:
Dmitrii Khurtin 2021-08-11 14:14:47 +03:00 committed by GitHub
parent 51d511c8ac
commit c986ce09ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 58 deletions

View File

@ -19,6 +19,20 @@ NGRAPH_RTTI_DEFINITION(ConvertMatmulToPointWiseConvolution, "ConvertMatmulToPoin
NGRAPH_RTTI_DEFINITION(ConvertMatmulWithBiasToPointWiseConvolution, "ConvertMatmulWithBiasToPointWiseConvolution", 0);
NGRAPH_RTTI_DEFINITION(ConvertMatmulWithFqToPointWiseConvolution, "ConvertMatmulWithFqToPointWiseConvolution", 0);
static bool BiasValidation(const ngraph::Output<ngraph::Node>& output) {
auto bias_output_shape = output.get_node()->get_output_shape(0);
if (bias_output_shape.size() > 4) {
gnalog() << "bias output shape (" << output.get_node()->get_friendly_name() << ") is more than 4\n";
return false;
}
if (bias_output_shape.size() == 1) {
return true;
}
return std::count_if(bias_output_shape.begin(), bias_output_shape.end(), [](size_t el){ return el > 1; }) < 2;
}
static std::tuple<bool, uint32_t, uint32_t, uint32_t> VerifyAndGetConvParams(std::shared_ptr<ngraph::Node> matmul_node) {
auto input1_shape = matmul_node->get_input_shape(0);
auto input2_shape = matmul_node->get_input_shape(1);
@ -83,10 +97,24 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
ngraph::copy_runtime_info(transpose_before, conv_node);
std::shared_ptr<ngraph::Node> root_node = matmul_node;
if (bias != nullptr) {
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
ngraph::copy_runtime_info(transpose_before, conv_node);
root_node = add;
if (bias) {
auto bias_output_shape = bias->get_output_shape(0);
std::shared_ptr<ngraph::Node> new_bias = bias;
if (bias_output_shape.size() > 1 || bias_output_shape.at(0) > 1) {
std::vector<size_t> axes(4, 1);
auto iter = std::find_if(bias_output_shape.begin(), bias_output_shape.end(), [](size_t value) { return value > 1; });
if (iter != bias_output_shape.end()) {
axes.at(1) = *iter;
}
new_bias = std::make_shared<ngraph::opset7::Constant>(
bias->get_output_element_type(0),
axes,
std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias)->get_data_ptr());
}
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, new_bias);
ngraph::copy_runtime_info(transpose_before, conv_node);
root_node = add;
}
if (fq != nullptr) {
@ -146,7 +174,7 @@ ConvertMatmulWithBiasToPointWiseConvolution::ConvertMatmulWithBiasToPointWiseCon
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto second_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{const_input, const_fq});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), second_input});
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(BiasValidation);
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
@ -169,7 +197,7 @@ ConvertMatmulWithFqToPointWiseConvolution::ConvertMatmulWithFqToPointWiseConvolu
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto second_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{const_input, const_fq});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), second_input});
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(BiasValidation);
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul_out,
@ -190,4 +218,4 @@ ConvertMatmulWithFqToPointWiseConvolution::ConvertMatmulWithFqToPointWiseConvolu
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
this->register_matcher(m, callback);
}
}

View File

@ -8,6 +8,8 @@ addIeTargetTest(
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
LINK_LIBRARIES
PRIVATE
ngraphFunctions
gmock
commonTestUtils_s
GNAPlugin_test_static

View File

@ -14,6 +14,7 @@
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include "ngraph_functions/builders.hpp"
namespace testing {
@ -119,6 +120,7 @@ void CreateMatMul::updateGraph(Graph& graph) {
graph.output = matmul_node;
}
template<bool ONE_DIMENSIONAL, bool ONE_CHANNEL>
class CreateAdd : public CreateGraphDecorator {
public:
CreateAdd(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
@ -126,8 +128,18 @@ protected:
void updateGraph(Graph&) override;
};
void CreateAdd::updateGraph(Graph& graph) {
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
template<bool ONE_DIMENSIONAL, bool ONE_CHANNEL>
void CreateAdd<ONE_DIMENSIONAL, ONE_CHANNEL>::updateGraph(Graph& graph) {
std::vector<size_t> axes(1, 1);
if (std::is_same<std::integral_constant<bool, ONE_CHANNEL>, std::integral_constant<bool, false>>::value) {
auto shape = graph.output->get_output_shape(0);
if (std::is_same<std::integral_constant<bool, ONE_DIMENSIONAL>, std::integral_constant<bool, false>>::value) {
axes.resize(shape.size(), 1);
}
axes.back() = shape.back();
}
auto bias = ngraph::builder::makeConstant<float>(ngraph::element::i64, axes, {}, true);
auto add_node = std::make_shared<ngraph::opset7::Add>(graph.output, bias);
graph.output = add_node;
}
@ -155,7 +167,8 @@ Graph createTransformedGraph(const ngraph::Shape& input_data_shape = ngraph::Sha
// ------------------------------------------------------------------------------------------------------------
Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bool addOutFakeQuantizeNode) {
template<bool ADD_CONST_FAKEQUANTIZE_NODE, bool INSERT_ADD_NODE, bool ONE_CHANNEL, bool ADD_OUT_FAKEQUANTIZE_NODE>
Graph createReferenceGraph() {
Graph graph;
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
@ -173,8 +186,9 @@ Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bo
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before, const_transpose_before);
std::shared_ptr<ngraph::op::Op> parent_node = constant_node;
if (addConstFakeQuantizeNode)
if (std::is_same<std::integral_constant<bool, ADD_CONST_FAKEQUANTIZE_NODE>, std::integral_constant<bool, true>>::value) {
parent_node = createFakeQuantizeNode(constant_node);
}
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{4}, ngraph::Shape{8, 8, 1, 1});
@ -189,15 +203,21 @@ Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bo
ngraph::op::PadType::VALID);
parent_node = conv_node;
if (std::is_same<std::integral_constant<bool, INSERT_ADD_NODE>, std::integral_constant<bool, true>>::value) {
std::vector<size_t> axes(1, 1);
if (std::is_same<std::integral_constant<bool, ONE_CHANNEL>, std::integral_constant<bool, false>>::value) {
axes.resize(4, 1);
axes[1] = 8;
}
if (insertAddNode) {
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto bias = ngraph::builder::makeConstant<float>(ngraph::element::i64, axes, {}, true);
auto add_node = std::make_shared<ngraph::opset7::Add>(parent_node, bias);
parent_node = add_node;
}
if (addOutFakeQuantizeNode)
if (std::is_same<std::integral_constant<bool, ADD_OUT_FAKEQUANTIZE_NODE>, std::integral_constant<bool, true>>::value) {
parent_node = createFakeQuantizeNode(parent_node);
}
auto const_transpose_after = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{4},
@ -254,47 +274,65 @@ TEST_P(ConvertMatmulToPointWiseConvolutionFixture, CompareFunctions) {
execute_test(function, reference_function, pass_manager);
}
namespace {
constexpr bool AddConstFakeQuantizeNode = true;
constexpr bool InsertAddNode = true;
constexpr bool OneDimensional = true;
constexpr bool OneChannel = true;
constexpr bool AddOutFakeQuantizeNode = true;
}
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionTestSuite, ConvertMatmulToPointWiseConvolutionFixture,
::testing::Values(std::make_tuple(createTransformedGraph<CreateMatMul>(),
createReferenceGraph(false /* addConstFakeQuantizeNode */,
false /* insertAddNode */,
false /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph(true /* addConstFakeQuantizeNode */,
false /* insertAddNode */,
false /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul>(),
createReferenceGraph(false /* addConstFakeQuantizeNode */,
true /* insertAddNode */,
false /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph(true /* addConstFakeQuantizeNode */,
true /* insertAddNode */,
false /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul>(),
createReferenceGraph(false /* addConstFakeQuantizeNode */,
true /* insertAddNode */,
true /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph(true /* addConstFakeQuantizeNode */,
true /* insertAddNode */,
true /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
createReferenceGraph(false /* addConstFakeQuantizeNode */,
false /* insertAddNode */,
true /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph(true /* addConstFakeQuantizeNode */,
false /* insertAddNode */,
true /* addOutFakeQuantizeNode */),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
::testing::Values(
std::make_tuple(
createTransformedGraph<CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, !OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, !OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
createReferenceGraph<!AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
createReferenceGraph<AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
// -------------------------------------------------------------------------------------------------------
@ -373,19 +411,19 @@ std::vector<FixtureData> transform_types = {
CreateMatMul,
CreateFakeQuantize>(),
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
CreateAdd,
CreateAdd<false, false>,
CreateMatMul>(),
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
CreateAdd,
CreateAdd<false, false>,
CreateMatMul,
CreateFakeQuantize>(),
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
CreateFakeQuantize,
CreateAdd,
CreateAdd<false, false>,
CreateMatMul>(),
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
CreateFakeQuantize,
CreateAdd,
CreateAdd<false, false>,
CreateMatMul,
CreateFakeQuantize>(),
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,