Move passes to CommonOptimizations (#2442)
* Move passes to CommonOptimizations * Updated BN tests to use ranges for constant value generation * Added some decomposition passes into legacy conversion * Added WA for FQReshapeFusion pass
This commit is contained in:
parent
3a5ba5581f
commit
e3270b6b34
@ -21,13 +21,8 @@ class TRANSFORMATIONS_API ConvertGELU;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertGELU: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGELU() : GraphRewrite() {
|
||||
convert_gelu();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_gelu();
|
||||
ConvertGELU();
|
||||
};
|
||||
|
@ -7,21 +7,40 @@
|
||||
#include "transformations/common_optimizations/algebraic_simplification.hpp"
|
||||
#include "transformations/common_optimizations/nop_elimination.hpp"
|
||||
#include "transformations/common_optimizations/common_optimizations.hpp"
|
||||
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
|
||||
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
|
||||
#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
|
||||
#include "transformations/depth_to_space_fusion.hpp"
|
||||
#include "transformations/optimize_strided_slice.hpp"
|
||||
#include "transformations/convert_scatter_elements_to_scatter.hpp"
|
||||
#include "transformations/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/remove_filtering_boxes_by_size.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/itt.hpp"
|
||||
#include "transformations/mish_fusion.hpp"
|
||||
#include "transformations/softplus_fusion.hpp"
|
||||
#include "transformations/softplus_to_mish_fusion.hpp"
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
#include "transformations/normalize_l2_fusion.hpp"
|
||||
#include "transformations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/convert_divide.hpp"
|
||||
#include "transformations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/convert_mod.hpp"
|
||||
#include "transformations/convert_minimum_to_power_and_max.hpp"
|
||||
#include "transformations/convert_negative.hpp"
|
||||
#include "transformations/convert_scatter_elements_to_scatter.hpp"
|
||||
#include "transformations/convert_reduce_to_pooling.hpp"
|
||||
#include "transformations/convert_subtract.hpp"
|
||||
#include "transformations/convert_depth_to_space.hpp"
|
||||
#include "transformations/convert_space_to_depth.hpp"
|
||||
#include "transformations/convert_broadcast_to_tiles.hpp"
|
||||
#include "transformations/convert_gelu.hpp"
|
||||
#include "transformations/batch_norm_decomposition.hpp"
|
||||
#include "transformations/pull_transpose_through_fq.hpp"
|
||||
#include "transformations/lin_op_sequence_fusoin.hpp"
|
||||
#include "transformations/reduce_l1_decomposition.hpp"
|
||||
#include "transformations/reduce_l2_decomposition.hpp"
|
||||
#include "transformations/remove_filtering_boxes_by_size.hpp"
|
||||
#include "transformations/hswish_decomposition.hpp"
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -55,6 +74,43 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
|
||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMod>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertGELU>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMinimum>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertSubtract>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertDivide>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertNegative>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->set_name("ngraph::pass::CommonDecompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
// LinOpSequenceFusion must be executed after all decompositions
|
||||
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
fq_fusions->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
@ -20,7 +20,14 @@ ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
|
||||
ngraph::pattern::any_input()},
|
||||
pattern::consumers_count(1));
|
||||
const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
|
||||
{fq_node_p, ngraph::pattern::any_input()});
|
||||
{fq_node_p, ngraph::pattern::any_input()}, [](const Output<Node> & output) {
|
||||
// WA: check that all Reshape node consumers are not GroupConvolution operations
|
||||
const auto & target_inputs = output.get_target_inputs();
|
||||
return std::all_of(target_inputs.begin(), target_inputs.end(),
|
||||
[](const Input<Node> & input){
|
||||
return input.get_node()->get_type_info() != opset4::GroupConvolution::type_info;
|
||||
});
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
const auto &pattern_map = m.get_pattern_value_map();
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcastToTiles, "ConvertBroadcastT
|
||||
ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
||||
auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
|
||||
|
||||
if (!broadcast) {
|
||||
@ -83,7 +83,7 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
||||
}
|
||||
|
||||
auto const_node = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {dims_count}, dims);
|
||||
auto tile = std::make_shared<ngraph::opset1::Tile>(last_node, const_node);
|
||||
auto tile = register_new_node<ngraph::opset1::Tile>(last_node, const_node);
|
||||
new_ops.push_back(tile);
|
||||
tile->set_friendly_name(broadcast->get_friendly_name());
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0);
|
||||
|
||||
void ngraph::pass::ConvertGELU::convert_gelu() {
|
||||
ngraph::pass::ConvertGELU::ConvertGELU() {
|
||||
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
|
||||
|
||||
@ -24,11 +24,11 @@ void ngraph::pass::ConvertGELU::convert_gelu() {
|
||||
auto input_type = input.get_element_type();
|
||||
|
||||
// f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) )
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, ngraph::op::Constant::create(input_type, Shape{}, {0.5}));
|
||||
auto sq2 = std::make_shared<ngraph::opset1::Sqrt>(ngraph::op::Constant::create(input_type, Shape{}, {2.0}));
|
||||
auto div = std::make_shared<ngraph::opset1::Divide>(input, sq2);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, ngraph::opset1::Constant::create(input_type, Shape{}, {0.5}));
|
||||
auto sq2 = std::make_shared<ngraph::opset1::Sqrt>(ngraph::opset1::Constant::create(input_type, Shape{}, {2.0}));
|
||||
auto div = register_new_node<ngraph::opset1::Divide>(input, sq2); // can be decomposed
|
||||
auto erf = std::make_shared<ngraph::opset1::Erf>(div);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(erf, ngraph::op::Constant::create(input_type, Shape{}, {1.0}));
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(erf, ngraph::opset1::Constant::create(input_type, Shape{}, {1.0}));
|
||||
auto res = std::make_shared<ngraph::opset1::Multiply>(mul, add);
|
||||
|
||||
res->set_friendly_name(gelu->get_friendly_name());
|
||||
@ -38,5 +38,5 @@ void ngraph::pass::ConvertGELU::convert_gelu() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gelu, "ConvertGELU");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
|
@ -68,48 +68,26 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet1ToLegacy");
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
// the following two transformations produce ReduceSum operations so they
|
||||
// must be executed before the ConvertReduceSumToPooling transformation
|
||||
manager.register_pass<ngraph::pass::ReduceL1Decomposition>();
|
||||
manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
|
||||
|
||||
// HSwishDecomposition produce Minimum, Relu and Multiply operations
|
||||
// so it must be executed before
|
||||
manager.register_pass<ngraph::pass::HSwishDecomposition>();
|
||||
|
||||
// List if Decomposition and Conversion transformations that can be
|
||||
// applied simultaneously in a single graph traversal
|
||||
// Some passes before ConvertOpSet1ToLegacy can produce some of this
|
||||
// operations. So for convenience we decompose this operations here and
|
||||
// in CommonOptimizations.
|
||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMod>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMinimum>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertSubtract>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertDivide>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertNegative>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
decomp->set_name("ngraph::pass::Decompositions");
|
||||
decomp->set_name("ngraph::pass::LegacyDecompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
auto convert_matmul = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
convert_matmul->add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
convert_matmul->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
convert_matmul->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
convert_matmul->set_name("ngraph::pass::ConvertMatMul");
|
||||
|
||||
// LinOpSequenceFusion must be executed after all decompositions
|
||||
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
// Convolution/Deconvolution/FullyConnected fusions
|
||||
@ -120,18 +98,12 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
convert_convolutions->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||
convert_convolutions->set_name("ngraph::pass::ConvertConvolutions");
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<FakeQuantizeMulFusion>();
|
||||
fq_fusions->add_matcher<FakeQuantizeReshapeFusion>();
|
||||
fq_fusions->add_matcher<PullTransposeThroughFQUp>();
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
// Convolution/Deconvolution/FullyConnected fusions
|
||||
auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fusion->add_matcher<ngraph::pass::ConvAddFusion>();
|
||||
fusion->add_matcher<ngraph::pass::DeconvAddFusion>();
|
||||
fusion->add_matcher<ngraph::pass::FullyConnectedBiasFusion>();
|
||||
fusion->set_name("ngraph::pass::Fusions");
|
||||
fusion->set_name("ngraph::pass::BiasFusions");
|
||||
|
||||
// CF is required after fusions
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
@ -148,6 +120,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
anchor->add_matcher<ngraph::pass::ConvertHardSigmoidToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertProposalToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertProposal4ToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertTileToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertLRNToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertPadToLegacyMatcher>();
|
||||
@ -170,7 +143,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
anchor->add_matcher<ngraph::pass::ConvertGRUSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertRNNSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertLSTMSequenceMatcher>();
|
||||
anchor->set_name("ngraph::pass::ConvertOpSet1ToLegacy");
|
||||
anchor->set_name("ngraph::pass::LegacyConversions");
|
||||
|
||||
// List of final conversion transformations that must to be executed
|
||||
// after previous group of transformations
|
||||
|
@ -21,7 +21,6 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertGELU>();
|
||||
manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
|
||||
manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
|
||||
|
||||
|
@ -29,7 +29,7 @@ ngraph::pass::HSwishDecomposition::HSwishDecomposition() {
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(hswish_node->input_value(0), add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto min = register_new_node<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(hswish_node->input_value(0), min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {(1.0/6.0)}); // const(1/6)
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
@ -25,7 +25,7 @@ ngraph::pass::ReduceL1Decomposition::ReduceL1Decomposition() {
|
||||
}
|
||||
|
||||
auto abs = std::make_shared<ngraph::opset4::Abs>(reduce_l1_node->input_value(0));
|
||||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(abs, reduce_l1_node->input_value(1), reduce_l1_node->get_keep_dims());
|
||||
auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(abs, reduce_l1_node->input_value(1), reduce_l1_node->get_keep_dims());
|
||||
|
||||
reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(reduce_l1_node,
|
||||
|
@ -26,7 +26,7 @@ ngraph::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
|
||||
|
||||
auto const_2 = ngraph::opset4::Constant::create(reduce_l2_node->input_value(0).get_element_type(), Shape{}, {2.0f});
|
||||
auto square = std::make_shared<ngraph::opset4::Power>(reduce_l2_node->input_value(0), const_2);
|
||||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
|
||||
auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
|
||||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
|
||||
reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(reduce_l2_node,
|
||||
|
@ -123,3 +123,49 @@ INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {}, {}, {}, true},
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {}, {}, {}, true}));
|
||||
} // namespace
|
||||
|
||||
TEST(nGraphFQReshapeFusionTests, FQReshapeGroupConvolution) {
|
||||
auto get_function = [](const FQReshapeFusionTestCase & test_case) {
|
||||
const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
|
||||
auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
|
||||
auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
|
||||
auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
|
||||
auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
|
||||
|
||||
auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
|
||||
|
||||
auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
|
||||
ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
|
||||
auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
|
||||
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.data_shape);
|
||||
ngraph::Strides stride{1, 1};
|
||||
ngraph::CoordinateDiff pad{0, 0};
|
||||
auto group_conv = std::make_shared<ngraph::opset4::GroupConvolution>(input, reshape, stride, pad, pad, stride);
|
||||
|
||||
auto result = std::make_shared<ngraph::op::Result>(group_conv);
|
||||
ngraph::ParameterVector params = {il, ih, ol, oh, input};
|
||||
ngraph::ResultVector results = {result};
|
||||
return std::make_shared<ngraph::Function>(results, params);
|
||||
};
|
||||
|
||||
FQReshapeFusionTestCase params;
|
||||
params.data_shape = {1, 2, 1, 3};
|
||||
params.il_shape = {2, 1, 1};
|
||||
params.ih_shape = {1};
|
||||
params.ol_shape = {1, 1};
|
||||
params.oh_shape = {1, 2, 1, 1};
|
||||
params.reshape_pattern = {2, 3, 1, 1, 1};
|
||||
|
||||
auto f = get_function(params);
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto res = compare_functions(f, get_function(params));
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -15,9 +15,9 @@ std::shared_ptr<ngraph::Node> makeBatchNormInference(const ngraph::Output<Node>&
|
||||
size_t C = data.get_shape().at(1);
|
||||
bool random = true;
|
||||
std::vector<float> values(C);
|
||||
auto gamma = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
|
||||
auto beta = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
|
||||
auto mean = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random);
|
||||
auto gamma = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
|
||||
auto beta = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
|
||||
auto mean = ngraph::builder::makeConstant(ngPrc, ngraph::Shape{C}, values, random, 1, 0);
|
||||
|
||||
// Fill the vector for variance with positive values
|
||||
std::default_random_engine gen;
|
||||
|
Loading…
Reference in New Issue
Block a user