[GNA] Transpose bias (#6759)
* transpose bias * removed bias transpose; added bias validation predicate to pattern * fixed after review; added handling of the case bias_output_shape.size() == 1 and bias_output_shape.at(0) > 1 * moved bias shape size check to matcher pattern; replaced loop with algorithm
This commit is contained in:
parent
51d511c8ac
commit
c986ce09ce
@ -19,6 +19,20 @@ NGRAPH_RTTI_DEFINITION(ConvertMatmulToPointWiseConvolution, "ConvertMatmulToPoin
|
||||
NGRAPH_RTTI_DEFINITION(ConvertMatmulWithBiasToPointWiseConvolution, "ConvertMatmulWithBiasToPointWiseConvolution", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ConvertMatmulWithFqToPointWiseConvolution, "ConvertMatmulWithFqToPointWiseConvolution", 0);
|
||||
|
||||
static bool BiasValidation(const ngraph::Output<ngraph::Node>& output) {
|
||||
auto bias_output_shape = output.get_node()->get_output_shape(0);
|
||||
if (bias_output_shape.size() > 4) {
|
||||
gnalog() << "bias output shape (" << output.get_node()->get_friendly_name() << ") is more than 4\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bias_output_shape.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return std::count_if(bias_output_shape.begin(), bias_output_shape.end(), [](size_t el){ return el > 1; }) < 2;
|
||||
}
|
||||
|
||||
static std::tuple<bool, uint32_t, uint32_t, uint32_t> VerifyAndGetConvParams(std::shared_ptr<ngraph::Node> matmul_node) {
|
||||
auto input1_shape = matmul_node->get_input_shape(0);
|
||||
auto input2_shape = matmul_node->get_input_shape(1);
|
||||
@ -83,10 +97,24 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
||||
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||
|
||||
std::shared_ptr<ngraph::Node> root_node = matmul_node;
|
||||
if (bias != nullptr) {
|
||||
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
|
||||
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||
root_node = add;
|
||||
if (bias) {
|
||||
auto bias_output_shape = bias->get_output_shape(0);
|
||||
std::shared_ptr<ngraph::Node> new_bias = bias;
|
||||
if (bias_output_shape.size() > 1 || bias_output_shape.at(0) > 1) {
|
||||
std::vector<size_t> axes(4, 1);
|
||||
auto iter = std::find_if(bias_output_shape.begin(), bias_output_shape.end(), [](size_t value) { return value > 1; });
|
||||
if (iter != bias_output_shape.end()) {
|
||||
axes.at(1) = *iter;
|
||||
}
|
||||
new_bias = std::make_shared<ngraph::opset7::Constant>(
|
||||
bias->get_output_element_type(0),
|
||||
axes,
|
||||
std::dynamic_pointer_cast<ngraph::opset7::Constant>(bias)->get_data_ptr());
|
||||
}
|
||||
|
||||
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, new_bias);
|
||||
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||
root_node = add;
|
||||
}
|
||||
|
||||
if (fq != nullptr) {
|
||||
@ -146,7 +174,7 @@ ConvertMatmulWithBiasToPointWiseConvolution::ConvertMatmulWithBiasToPointWiseCon
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
|
||||
auto second_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{const_input, const_fq});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), second_input});
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(BiasValidation);
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
@ -169,7 +197,7 @@ ConvertMatmulWithFqToPointWiseConvolution::ConvertMatmulWithFqToPointWiseConvolu
|
||||
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
|
||||
auto second_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{const_input, const_fq});
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), second_input});
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto bias = ngraph::pattern::wrap_type<ngraph::opset7::Constant>(BiasValidation);
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset7::Add>({matmul, bias});
|
||||
auto matmul_out = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{add, matmul});
|
||||
auto out_fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul_out,
|
||||
@ -190,4 +218,4 @@ ConvertMatmulWithFqToPointWiseConvolution::ConvertMatmulWithFqToPointWiseConvolu
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(out_fq, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ addIeTargetTest(
|
||||
NAME ${TARGET_NAME}
|
||||
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
LINK_LIBRARIES
|
||||
PRIVATE
|
||||
ngraphFunctions
|
||||
gmock
|
||||
commonTestUtils_s
|
||||
GNAPlugin_test_static
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace testing {
|
||||
|
||||
@ -119,6 +120,7 @@ void CreateMatMul::updateGraph(Graph& graph) {
|
||||
graph.output = matmul_node;
|
||||
}
|
||||
|
||||
template<bool ONE_DIMENSIONAL, bool ONE_CHANNEL>
|
||||
class CreateAdd : public CreateGraphDecorator {
|
||||
public:
|
||||
CreateAdd(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||
@ -126,8 +128,18 @@ protected:
|
||||
void updateGraph(Graph&) override;
|
||||
};
|
||||
|
||||
void CreateAdd::updateGraph(Graph& graph) {
|
||||
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
template<bool ONE_DIMENSIONAL, bool ONE_CHANNEL>
|
||||
void CreateAdd<ONE_DIMENSIONAL, ONE_CHANNEL>::updateGraph(Graph& graph) {
|
||||
std::vector<size_t> axes(1, 1);
|
||||
if (std::is_same<std::integral_constant<bool, ONE_CHANNEL>, std::integral_constant<bool, false>>::value) {
|
||||
auto shape = graph.output->get_output_shape(0);
|
||||
if (std::is_same<std::integral_constant<bool, ONE_DIMENSIONAL>, std::integral_constant<bool, false>>::value) {
|
||||
axes.resize(shape.size(), 1);
|
||||
}
|
||||
axes.back() = shape.back();
|
||||
}
|
||||
|
||||
auto bias = ngraph::builder::makeConstant<float>(ngraph::element::i64, axes, {}, true);
|
||||
auto add_node = std::make_shared<ngraph::opset7::Add>(graph.output, bias);
|
||||
graph.output = add_node;
|
||||
}
|
||||
@ -155,7 +167,8 @@ Graph createTransformedGraph(const ngraph::Shape& input_data_shape = ngraph::Sha
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bool addOutFakeQuantizeNode) {
|
||||
template<bool ADD_CONST_FAKEQUANTIZE_NODE, bool INSERT_ADD_NODE, bool ONE_CHANNEL, bool ADD_OUT_FAKEQUANTIZE_NODE>
|
||||
Graph createReferenceGraph() {
|
||||
Graph graph;
|
||||
|
||||
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
|
||||
@ -173,8 +186,9 @@ Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bo
|
||||
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before, const_transpose_before);
|
||||
|
||||
std::shared_ptr<ngraph::op::Op> parent_node = constant_node;
|
||||
if (addConstFakeQuantizeNode)
|
||||
if (std::is_same<std::integral_constant<bool, ADD_CONST_FAKEQUANTIZE_NODE>, std::integral_constant<bool, true>>::value) {
|
||||
parent_node = createFakeQuantizeNode(constant_node);
|
||||
}
|
||||
|
||||
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{4}, ngraph::Shape{8, 8, 1, 1});
|
||||
@ -189,15 +203,21 @@ Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bo
|
||||
ngraph::op::PadType::VALID);
|
||||
|
||||
parent_node = conv_node;
|
||||
if (std::is_same<std::integral_constant<bool, INSERT_ADD_NODE>, std::integral_constant<bool, true>>::value) {
|
||||
std::vector<size_t> axes(1, 1);
|
||||
if (std::is_same<std::integral_constant<bool, ONE_CHANNEL>, std::integral_constant<bool, false>>::value) {
|
||||
axes.resize(4, 1);
|
||||
axes[1] = 8;
|
||||
}
|
||||
|
||||
if (insertAddNode) {
|
||||
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto bias = ngraph::builder::makeConstant<float>(ngraph::element::i64, axes, {}, true);
|
||||
auto add_node = std::make_shared<ngraph::opset7::Add>(parent_node, bias);
|
||||
parent_node = add_node;
|
||||
}
|
||||
|
||||
if (addOutFakeQuantizeNode)
|
||||
if (std::is_same<std::integral_constant<bool, ADD_OUT_FAKEQUANTIZE_NODE>, std::integral_constant<bool, true>>::value) {
|
||||
parent_node = createFakeQuantizeNode(parent_node);
|
||||
}
|
||||
|
||||
auto const_transpose_after = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{4},
|
||||
@ -254,47 +274,65 @@ TEST_P(ConvertMatmulToPointWiseConvolutionFixture, CompareFunctions) {
|
||||
execute_test(function, reference_function, pass_manager);
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr bool AddConstFakeQuantizeNode = true;
|
||||
constexpr bool InsertAddNode = true;
|
||||
constexpr bool OneDimensional = true;
|
||||
constexpr bool OneChannel = true;
|
||||
constexpr bool AddOutFakeQuantizeNode = true;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionTestSuite, ConvertMatmulToPointWiseConvolutionFixture,
|
||||
::testing::Values(std::make_tuple(createTransformedGraph<CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
|
||||
::testing::Values(
|
||||
std::make_tuple(
|
||||
createTransformedGraph<CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, !OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, !AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, !OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd<!OneDimensional, !OneChannel>, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
|
||||
createReferenceGraph<!AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph<AddConstFakeQuantizeNode, !InsertAddNode, !OneChannel, AddOutFakeQuantizeNode>(),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------
|
||||
|
||||
@ -373,19 +411,19 @@ std::vector<FixtureData> transform_types = {
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||
CreateAdd,
|
||||
CreateAdd<false, false>,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||
CreateAdd,
|
||||
CreateAdd<false, false>,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateAdd,
|
||||
CreateAdd<false, false>,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateAdd,
|
||||
CreateAdd<false, false>,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
|
Loading…
Reference in New Issue
Block a user