[Transformations] Enable missing runtime info check (#15796)

* Add rt info propagation to StridesOptimization

* Enable rt info check for pruning tests
This commit is contained in:
Tomasz Jankowski 2023-02-23 16:14:13 +01:00 committed by GitHub
parent 6359926815
commit e8d1be6e0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 71 deletions

View File

@ -13,7 +13,9 @@
#include "itt.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset7;
static bool can_propagate_conv_stride(const std::shared_ptr<ngraph::Node>& conv) {
const auto& kernel_shape = conv->input_value(1).get_shape();
@ -39,40 +41,36 @@ static std::tuple<ngraph::Strides, bool> check_next_ops(const std::vector<ngraph
return std::make_tuple(strides[0], all_ops_are_valid);
}
static void insert_pooling(const ngraph::Output<ngraph::Node>& first,
ngraph::Input<ngraph::Node>& second,
const ngraph::Strides& strides) {
static void insert_pooling(const Output<Node>& first, Input<Node>& second, const Strides& strides) {
pass::NodeRegistry rg;
auto first_node = first.get_node_shared_ptr();
auto rank = first.get_partial_shape().rank();
bool do_reshape = rank.is_static() && static_cast<size_t>(rank.get_length()) < strides.size() + 2;
const auto rank = first.get_partial_shape().rank();
const bool do_reshape = rank.is_static() && static_cast<size_t>(rank.get_length()) < strides.size() + 2;
if (do_reshape) {
size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
auto ones = opset7::Constant::create(ngraph::element::i64, ngraph::Shape{diff}, std::vector<int64_t>(diff, 1));
auto current_shape = std::make_shared<opset7::ShapeOf>(first);
std::shared_ptr<ngraph::Node> new_shape =
std::make_shared<opset7::Concat>(ngraph::OutputVector{ones, current_shape}, 0);
std::shared_ptr<ngraph::Node> constant_new_shape = get_constant_from_source(new_shape);
if (constant_new_shape)
const size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
const auto ones = rg.make<Constant>(element::i64, Shape{diff}, vector<int64_t>(diff, 1));
const auto current_shape = rg.make<ShapeOf>(first);
shared_ptr<Node> new_shape = rg.make<Concat>(OutputVector{ones, current_shape}, 0);
if (const auto constant_new_shape = get_constant_from_source(new_shape)) {
rg.add(constant_new_shape);
new_shape = constant_new_shape;
first_node = std::make_shared<opset7::Reshape>(first_node, new_shape, false);
}
first_node = rg.make<Reshape>(first_node, new_shape, false);
}
std::shared_ptr<ngraph::Node> new_node = std::make_shared<opset7::MaxPool>(first_node,
strides,
ngraph::Shape{},
ngraph::Shape{},
ngraph::Shape(strides.size(), 1));
shared_ptr<Node> new_node = rg.make<MaxPool>(first_node, strides, Shape{}, Shape{}, Shape(strides.size(), 1));
if (do_reshape) {
// squeeze dimensions back
size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
std::vector<size_t> axes(diff);
std::iota(axes.begin(), axes.end(), 0);
new_node = std::make_shared<opset7::Squeeze>(
new_node,
opset7::Constant::create(ngraph::element::u64, ngraph::Shape{diff}, axes));
const size_t diff = strides.size() + 2 - static_cast<size_t>(rank.get_length());
vector<size_t> axes(diff);
iota(axes.begin(), axes.end(), 0);
new_node = rg.make<Squeeze>(new_node, rg.make<Constant>(element::u64, Shape{diff}, axes));
}
std::shared_ptr<ngraph::Node> constant_new_node = get_constant_from_source(new_node);
if (constant_new_node)
if (const auto constant_new_node = get_constant_from_source(new_node)) {
rg.add(constant_new_node);
new_node = constant_new_node;
}
copy_runtime_info(as_node_vector({second.get_source_output()}), rg.get());
second.replace_source_output(new_node);
}

View File

@ -264,9 +264,6 @@ TEST_F(TransformationTestsF, StridesOptimization5) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_2}, ngraph::ParameterVector{data});
}
// TODO: update transformation and remove this check XXX-68696
disable_rt_info_check();
}
// Pl->Conv(1x1,1x1)->Conv(1x1,2x2)->Conv(3x3,1x1)->Conv(1x1,2x2)
@ -424,8 +421,6 @@ TEST_F(TransformationTestsF, StridesOptimization7) {
function_ref =
std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3, conv_4}, ngraph::ParameterVector{data});
}
// TODO: update transformation and remove this check XXX-68696
disable_rt_info_check();
}
// Pl--->Conv(1x1,1x1)->ReLU--->Eltwise-->Conv(1x1,2x2)-->Eltwise-->Conv(1x1, 2x2)
@ -517,8 +512,6 @@ TEST_F(TransformationTestsF, StridesOptimization8) {
function_ref =
std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3}, ngraph::ParameterVector{data, data_2});
}
// TODO: update transformation and remove this check XXX-68696
disable_rt_info_check();
}
// Pl------->Conv(1x1,1x1)------>Eltwise------>Conv(1x1,2x2)---->Eltwise-->Conv(1x1, 2x2)
@ -636,6 +629,4 @@ TEST_F(TransformationTestsF, StridesOptimization9) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv_3},
ngraph::ParameterVector{data, data_2, data_3});
}
// TODO: update transformation and remove this check XXX-68696
disable_rt_info_check();
}

