GNA convert matmul to pointwise convolution transformation unit tests (#6524)
* ConvertMatmulToPointWiseConvolutionTest first test * add ConvertMatmulToPointWiseConvolutionFqTest * use general functions to create test subgraphs * use general funstion to append node; add ConvertMatmulWithBiasToPointWiseConvolutionTest * add ConvertMatmulWithBiasToPointWiseConvolutionFqTest * use decorator instead of bool function arguments * remove unused functions * cleanup * add ConvertMatmulWithFqToPointWiseConvolutionTest * add ConvertMatmulWithFqToPointWiseConvolutionFqTest * add ConvertMatmulWithFqToPointWiseConvolutionTestNoAddNode * remove debug * add ConvertMatmulToPointWiseConvolutionTestInputRank3 * use TEST_P for ConvertMatmulToPointWiseConvolution tests * use testing::values fixture instead of multiple tests * cleanup * use combine tests for invalid inputs * code style cleanup * fix unique_ptr build under Windows * code review fixes: function template params * code review fixes: remove duplicated test entry * fix function arguments alignments
This commit is contained in:
parent
f48ea5d1cc
commit
c64b809e87
@ -8,6 +8,7 @@
|
|||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
#include <ngraph/pattern/op/or.hpp>
|
#include <ngraph/pattern/op/or.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
|
||||||
#include "layers/gna_permute.hpp"
|
#include "layers/gna_permute.hpp"
|
||||||
#include "backend/gna_limitations.hpp"
|
#include "backend/gna_limitations.hpp"
|
||||||
@ -62,30 +63,36 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
|||||||
ngraph::Shape{1, 1, width, in_channels});
|
ngraph::Shape{1, 1, width, in_channels});
|
||||||
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(input_node, reshape_const_before, false);
|
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(input_node, reshape_const_before, false);
|
||||||
reshape_before->set_friendly_name(base_name + "/reshape_in");
|
reshape_before->set_friendly_name(base_name + "/reshape_in");
|
||||||
|
ngraph::copy_runtime_info(input_node, reshape_before);
|
||||||
|
|
||||||
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before,
|
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before,
|
||||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||||
GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW)));
|
GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW)));
|
||||||
transpose_before->set_friendly_name(base_name + "/transpose_in");
|
transpose_before->set_friendly_name(base_name + "/transpose_in");
|
||||||
|
ngraph::copy_runtime_info(matmul_node, transpose_before);
|
||||||
|
|
||||||
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||||
ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1});
|
ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1});
|
||||||
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(weights_node, weights_reshape_const, false);
|
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(weights_node, weights_reshape_const, false);
|
||||||
|
ngraph::copy_runtime_info(weights_node, weights_reshaped);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before, weights_reshaped,
|
std::shared_ptr<ngraph::Node> conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before, weights_reshaped,
|
||||||
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
|
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
|
||||||
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
|
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
|
||||||
conv_node->set_friendly_name(base_name + "/conv");
|
conv_node->set_friendly_name(base_name + "/conv");
|
||||||
|
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> root_node = matmul_node;
|
std::shared_ptr<ngraph::Node> root_node = matmul_node;
|
||||||
if (bias != nullptr) {
|
if (bias != nullptr) {
|
||||||
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
|
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
|
||||||
|
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||||
root_node = add;
|
root_node = add;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fq != nullptr) {
|
if (fq != nullptr) {
|
||||||
conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2),
|
conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2),
|
||||||
fq->input_value(3), fq->input_value(4)});
|
fq->input_value(3), fq->input_value(4)});
|
||||||
|
ngraph::copy_runtime_info(fq, conv_node);
|
||||||
root_node = fq;
|
root_node = fq;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +100,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
|||||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||||
GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)));
|
GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)));
|
||||||
transpose_after->set_friendly_name(base_name + "/transpose_out");
|
transpose_after->set_friendly_name(base_name + "/transpose_out");
|
||||||
|
ngraph::copy_runtime_info(conv_node, transpose_after);
|
||||||
|
|
||||||
auto output_shape = matmul_node->get_output_shape(0);
|
auto output_shape = matmul_node->get_output_shape(0);
|
||||||
output_shape[output_shape.size() - 1] = out_channels;
|
output_shape[output_shape.size() - 1] = out_channels;
|
||||||
@ -102,6 +110,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
|||||||
output_shape);
|
output_shape);
|
||||||
auto reshape_after = std::make_shared<ngraph::opset7::Reshape>(transpose_after, reshape_const_after, false);
|
auto reshape_after = std::make_shared<ngraph::opset7::Reshape>(transpose_after, reshape_const_after, false);
|
||||||
reshape_after->set_friendly_name(base_name);
|
reshape_after->set_friendly_name(base_name);
|
||||||
|
ngraph::copy_runtime_info(transpose_after, reshape_after);
|
||||||
|
|
||||||
ngraph::replace_node(root_node, reshape_after);
|
ngraph::replace_node(root_node, reshape_after);
|
||||||
return true;
|
return true;
|
||||||
|
@ -0,0 +1,417 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "transformations/convert_matmul_to_pointwise_convolution.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct Graph {
|
||||||
|
std::shared_ptr<ngraph::Function> createFunction();
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::opset7::Parameter> input_params;
|
||||||
|
std::shared_ptr<ngraph::op::Op> output;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> Graph::createFunction() {
|
||||||
|
auto result = std::make_shared<ngraph::opset7::Result>(output);
|
||||||
|
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||||
|
ngraph::ParameterVector{input_params});
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TODO: use std::make_unique when C++14 will be available
|
||||||
|
template <typename T, typename... Args>
|
||||||
|
std::unique_ptr<T> createUnique(Args&&... args) {
|
||||||
|
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateGraphDecorator {
|
||||||
|
public:
|
||||||
|
CreateGraphDecorator(std::unique_ptr<CreateGraphDecorator> prev_builder = nullptr) : prev_builder_(std::move(prev_builder)) {}
|
||||||
|
virtual ~CreateGraphDecorator() = default;
|
||||||
|
virtual Graph build() {
|
||||||
|
Graph graph;
|
||||||
|
if (prev_builder_)
|
||||||
|
graph = prev_builder_->build();
|
||||||
|
updateGraph(graph);
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
protected:
|
||||||
|
virtual void updateGraph(Graph&) = 0;
|
||||||
|
private:
|
||||||
|
CreateGraphDecorator(const CreateGraphDecorator&) = delete;
|
||||||
|
CreateGraphDecorator& operator=(const CreateGraphDecorator&) = delete;
|
||||||
|
private:
|
||||||
|
std::unique_ptr<CreateGraphDecorator> prev_builder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CreateGraphDecoratorPtr = std::unique_ptr<CreateGraphDecorator>;
|
||||||
|
|
||||||
|
class CreateBaseDecorator : public CreateGraphDecorator {
|
||||||
|
public:
|
||||||
|
// always the first decorator => no prev_builder
|
||||||
|
CreateBaseDecorator(const ngraph::Shape& input_data_shape,
|
||||||
|
const ngraph::Shape& input_const_shape) :
|
||||||
|
CreateGraphDecorator(nullptr),
|
||||||
|
input_data_shape_(input_data_shape),
|
||||||
|
input_const_shape_(input_const_shape) {}
|
||||||
|
protected:
|
||||||
|
Graph build() override;
|
||||||
|
void updateGraph(Graph&) override {}
|
||||||
|
private:
|
||||||
|
const ngraph::Shape input_data_shape_;
|
||||||
|
const ngraph::Shape input_const_shape_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Graph CreateBaseDecorator::build() {
|
||||||
|
Graph graph;
|
||||||
|
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
|
||||||
|
input_data_shape_);
|
||||||
|
graph.output = ngraph::opset7::Constant::create(ngraph::element::i64, input_const_shape_, {1});
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateFakeQuantize : public CreateGraphDecorator {
|
||||||
|
public:
|
||||||
|
CreateFakeQuantize(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||||
|
protected:
|
||||||
|
void updateGraph(Graph&) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::opset7::FakeQuantize> createFakeQuantizeNode(std::shared_ptr<ngraph::op::Op> parent_node) {
|
||||||
|
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||||
|
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||||
|
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||||
|
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||||
|
return std::make_shared<ngraph::opset7::FakeQuantize>(parent_node, input_low,
|
||||||
|
input_high, output_low,
|
||||||
|
output_high, 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateFakeQuantize::updateGraph(Graph& graph) {
|
||||||
|
graph.output = createFakeQuantizeNode(graph.output);
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateMatMul : public CreateGraphDecorator {
|
||||||
|
public:
|
||||||
|
CreateMatMul(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||||
|
protected:
|
||||||
|
void updateGraph(Graph&) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
void CreateMatMul::updateGraph(Graph& graph) {
|
||||||
|
auto matmul_node = std::make_shared<ngraph::opset7::MatMul>(graph.input_params, graph.output);
|
||||||
|
graph.output = matmul_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateAdd : public CreateGraphDecorator {
|
||||||
|
public:
|
||||||
|
CreateAdd(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||||
|
protected:
|
||||||
|
void updateGraph(Graph&) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
void CreateAdd::updateGraph(Graph& graph) {
|
||||||
|
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||||
|
auto add_node = std::make_shared<ngraph::opset7::Add>(graph.output, bias);
|
||||||
|
graph.output = add_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DecorT, typename... DecorTs, typename std::enable_if<(sizeof...(DecorTs) == 0), bool>::type = true>
|
||||||
|
CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||||
|
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||||
|
CreateGraphDecoratorPtr build_decorator = createUnique<CreateBaseDecorator>(input_data_shape, input_const_shape);
|
||||||
|
return createUnique<DecorT>(std::move(build_decorator));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DecorT, typename... DecorTs, typename std::enable_if<(sizeof...(DecorTs) > 0), bool>::type = true>
|
||||||
|
CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||||
|
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||||
|
CreateGraphDecoratorPtr build_decorator = createBuildDecorator<DecorTs...>(input_data_shape, input_const_shape);
|
||||||
|
return createUnique<DecorT>(std::move(build_decorator));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DecorT, typename... DecorTs>
|
||||||
|
Graph createTransformedGraph(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||||
|
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||||
|
CreateGraphDecoratorPtr build_decorator = createBuildDecorator<DecorT, DecorTs...>(input_data_shape, input_const_shape);
|
||||||
|
return build_decorator->build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bool addOutFakeQuantizeNode) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{16, 8});
|
||||||
|
auto constant_node = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8, 8}, {1});
|
||||||
|
|
||||||
|
auto const_reshape_before = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||||
|
ngraph::Shape{4},
|
||||||
|
ngraph::Shape{1, 1, 16, 8});
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(graph.input_params, const_reshape_before, false);
|
||||||
|
|
||||||
|
auto const_transpose_before = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||||
|
ngraph::Shape{4},
|
||||||
|
ngraph::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before, const_transpose_before);
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::op::Op> parent_node = constant_node;
|
||||||
|
if (addConstFakeQuantizeNode)
|
||||||
|
parent_node = createFakeQuantizeNode(constant_node);
|
||||||
|
|
||||||
|
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||||
|
ngraph::Shape{4}, ngraph::Shape{8, 8, 1, 1});
|
||||||
|
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(parent_node, weights_reshape_const, false);
|
||||||
|
|
||||||
|
auto conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before,
|
||||||
|
weights_reshaped,
|
||||||
|
ngraph::Strides{1, 1},
|
||||||
|
ngraph::CoordinateDiff{0, 0},
|
||||||
|
ngraph::CoordinateDiff{0, 0},
|
||||||
|
ngraph::Strides{1, 1},
|
||||||
|
ngraph::op::PadType::VALID);
|
||||||
|
|
||||||
|
parent_node = conv_node;
|
||||||
|
|
||||||
|
if (insertAddNode) {
|
||||||
|
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||||
|
auto add_node = std::make_shared<ngraph::opset7::Add>(parent_node, bias);
|
||||||
|
parent_node = add_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (addOutFakeQuantizeNode)
|
||||||
|
parent_node = createFakeQuantizeNode(parent_node);
|
||||||
|
|
||||||
|
auto const_transpose_after = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||||
|
ngraph::Shape{4},
|
||||||
|
ngraph::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose_after = std::make_shared<ngraph::opset7::Transpose>(parent_node, const_transpose_after);
|
||||||
|
|
||||||
|
auto const_reshape_after = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||||
|
ngraph::Shape{2},
|
||||||
|
ngraph::Shape{16, 8});
|
||||||
|
graph.output = std::make_shared<ngraph::opset7::Reshape>(transpose_after, const_reshape_after, false);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ConvertMatmulToPointWiseConvolutionFixture: public CommonTestUtils::TestsCommon,
|
||||||
|
public ::testing::WithParamInterface<std::tuple<Graph /* tranformed */,
|
||||||
|
Graph /* reference */,
|
||||||
|
ngraph::pass::Manager>> {
|
||||||
|
public:
|
||||||
|
void SetUp() override;
|
||||||
|
public:
|
||||||
|
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||||
|
ngraph::pass::Manager pass_manager;
|
||||||
|
};
|
||||||
|
|
||||||
|
void ConvertMatmulToPointWiseConvolutionFixture::SetUp() {
|
||||||
|
// TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17
|
||||||
|
Graph transformed_graph;
|
||||||
|
Graph reference_graph;
|
||||||
|
std::tie(transformed_graph, reference_graph, pass_manager) = this->GetParam();
|
||||||
|
|
||||||
|
function = transformed_graph.createFunction();
|
||||||
|
reference_function = reference_graph.createFunction();
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function, ngraph::pass::Manager& pass_manager) {
|
||||||
|
pass_manager.run_passes(function);
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TransformationT>
|
||||||
|
ngraph::pass::Manager createPassManager() {
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<TransformationT>();
|
||||||
|
return manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ConvertMatmulToPointWiseConvolutionFixture, CompareFunctions) {
|
||||||
|
execute_test(function, reference_function, pass_manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionTestSuite, ConvertMatmulToPointWiseConvolutionFixture,
|
||||||
|
::testing::Values(std::make_tuple(createTransformedGraph<CreateMatMul>(),
|
||||||
|
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||||
|
false /* insertAddNode */,
|
||||||
|
false /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
|
||||||
|
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||||
|
false /* insertAddNode */,
|
||||||
|
false /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul>(),
|
||||||
|
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||||
|
true /* insertAddNode */,
|
||||||
|
false /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||||
|
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||||
|
true /* insertAddNode */,
|
||||||
|
false /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul>(),
|
||||||
|
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||||
|
true /* insertAddNode */,
|
||||||
|
true /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||||
|
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||||
|
true /* insertAddNode */,
|
||||||
|
true /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
|
||||||
|
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||||
|
false /* insertAddNode */,
|
||||||
|
true /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||||
|
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
|
||||||
|
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||||
|
false /* insertAddNode */,
|
||||||
|
true /* addOutFakeQuantizeNode */),
|
||||||
|
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ITransformedGraphFactory {
|
||||||
|
public:
|
||||||
|
virtual ~ITransformedGraphFactory() = default;
|
||||||
|
virtual Graph createGraph(const ngraph::Shape& input_data_shape,
|
||||||
|
const ngraph::Shape& input_const_shape) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename DecorT, typename... DecorTs>
|
||||||
|
class TransformedGraphFactory : public ITransformedGraphFactory {
|
||||||
|
public:
|
||||||
|
TransformedGraphFactory() = default;
|
||||||
|
|
||||||
|
Graph createGraph(const ngraph::Shape& input_data_shape,
|
||||||
|
const ngraph::Shape& input_const_shape) override {
|
||||||
|
return createTransformedGraph<DecorT, DecorTs...>(input_data_shape, input_const_shape);
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
TransformedGraphFactory(const TransformedGraphFactory&) = delete;
|
||||||
|
TransformedGraphFactory& operator=(const TransformedGraphFactory&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FixtureData {
|
||||||
|
std::shared_ptr<ITransformedGraphFactory> graph_factory;
|
||||||
|
ngraph::pass::Manager pass_manager;
|
||||||
|
|
||||||
|
template<typename TransformationT, typename DecorT, typename... DecorTs>
|
||||||
|
static FixtureData create() {
|
||||||
|
FixtureData fixture_data;
|
||||||
|
fixture_data.graph_factory = std::make_shared<TransformedGraphFactory<DecorT, DecorTs...>>();
|
||||||
|
fixture_data.pass_manager = createPassManager<TransformationT>();
|
||||||
|
return fixture_data;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using FixtureInputShapes = std::tuple<ngraph::Shape /* input data */, ngraph::Shape /* input const */>;
|
||||||
|
|
||||||
|
class ConvertMatmulToPointWiseConvolutionInvalidInputFixture: public CommonTestUtils::TestsCommon,
|
||||||
|
public ::testing::WithParamInterface<std::tuple<FixtureData,
|
||||||
|
FixtureInputShapes>> {
|
||||||
|
public:
|
||||||
|
void SetUp() override;
|
||||||
|
public:
|
||||||
|
std::shared_ptr<ngraph::Function> function;
|
||||||
|
ngraph::pass::Manager pass_manager;
|
||||||
|
};
|
||||||
|
|
||||||
|
void ConvertMatmulToPointWiseConvolutionInvalidInputFixture::SetUp() {
|
||||||
|
// TODO: use auto & [fixture_data, input_shapes] = this->GetParam() when C++17
|
||||||
|
FixtureData fixture_data;
|
||||||
|
FixtureInputShapes input_shapes;
|
||||||
|
std::tie(fixture_data, input_shapes) = this->GetParam();
|
||||||
|
|
||||||
|
ngraph::Shape input_data, input_const;
|
||||||
|
std::tie(input_data, input_const) = input_shapes;
|
||||||
|
|
||||||
|
function = fixture_data.graph_factory->createGraph(input_data, input_const).createFunction();
|
||||||
|
pass_manager = fixture_data.pass_manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute_test_cloned_function(std::shared_ptr<ngraph::Function> function,
|
||||||
|
ngraph::pass::Manager& pass_manager) {
|
||||||
|
std::shared_ptr<ngraph::Function> reference_function = ngraph::clone_function(*function);
|
||||||
|
pass_manager.run_passes(function);
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||||
|
ASSERT_TRUE(result.valid);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<FixtureData> transform_types = {
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulToPointWiseConvolution,
|
||||||
|
CreateMatMul>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulToPointWiseConvolution,
|
||||||
|
CreateMatMul,
|
||||||
|
CreateFakeQuantize>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||||
|
CreateAdd,
|
||||||
|
CreateMatMul>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||||
|
CreateAdd,
|
||||||
|
CreateMatMul,
|
||||||
|
CreateFakeQuantize>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||||
|
CreateFakeQuantize,
|
||||||
|
CreateAdd,
|
||||||
|
CreateMatMul>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||||
|
CreateFakeQuantize,
|
||||||
|
CreateAdd,
|
||||||
|
CreateMatMul,
|
||||||
|
CreateFakeQuantize>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||||
|
CreateFakeQuantize,
|
||||||
|
CreateMatMul>(),
|
||||||
|
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||||
|
CreateFakeQuantize,
|
||||||
|
CreateMatMul,
|
||||||
|
CreateFakeQuantize>()
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<FixtureInputShapes> input_shapes = {
|
||||||
|
std::make_tuple(ngraph::Shape{16, 16, 16}, ngraph::Shape{16, 16, 16}),
|
||||||
|
std::make_tuple(ngraph::Shape{16, 9}, ngraph::Shape{9, 9}),
|
||||||
|
std::make_tuple(ngraph::Shape{16, 65533}, ngraph::Shape{65533, 2}),
|
||||||
|
std::make_tuple(ngraph::Shape{16, 769}, ngraph::Shape{769, 2})
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(ConvertMatmulToPointWiseConvolutionInvalidInputFixture, CompareFunctions) {
|
||||||
|
execute_test_cloned_function(function, pass_manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionInvalidInputTestSuite, ConvertMatmulToPointWiseConvolutionInvalidInputFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(transform_types),
|
||||||
|
::testing::ValuesIn(input_shapes)));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace testing
|
Loading…
Reference in New Issue
Block a user