GNA convert matmul to pointwise convolution transformation unit tests (#6524)
* ConvertMatmulToPointWiseConvolutionTest first test * add ConvertMatmulToPointWiseConvolutionFqTest * use general functions to create test subgraphs * use general funstion to append node; add ConvertMatmulWithBiasToPointWiseConvolutionTest * add ConvertMatmulWithBiasToPointWiseConvolutionFqTest * use decorator instead of bool function arguments * remove unused functions * cleanup * add ConvertMatmulWithFqToPointWiseConvolutionTest * add ConvertMatmulWithFqToPointWiseConvolutionFqTest * add ConvertMatmulWithFqToPointWiseConvolutionTestNoAddNode * remove debug * add ConvertMatmulToPointWiseConvolutionTestInputRank3 * use TEST_P for ConvertMatmulToPointWiseConvolution tests * use testing::values fixture instead of multiple tests * cleanup * use combine tests for invalid inputs * code style cleanup * fix unique_ptr build under Windows * code review fixes: function template params * code review fixes: remove duplicated test entry * fix function arguments alignments
This commit is contained in:
parent
f48ea5d1cc
commit
c64b809e87
@ -8,6 +8,7 @@
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "layers/gna_permute.hpp"
|
||||
#include "backend/gna_limitations.hpp"
|
||||
@ -62,30 +63,36 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
||||
ngraph::Shape{1, 1, width, in_channels});
|
||||
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(input_node, reshape_const_before, false);
|
||||
reshape_before->set_friendly_name(base_name + "/reshape_in");
|
||||
ngraph::copy_runtime_info(input_node, reshape_before);
|
||||
|
||||
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||
GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW)));
|
||||
transpose_before->set_friendly_name(base_name + "/transpose_in");
|
||||
ngraph::copy_runtime_info(matmul_node, transpose_before);
|
||||
|
||||
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1});
|
||||
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(weights_node, weights_reshape_const, false);
|
||||
ngraph::copy_runtime_info(weights_node, weights_reshaped);
|
||||
|
||||
std::shared_ptr<ngraph::Node> conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before, weights_reshaped,
|
||||
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
|
||||
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
|
||||
conv_node->set_friendly_name(base_name + "/conv");
|
||||
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||
|
||||
std::shared_ptr<ngraph::Node> root_node = matmul_node;
|
||||
if (bias != nullptr) {
|
||||
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
|
||||
ngraph::copy_runtime_info(transpose_before, conv_node);
|
||||
root_node = add;
|
||||
}
|
||||
|
||||
if (fq != nullptr) {
|
||||
conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2),
|
||||
fq->input_value(3), fq->input_value(4)});
|
||||
ngraph::copy_runtime_info(fq, conv_node);
|
||||
root_node = fq;
|
||||
}
|
||||
|
||||
@ -93,6 +100,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
|
||||
GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)));
|
||||
transpose_after->set_friendly_name(base_name + "/transpose_out");
|
||||
ngraph::copy_runtime_info(conv_node, transpose_after);
|
||||
|
||||
auto output_shape = matmul_node->get_output_shape(0);
|
||||
output_shape[output_shape.size() - 1] = out_channels;
|
||||
@ -102,6 +110,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
|
||||
output_shape);
|
||||
auto reshape_after = std::make_shared<ngraph::opset7::Reshape>(transpose_after, reshape_const_after, false);
|
||||
reshape_after->set_friendly_name(base_name);
|
||||
ngraph::copy_runtime_info(transpose_after, reshape_after);
|
||||
|
||||
ngraph::replace_node(root_node, reshape_after);
|
||||
return true;
|
||||
|
@ -0,0 +1,417 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "transformations/convert_matmul_to_pointwise_convolution.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
namespace testing {
|
||||
|
||||
namespace {
|
||||
|
||||
struct Graph {
|
||||
std::shared_ptr<ngraph::Function> createFunction();
|
||||
|
||||
std::shared_ptr<ngraph::opset7::Parameter> input_params;
|
||||
std::shared_ptr<ngraph::op::Op> output;
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Function> Graph::createFunction() {
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(output);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
// TODO: use std::make_unique when C++14 will be available
|
||||
template <typename T, typename... Args>
|
||||
std::unique_ptr<T> createUnique(Args&&... args) {
|
||||
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
class CreateGraphDecorator {
|
||||
public:
|
||||
CreateGraphDecorator(std::unique_ptr<CreateGraphDecorator> prev_builder = nullptr) : prev_builder_(std::move(prev_builder)) {}
|
||||
virtual ~CreateGraphDecorator() = default;
|
||||
virtual Graph build() {
|
||||
Graph graph;
|
||||
if (prev_builder_)
|
||||
graph = prev_builder_->build();
|
||||
updateGraph(graph);
|
||||
return graph;
|
||||
}
|
||||
protected:
|
||||
virtual void updateGraph(Graph&) = 0;
|
||||
private:
|
||||
CreateGraphDecorator(const CreateGraphDecorator&) = delete;
|
||||
CreateGraphDecorator& operator=(const CreateGraphDecorator&) = delete;
|
||||
private:
|
||||
std::unique_ptr<CreateGraphDecorator> prev_builder_;
|
||||
};
|
||||
|
||||
using CreateGraphDecoratorPtr = std::unique_ptr<CreateGraphDecorator>;
|
||||
|
||||
class CreateBaseDecorator : public CreateGraphDecorator {
|
||||
public:
|
||||
// always the first decorator => no prev_builder
|
||||
CreateBaseDecorator(const ngraph::Shape& input_data_shape,
|
||||
const ngraph::Shape& input_const_shape) :
|
||||
CreateGraphDecorator(nullptr),
|
||||
input_data_shape_(input_data_shape),
|
||||
input_const_shape_(input_const_shape) {}
|
||||
protected:
|
||||
Graph build() override;
|
||||
void updateGraph(Graph&) override {}
|
||||
private:
|
||||
const ngraph::Shape input_data_shape_;
|
||||
const ngraph::Shape input_const_shape_;
|
||||
};
|
||||
|
||||
Graph CreateBaseDecorator::build() {
|
||||
Graph graph;
|
||||
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
|
||||
input_data_shape_);
|
||||
graph.output = ngraph::opset7::Constant::create(ngraph::element::i64, input_const_shape_, {1});
|
||||
return graph;
|
||||
}
|
||||
|
||||
class CreateFakeQuantize : public CreateGraphDecorator {
|
||||
public:
|
||||
CreateFakeQuantize(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||
protected:
|
||||
void updateGraph(Graph&) override;
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::opset7::FakeQuantize> createFakeQuantizeNode(std::shared_ptr<ngraph::op::Op> parent_node) {
|
||||
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
|
||||
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
|
||||
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
|
||||
return std::make_shared<ngraph::opset7::FakeQuantize>(parent_node, input_low,
|
||||
input_high, output_low,
|
||||
output_high, 11);
|
||||
}
|
||||
|
||||
void CreateFakeQuantize::updateGraph(Graph& graph) {
|
||||
graph.output = createFakeQuantizeNode(graph.output);
|
||||
}
|
||||
|
||||
class CreateMatMul : public CreateGraphDecorator {
|
||||
public:
|
||||
CreateMatMul(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||
protected:
|
||||
void updateGraph(Graph&) override;
|
||||
};
|
||||
|
||||
void CreateMatMul::updateGraph(Graph& graph) {
|
||||
auto matmul_node = std::make_shared<ngraph::opset7::MatMul>(graph.input_params, graph.output);
|
||||
graph.output = matmul_node;
|
||||
}
|
||||
|
||||
class CreateAdd : public CreateGraphDecorator {
|
||||
public:
|
||||
CreateAdd(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {}
|
||||
protected:
|
||||
void updateGraph(Graph&) override;
|
||||
};
|
||||
|
||||
void CreateAdd::updateGraph(Graph& graph) {
|
||||
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto add_node = std::make_shared<ngraph::opset7::Add>(graph.output, bias);
|
||||
graph.output = add_node;
|
||||
}
|
||||
|
||||
template<typename DecorT, typename... DecorTs, typename std::enable_if<(sizeof...(DecorTs) == 0), bool>::type = true>
|
||||
CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||
CreateGraphDecoratorPtr build_decorator = createUnique<CreateBaseDecorator>(input_data_shape, input_const_shape);
|
||||
return createUnique<DecorT>(std::move(build_decorator));
|
||||
}
|
||||
|
||||
template<typename DecorT, typename... DecorTs, typename std::enable_if<(sizeof...(DecorTs) > 0), bool>::type = true>
|
||||
CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||
CreateGraphDecoratorPtr build_decorator = createBuildDecorator<DecorTs...>(input_data_shape, input_const_shape);
|
||||
return createUnique<DecorT>(std::move(build_decorator));
|
||||
}
|
||||
|
||||
template<typename DecorT, typename... DecorTs>
|
||||
Graph createTransformedGraph(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8},
|
||||
const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) {
|
||||
CreateGraphDecoratorPtr build_decorator = createBuildDecorator<DecorT, DecorTs...>(input_data_shape, input_const_shape);
|
||||
return build_decorator->build();
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bool addOutFakeQuantizeNode) {
|
||||
Graph graph;
|
||||
|
||||
graph.input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64,
|
||||
ngraph::Shape{16, 8});
|
||||
auto constant_node = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8, 8}, {1});
|
||||
|
||||
auto const_reshape_before = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{4},
|
||||
ngraph::Shape{1, 1, 16, 8});
|
||||
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(graph.input_params, const_reshape_before, false);
|
||||
|
||||
auto const_transpose_before = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{4},
|
||||
ngraph::Shape{0, 3, 1, 2});
|
||||
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before, const_transpose_before);
|
||||
|
||||
std::shared_ptr<ngraph::op::Op> parent_node = constant_node;
|
||||
if (addConstFakeQuantizeNode)
|
||||
parent_node = createFakeQuantizeNode(constant_node);
|
||||
|
||||
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{4}, ngraph::Shape{8, 8, 1, 1});
|
||||
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(parent_node, weights_reshape_const, false);
|
||||
|
||||
auto conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before,
|
||||
weights_reshaped,
|
||||
ngraph::Strides{1, 1},
|
||||
ngraph::CoordinateDiff{0, 0},
|
||||
ngraph::CoordinateDiff{0, 0},
|
||||
ngraph::Strides{1, 1},
|
||||
ngraph::op::PadType::VALID);
|
||||
|
||||
parent_node = conv_node;
|
||||
|
||||
if (insertAddNode) {
|
||||
auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto add_node = std::make_shared<ngraph::opset7::Add>(parent_node, bias);
|
||||
parent_node = add_node;
|
||||
}
|
||||
|
||||
if (addOutFakeQuantizeNode)
|
||||
parent_node = createFakeQuantizeNode(parent_node);
|
||||
|
||||
auto const_transpose_after = ngraph::opset7::Constant::create(ngraph::element::i64,
|
||||
ngraph::Shape{4},
|
||||
ngraph::Shape{0, 2, 3, 1});
|
||||
auto transpose_after = std::make_shared<ngraph::opset7::Transpose>(parent_node, const_transpose_after);
|
||||
|
||||
auto const_reshape_after = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{2},
|
||||
ngraph::Shape{16, 8});
|
||||
graph.output = std::make_shared<ngraph::opset7::Reshape>(transpose_after, const_reshape_after, false);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ConvertMatmulToPointWiseConvolutionFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<std::tuple<Graph /* tranformed */,
|
||||
Graph /* reference */,
|
||||
ngraph::pass::Manager>> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function, reference_function;
|
||||
ngraph::pass::Manager pass_manager;
|
||||
};
|
||||
|
||||
void ConvertMatmulToPointWiseConvolutionFixture::SetUp() {
|
||||
// TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17
|
||||
Graph transformed_graph;
|
||||
Graph reference_graph;
|
||||
std::tie(transformed_graph, reference_graph, pass_manager) = this->GetParam();
|
||||
|
||||
function = transformed_graph.createFunction();
|
||||
reference_function = reference_graph.createFunction();
|
||||
}
|
||||
|
||||
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function, ngraph::pass::Manager& pass_manager) {
|
||||
pass_manager.run_passes(function);
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
template <typename TransformationT>
|
||||
ngraph::pass::Manager createPassManager() {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<TransformationT>();
|
||||
return manager;
|
||||
}
|
||||
|
||||
TEST_P(ConvertMatmulToPointWiseConvolutionFixture, CompareFunctions) {
|
||||
execute_test(function, reference_function, pass_manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionTestSuite, ConvertMatmulToPointWiseConvolutionFixture,
|
||||
::testing::Values(std::make_tuple(createTransformedGraph<CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
false /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateAdd, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
true /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul>(),
|
||||
createReferenceGraph(false /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>()),
|
||||
std::make_tuple(createTransformedGraph<CreateFakeQuantize, CreateMatMul, CreateFakeQuantize>(),
|
||||
createReferenceGraph(true /* addConstFakeQuantizeNode */,
|
||||
false /* insertAddNode */,
|
||||
true /* addOutFakeQuantizeNode */),
|
||||
createPassManager<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution>())));
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------
|
||||
|
||||
class ITransformedGraphFactory {
|
||||
public:
|
||||
virtual ~ITransformedGraphFactory() = default;
|
||||
virtual Graph createGraph(const ngraph::Shape& input_data_shape,
|
||||
const ngraph::Shape& input_const_shape) = 0;
|
||||
};
|
||||
|
||||
template<typename DecorT, typename... DecorTs>
|
||||
class TransformedGraphFactory : public ITransformedGraphFactory {
|
||||
public:
|
||||
TransformedGraphFactory() = default;
|
||||
|
||||
Graph createGraph(const ngraph::Shape& input_data_shape,
|
||||
const ngraph::Shape& input_const_shape) override {
|
||||
return createTransformedGraph<DecorT, DecorTs...>(input_data_shape, input_const_shape);
|
||||
}
|
||||
private:
|
||||
TransformedGraphFactory(const TransformedGraphFactory&) = delete;
|
||||
TransformedGraphFactory& operator=(const TransformedGraphFactory&) = delete;
|
||||
};
|
||||
|
||||
struct FixtureData {
|
||||
std::shared_ptr<ITransformedGraphFactory> graph_factory;
|
||||
ngraph::pass::Manager pass_manager;
|
||||
|
||||
template<typename TransformationT, typename DecorT, typename... DecorTs>
|
||||
static FixtureData create() {
|
||||
FixtureData fixture_data;
|
||||
fixture_data.graph_factory = std::make_shared<TransformedGraphFactory<DecorT, DecorTs...>>();
|
||||
fixture_data.pass_manager = createPassManager<TransformationT>();
|
||||
return fixture_data;
|
||||
}
|
||||
};
|
||||
|
||||
using FixtureInputShapes = std::tuple<ngraph::Shape /* input data */, ngraph::Shape /* input const */>;
|
||||
|
||||
class ConvertMatmulToPointWiseConvolutionInvalidInputFixture: public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<std::tuple<FixtureData,
|
||||
FixtureInputShapes>> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> function;
|
||||
ngraph::pass::Manager pass_manager;
|
||||
};
|
||||
|
||||
void ConvertMatmulToPointWiseConvolutionInvalidInputFixture::SetUp() {
|
||||
// TODO: use auto & [fixture_data, input_shapes] = this->GetParam() when C++17
|
||||
FixtureData fixture_data;
|
||||
FixtureInputShapes input_shapes;
|
||||
std::tie(fixture_data, input_shapes) = this->GetParam();
|
||||
|
||||
ngraph::Shape input_data, input_const;
|
||||
std::tie(input_data, input_const) = input_shapes;
|
||||
|
||||
function = fixture_data.graph_factory->createGraph(input_data, input_const).createFunction();
|
||||
pass_manager = fixture_data.pass_manager;
|
||||
}
|
||||
|
||||
void execute_test_cloned_function(std::shared_ptr<ngraph::Function> function,
|
||||
ngraph::pass::Manager& pass_manager) {
|
||||
std::shared_ptr<ngraph::Function> reference_function = ngraph::clone_function(*function);
|
||||
pass_manager.run_passes(function);
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
}
|
||||
|
||||
std::vector<FixtureData> transform_types = {
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulToPointWiseConvolution,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulToPointWiseConvolution,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||
CreateAdd,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithBiasToPointWiseConvolution,
|
||||
CreateAdd,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateAdd,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateAdd,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateMatMul>(),
|
||||
FixtureData::create<GNAPluginNS::ConvertMatmulWithFqToPointWiseConvolution,
|
||||
CreateFakeQuantize,
|
||||
CreateMatMul,
|
||||
CreateFakeQuantize>()
|
||||
};
|
||||
|
||||
std::vector<FixtureInputShapes> input_shapes = {
|
||||
std::make_tuple(ngraph::Shape{16, 16, 16}, ngraph::Shape{16, 16, 16}),
|
||||
std::make_tuple(ngraph::Shape{16, 9}, ngraph::Shape{9, 9}),
|
||||
std::make_tuple(ngraph::Shape{16, 65533}, ngraph::Shape{65533, 2}),
|
||||
std::make_tuple(ngraph::Shape{16, 769}, ngraph::Shape{769, 2})
|
||||
};
|
||||
|
||||
TEST_P(ConvertMatmulToPointWiseConvolutionInvalidInputFixture, CompareFunctions) {
|
||||
execute_test_cloned_function(function, pass_manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionInvalidInputTestSuite, ConvertMatmulToPointWiseConvolutionInvalidInputFixture,
|
||||
::testing::Combine(::testing::ValuesIn(transform_types),
|
||||
::testing::ValuesIn(input_shapes)));
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace testing
|
Loading…
Reference in New Issue
Block a user