View File

@ -287,7 +287,6 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -371,7 +370,6 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) {
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -599,7 +597,6 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
compare_masks(*getMask(max_pool->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -768,7 +765,6 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
// compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -915,7 +911,6 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1084,7 +1079,6 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf)
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1222,7 +1216,6 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1427,7 +1420,6 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
compare_masks(*getMask(fq->input(4).get_source_output()), Mask({{}, {0, 1, 2, 3, 4}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1559,7 +1551,6 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) {
Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1707,7 +1698,6 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) {
Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1878,7 +1868,6 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) {
compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2070,7 +2059,6 @@ TEST_F(TransformationTestsF, PruneReducelayerUp) {
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2174,7 +2162,6 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) {
compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2354,7 +2341,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2467,7 +2453,6 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationReshapeUpWithShapeOf) {
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2579,7 +2564,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpShapeSubGraph) {
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2678,7 +2662,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeExtend) {
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2784,7 +2767,6 @@ TEST_F(DISABLED_TransformationTestsF, MaskPropagationReshapeDownMul) {
compare_masks(*getMask(last_conv->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2889,7 +2871,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDownAdd) {
compare_masks(*getMask(last_conv->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3054,7 +3035,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUnsqueezeUp) {
compare_masks(*getMask(mul_left->output(0)), Mask({{}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3119,7 +3099,6 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUnsqueezeDown) {
compare_masks(*getMask(mul_left->output(0)), Mask({{}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3292,7 +3271,6 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
compare_masks(*getMask(end_conv->output(0)), Mask({{}, {}, {}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3395,7 +3373,6 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) {
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3658,7 +3635,6 @@ TEST_F(TransformationTestsF, MaskPropagationLinearOuterDims) {
compare_masks(*getMask(last_mul->output(0)), Mask({{}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3808,7 +3784,6 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -3898,7 +3873,6 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4003,7 +3977,6 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) {
compare_masks(*getMask(linear->output(0)), Mask{{}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4076,7 +4049,6 @@ TEST_F(TransformationTestsF, PropagateFlattenDown) {
compare_masks(*getMask(linear->output(0)), {{}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4126,7 +4098,6 @@ TEST_F(TransformationTestsF, PropagateMasksTranspose) {
compare_masks(*getMask(last_mul->output(0)), Mask{{}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4200,7 +4171,6 @@ TEST_F(TransformationTestsF, PropagateMasksTransposeComplex) {
compare_masks(*getMask(last_mul->output(0)), Mask{{}, {}, {}, {}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4402,7 +4372,6 @@ TEST_F(DISABLED_TransformationTestsF, PropagateMasksBroadcastedEltwiseWithInputs
compare_masks(*getMask(last_mul->output(0)), Mask({{}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4583,7 +4552,6 @@ TEST_F(TransformationTestsF, PropagateMasksBroadcastedEltwise) {
compare_masks(*getMask(last_mul->output(0)), Mask({{}, {}}));
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4773,7 +4741,6 @@ TEST_F(TransformationTestsF, MaskPropagationComplexReshape) {
std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationComplexReshapeWithMasks.svg",
modifier);
}
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -4966,7 +4933,6 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationReshapedPassThroughP) {
manager.register_pass<ngraph::pass::VisualizeTree>(
std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReverseFlattenWithMasks" + postfix + ".svg",
modifier);
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -5032,7 +4998,6 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationBroadcastedSameRankEltwiseS
compare_masks(*getMask(mul_last->output(0)), Mask{{}, {}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -5194,7 +5159,6 @@ TEST_F(TransformationTestsF, MaskPropagationMatMulWithSeveralOutputs) {
compare_masks(*getMask(right_matmul), Mask{{}, {}});
manager.register_pass<pass::ShrinkWeights>();
disable_rt_info_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}