Pruning with FQ support (#5925)

* Add FQ, Concat, exted Eltwise support

* Fix tests after rebase + small refactoring

* Added Reshape on GroupConv weights mask propagating

* Added printing of  reduced weights to test transformation

* Turn off pruning for test

* Fixed comments + revert transformation comment

* Fixed last comments
This commit is contained in:
Maria Kaglinskaya 2021-06-08 09:49:53 +03:00 committed by GitHub
parent bc7f61be24
commit 6022df6687
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 946 additions and 77 deletions

View File

@ -54,10 +54,90 @@ public:
});
}
std::vector<size_t> get_not_empty_dims() {
std::vector<size_t> not_empty_dims;
for (size_t i = 0; i < this->size(); i++) {
if (!this->at(i).empty())
not_empty_dims.push_back(i);
}
return not_empty_dims;
}
bool is_shape_like() const { return m_is_shape_like; }
void set_shape_like(bool flag) { m_is_shape_like = flag; }
void copy_value_from_mask(Mask *const mask) {
auto cur_mask_iter = begin();
auto mask_iter = mask->begin();
while (cur_mask_iter != end() && mask_iter != mask->end()) {
*cur_mask_iter = *mask_iter;
cur_mask_iter++;
mask_iter++;
}
}
void copy_value_from_mask_reversed(Mask *const mask) {
auto cur_mask_iter = rbegin();
auto mask_iter = mask->rbegin();
while (cur_mask_iter != rend() && mask_iter != mask->rend()) {
*cur_mask_iter = *mask_iter;
cur_mask_iter++;
mask_iter++;
}
}
Mask::Ptr intersect_masks_reversed(Mask *const mask) {
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
auto result_iter = result_mask->rbegin();
auto mask_1_iter = rbegin();
auto mask_2_iter = mask->rbegin();
while (mask_1_iter != rend() &&
mask_2_iter != mask->rend()) {
// Merge mask dimension values for both masks
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
for (const auto & value : *mask_1_iter) {
if (mask_2_iter->count(value)) {
result_iter->insert(value);
}
}
result_iter++;
mask_1_iter++;
mask_2_iter++;
}
return result_mask;
}
Mask::Ptr union_masks_reversed(Mask *const mask) {
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
auto result_iter = result_mask->rbegin();
auto mask_1_iter = rbegin();
auto mask_2_iter = mask->rbegin();
while (mask_1_iter != rend() &&
mask_2_iter != mask->rend()) {
// Union mask dimension values for both masks
// Example: (MaskValue[1,2,3,4], MaskValue[2, 5]) -> MaskValue[1, 2, 3, 4, 5]
for (const auto & value : *mask_1_iter) {
result_iter->insert(value);
}
for (const auto & value : *mask_2_iter) {
if (!result_iter->count(value)) {
result_iter->insert(value);
}
}
result_iter++;
mask_1_iter++;
mask_2_iter++;
}
return result_mask;
}
void add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
m_callbacks[mask.get()] = receive_callback;
m_dependencies.push_back(mask.get());

View File

@ -14,6 +14,7 @@ namespace ngraph {
namespace pass {
class InitConstMask;
class InitMasks;
class PropagateMasks;
class ShrinkWeights;
@ -22,6 +23,16 @@ class Pruning;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief Initialising masks for pruned operations
*/
class ngraph::pass::InitMasks : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
InitMasks();
};
/**
* @ingroup ie_transformation_common_api
* @brief Check Constant operation values by given dimensions and set

View File

@ -17,7 +17,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::InitConstMask, "InitConstMask", 0);
ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
const std::function<bool(const double & value)> & condition) {
auto constant = pattern::wrap_type<opset6::Constant>(
pattern::type_matches_any({element::f16, element::f32, element::f64}));
pattern::type_matches_any({element::i8, element::u8, element::f16, element::f32, element::f64}));
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(m.get_match_root());

View File

@ -0,0 +1,64 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pruning.hpp"
#include "mask_attribute.hpp"
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/log.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::InitMasks, "InitMasks", 0);
namespace ngraph {
namespace pass {
namespace init_masks {
class InitConvMask;
} // namespace init_masks
} // namespace pass
} // namespace ngraph
class ngraph::pass::init_masks::InitConvMask : public MatcherPass {
public:
InitConvMask() {
auto input = pattern::any_input();
auto weights = pattern::any_input();
auto conv = pattern::wrap_type<opset6::Convolution, opset6::GroupConvolution>({input, weights});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_output = pattern_map.at(conv);
// Initializing weights mask:
// 1. Looking for Const node with weights
NodeVector weights_calculation_nodes;
auto cur_node = m_output.get_node()->get_input_node_shared_ptr(1);
while (!ngraph::is_type<opset6::Constant>(cur_node) && cur_node->inputs().size()) {
weights_calculation_nodes.push_back(cur_node);
cur_node = cur_node->get_input_node_shared_ptr(0);
}
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
NGRAPH_DEBUG << "Can't find Constant weights for Convolution: " <<
m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
// 2. Init mask for Const node
InitConstMask({0}/* check only output channels dim */).apply(cur_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvolutionInitMask");
register_matcher(m, callback);
}
};
ngraph::pass::InitMasks::InitMasks() {
add_matcher<init_masks::InitConvMask>();
}

View File

@ -7,7 +7,9 @@
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/log.hpp>
#include <ngraph/rt_info.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::PropagateMasks, "PropagateMasks", 0);
@ -20,11 +22,23 @@ class GroupConvolution;
class Elementwise;
class PassThrough;
class StopPropagation;
class FakeQuantize;
class Concat;
class Reshape;
} // namespace mask_propagation
} // namespace pass
} // namespace ngraph
ngraph::Shape broadcast_shape_to_rank(ngraph::Shape shape_to_broadcast, int64_t dst_rank) {
auto initial_rank = static_cast<int64_t>(shape_to_broadcast.size());
auto num_of_broadcased_dims = dst_rank - initial_rank;
std::vector<size_t> dims(num_of_broadcased_dims, 1);
dims.insert(dims.end(), shape_to_broadcast.begin(), shape_to_broadcast.end());
auto new_shape = ngraph::Shape(dims);
return new_shape;
}
class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
public:
Convolution() {
@ -38,12 +52,15 @@ public:
const auto & m_output = pattern_map.at(conv);
const auto & m_input = pattern_map.at(input);
// In case if weights are Constant we initialize Mask
InitConstMask({0}/* check only output channel */).apply(m_weights.get_node_shared_ptr());
auto weights_mask = getMask(m_weights);
// If weights are not a Constant and we didn't set Mask value before we will get nullptr
if (!weights_mask) return false;
// Nullptr in weights-mask means that mask for this node wasn't initialized earlier.
// Weights mask for convolution should be initialized in the InitMasks pass (and propagate after it).
// If mask isn't initialized - this weights (and hence all convolution) can't be pruned for some reason.
if (!weights_mask) {
NGRAPH_DEBUG << "No weights mask for " << m_output.get_node()->get_friendly_name() << "\n";
return false;
}
auto weights_mask_row = weights_mask.get();
if (auto input_mask = getMask(m_input)) {
@ -119,9 +136,15 @@ public:
auto weights_mask = getMask(m_weights);
if (!weights_mask) {
// TODO: only if weights are constant
weights_mask = std::make_shared<Mask>(weights_shape.size());
setMask(m_weights, weights_mask);
// Setting mask only if weights are constant
if (ngraph::is_type<opset6::Constant>(m_output.get_node_shared_ptr())) {
weights_mask = std::make_shared<Mask>(weights_shape.size());
setMask(m_weights, weights_mask);
} else {
NGRAPH_DEBUG << "GroupConvolution: No weights mask and weights aren't constant for " <<
*m_output.get_node() << "\n";
return false;
}
}
auto weights_mask_row = weights_mask.get();
@ -169,13 +192,85 @@ public:
}
};
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
public:
Reshape() {
auto input = pattern::any_input(pattern::has_static_shape());
auto shape = pattern::any_input();
// Working only for Reshapes on Group Convolution weights
auto reshape = pattern::wrap_type<opset6::Reshape>({input, shape}, pattern::consumers_count(1));
auto gconv = pattern::wrap_type<opset6::GroupConvolution>({pattern::any_input(), reshape},
pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_shape = pattern_map.at(shape);
const auto & m_output = pattern_map.at(reshape);
const auto & m_input = pattern_map.at(input);
auto shape_val = m_shape.get_node_shared_ptr();
// In Depthwise Convolutions Reshape on weights just add additional dimension for output channels count
// (1 in case of the depthwise) of kernel.
// Example: Reshape from [G, 1 (I), X, Y, Z] -> [G, 1 (O), 1 (I), X, Y, Z], where G - group numbers,
// X, Y, Z - spartial dimensions (can be only X or X, Y), I, O - number of input/output channels of kernel.
// Checking that matched Reshape meets this conditions (add 1-d dim on 1 position of shape constant)
auto inp_shape = m_input.get_shape();
auto out_shape = m_output.get_shape();
inp_shape.insert(inp_shape.begin() + 1, 1);
if (inp_shape != out_shape) {
return false;
}
auto input_mask = getMask(m_input);
if (!input_mask) {
return false;
}
auto input_mask_row = input_mask.get();
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
// Depthwise Convolution pruned only by input channels (== groups) ->
// Propagating mask from Group (0) dim in Reshape input to Group (0) dim in Reshape output and back
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = output_mask_row->at(0);
return true;
}, output_mask);
output_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = input_mask_row->at(0);
return true;
}, input_mask);
input_mask->apply_callback(output_mask);
// To allow pruning on weights (allow reshape input Group (0) dim changing) replace Reshape Shape constant
// [G, 1, 1, X, Y, Z] by [-1, 1, 1, X, Y, Z].
auto old_shape_const = std::dynamic_pointer_cast<opset6::Constant>(m_shape.get_node_shared_ptr());
auto shape_value = old_shape_const.get()->cast_vector<int64_t>();
shape_value[0] = -1;
auto new_const = opset6::Constant::create(old_shape_const->get_element_type(),
old_shape_const->get_shape(), shape_value);
new_const->set_friendly_name(old_shape_const->get_friendly_name());
ngraph::copy_runtime_info(old_shape_const, new_const);
ngraph::replace_node(old_shape_const, new_const);
setMask(m_output, output_mask);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation");
register_matcher(m, callback);
}
};
class ngraph::pass::mask_propagation::Elementwise : public MatcherPass {
public:
Elementwise() {
auto input = pattern::any_input();
auto weights = pattern::any_input();
auto eltwise = pattern::wrap_type<op::util::BinaryElementwiseArithmetic>({input, weights},
pattern::has_static_rank());
auto eltwise = pattern::wrap_type<opset6::Add, opset6::Subtract, opset6::Maximum, opset6::Minimum,
opset6::Multiply>({input, weights}, pattern::has_static_rank());
// TODO: add Div, Power support
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
@ -183,82 +278,275 @@ public:
const auto & m_output = pattern_map.at(eltwise);
const auto & m_input = pattern_map.at(input);
// TODO: implement check that compares input shape ranks
// Case when input masks should be united instead of intersection
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
const auto & input_rank = m_input.get_partial_shape().rank().get_length();
const auto & weights_rank = m_weights.get_partial_shape().rank().get_length();
// Here assuming that masks can be propagated only through 3/4 dimensional tensors
// (since channel dim is necessary)
if (weights_rank < 3 || input_rank < 3) return false;
// In case if one of the inputs is constant
// TODO: need to find channel dimension instead of hardcoded zero
const size_t & channel_dim = (input_rank == weights_rank ? 1 : 0);
InitConstMask({channel_dim}).apply(m_input.get_node_shared_ptr());
InitConstMask({channel_dim}).apply(m_weights.get_node_shared_ptr());
// In case if first of the inputs is constant
InitConstMask({0, 1/* potential output channel dim */}).apply(m_input.get_node_shared_ptr());
auto input_mask = getMask(m_input);
if (!input_mask) {
NGRAPH_DEBUG << "No input mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
InitConstMask({0, 1}).apply(m_weights.get_node_shared_ptr());
auto weights_mask = getMask(m_weights);
auto input_mask = getMask(m_input);
if (!weights_mask || !input_mask) {
NGRAPH_DEBUG << "No mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
if (!weights_mask) {
NGRAPH_DEBUG << "No weights mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
auto input_mask_row = input_mask.get();
auto weights_mask_row = weights_mask.get();
// Merge masks from two inputs
// Merging masks from two inputs
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
auto out_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool {
auto omask_iter = cur_mask->rbegin();
auto imask_iter = input_mask_row->rbegin();
auto wmask_iter = weights_mask_row->rbegin();
for (auto & item : *cur_mask) {
item.clear();
}
while (imask_iter != input_mask_row->rend() &&
wmask_iter != weights_mask_row->rend()) {
// Merge mask dimension values for both masks
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
for (const auto & value : *imask_iter) {
if (wmask_iter->count(value)) {
omask_iter->insert(value);
}
}
omask_iter++;
imask_iter++;
wmask_iter++;
auto out_mask_callback = [input_mask_row, weights_mask_row, union_eltwise_type](Mask::Ptr cur_mask) -> bool {
Mask::Ptr result_mask;
if (union_eltwise_type) {
result_mask = input_mask_row->union_masks_reversed(weights_mask_row);
} else {
result_mask = input_mask_row->intersect_masks_reversed(weights_mask_row);
}
cur_mask->copy_value_from_mask_reversed(result_mask.get());
return true;
};
output_mask->add_callback(out_mask_callback, input_mask);
output_mask->add_callback(out_mask_callback, weights_mask);
auto callback = [output_mask_row](Mask::Ptr cur_mask) -> bool {
auto omask_iter = output_mask_row->rbegin();
auto cmask_iter = cur_mask->rbegin();
while (omask_iter != output_mask_row->rend() &&
cmask_iter != cur_mask->rend()) {
// TODO: check
*cmask_iter = *omask_iter;
omask_iter++;
cmask_iter++;
}
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask_reversed(weights_mask_row);
return true;
};
input_mask->add_callback(callback, output_mask);
weights_mask->add_callback(callback, output_mask);
}, weights_mask);
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask_reversed(output_mask_row);
return true;
}, output_mask);
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask_reversed(input_mask_row);
return true;
}, input_mask);
// Init output mask
output_mask->apply_callback(input_mask);
weights_mask->apply_callback(input_mask);
setMask(m_output, output_mask);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "EltwiseMaskPropagation");
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "ElementwiseMaskPropagation");
register_matcher(m, callback);
}
};
class ngraph::pass::mask_propagation::FakeQuantize : public MatcherPass{
public:
FakeQuantize(){
auto input = pattern::any_input(pattern::has_static_shape());
auto input_low = pattern::any_input(pattern::has_static_shape());
auto input_high = pattern::any_input(pattern::has_static_shape());
auto output_low = pattern::any_input(pattern::has_static_shape());
auto output_high = pattern::any_input(pattern::has_static_shape());
auto fake_quantize = pattern::wrap_type<opset6::FakeQuantize>({input, input_low, input_high, output_low,
output_high});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_input = pattern_map.at(input);
const auto & m_input_low = pattern_map.at(input_low);
const auto & m_input_high = pattern_map.at(input_high);
const auto & m_output_low = pattern_map.at(output_low);
const auto & m_output_high = pattern_map.at(output_high);
const auto & m_output = pattern_map.at(fake_quantize);
auto input_mask = getMask(m_input);
// Input mask is the only source of pruning in FQ
if (!input_mask) {
NGRAPH_DEBUG << "FakeQuantize: No input mask for " << *m_output.get_node() << "\n";
return false;
}
auto input_mask_row = input_mask.get();
// Propagate input mask to output mask and in the opposite direction
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
// Output mask is equal to input mask
auto output_mask_callback = [input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(input_mask_row);
return true;
};
auto input_mask_callback = [output_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(output_mask_row);
return true;
};
output_mask->add_callback(output_mask_callback, input_mask);
input_mask->add_callback(input_mask_callback, output_mask);
// Calculate output mask
output_mask->apply_callback(input_mask);
setMask(m_output, output_mask);
auto input_low_size = shape_size(m_input_low.get_shape());
auto input_high_size = shape_size(m_input_high.get_shape());
auto output_low_size = shape_size(m_output_low.get_shape());
auto output_high_size = shape_size(m_output_high.get_shape());
// In the per-tensor case FQ params shouldn't be pruned
if (input_low_size == 1 && output_low_size == 1 && input_high_size == 1 && output_high_size == 1) {
return true;
}
// If input/output ranges in FQ should be broadcasted to input shape -> broadcast this consant values
// for the convenience of working with the masks
NodeVector fq_params_nodes{m_input_low.get_node_shared_ptr(),
m_input_high.get_node_shared_ptr(),
m_output_low.get_node_shared_ptr(),
m_output_high.get_node_shared_ptr()};
auto fq_node = std::dynamic_pointer_cast<op::FakeQuantize>(m_output.get_node_shared_ptr());
size_t idx = 0;
if (fq_node->get_auto_broadcast() != ngraph::op::AutoBroadcastType::NONE) {
for (auto const_node : fq_params_nodes) {
auto new_shape = broadcast_shape_to_rank(const_node->get_shape(),
m_input.get_partial_shape().rank().get_length());
auto const_copy = const_node->clone_with_new_inputs(const_node->input_values());
auto new_const = std::dynamic_pointer_cast<op::Constant>(const_copy);
new_const->set_data_shape(new_shape);
new_const->validate_and_infer_types();
new_const->set_friendly_name(const_node->get_friendly_name());
ngraph::copy_runtime_info(const_node, new_const);
ngraph::replace_node(const_node, new_const);
fq_params_nodes[idx++] = new_const;
}
}
auto fq_params_mask_callback = [input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/* fq params have same shapes as input */) = input_mask_row->at(1 /* channel dim in data */);
return true;
};
for (auto fq_param : fq_params_nodes) {
auto mask = std::make_shared<Mask>(fq_param->get_shape().size());
mask->add_callback(fq_params_mask_callback, input_mask);
input_mask->add_callback([mask](Mask::Ptr cur_mask) -> bool {
return true;
}, mask);
mask->apply_callback(input_mask);
setMask(fq_param->output(0), mask);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(fake_quantize, "FakeQuantizeMaskPropagation");
register_matcher(m, callback);
}
};
class ngraph::pass::mask_propagation::Concat : public MatcherPass{
public:
Concat() {
auto concat = pattern::wrap_type<opset6::Concat>(pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_output = pattern_map.at(concat);
auto concat_ptr = std::dynamic_pointer_cast<opset6::Concat>(m_output.get_node_shared_ptr());
auto axis = concat_ptr->get_concatenation_axis();
auto inputs = concat_ptr->inputs();
std::map<int64_t , Mask::Ptr> input_masks;
std::map<int64_t , Mask *> input_masks_row;
std::vector<int64_t> input_sizes;
size_t first_input_idx = 0;
Mask::Ptr first_input_mask;
bool first_initialized = false;
for (size_t i=0; i < inputs.size(); i++) {
auto input = inputs[i];
auto input_mask = getMask(input.get_source_output());
if (input_mask) {
input_masks[i] = input_mask;
input_masks_row[i] = input_mask.get();
if (!first_initialized) {
first_input_idx = i;
first_input_mask = input_mask;
first_initialized = true;
}
}
input_sizes.push_back(input.get_shape().at(axis));
}
if (!first_initialized) {
return false;
}
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
auto out_mask_callback = [input_masks_row, input_sizes, axis](Mask::Ptr cur_mask) -> bool {
int64_t cur_size = 0;
cur_mask->at(axis).clear();
for (size_t i=0; i < input_sizes.size(); ++i) {
if (input_masks_row.count(i)) {
for (auto idx : input_masks_row.at(i)->at(axis)) {
cur_mask->at(axis).insert(idx + cur_size);
}
}
cur_size += input_sizes[i];
}
return true;
};
auto create_input_mask_callback_for_idx = [output_mask_row, input_sizes, axis](size_t input_idx){
auto input_mask_callback = [output_mask_row, input_sizes, axis, input_idx](Mask::Ptr cur_mask) -> bool {
cur_mask->clean_dim_values();
uint64_t min_val = 0;
for (size_t i = 0; i < input_idx; i++) {
min_val += input_sizes[i];
}
uint64_t max_val = min_val + input_sizes[input_idx];
for (auto idx : output_mask_row->at(axis)) {
if (idx < max_val && idx >= min_val) {
cur_mask->at(axis).insert(idx - min_val);
}
}
return true;
};
return input_mask_callback;
};
output_mask->add_callback(out_mask_callback, first_input_mask);
for (size_t i=0; i < inputs.size(); ++i) {
if (input_masks.count(i) && i != first_input_idx) {
auto input_mask = input_masks.at(i);
input_mask->add_callback(create_input_mask_callback_for_idx(i),
first_input_mask);
first_input_mask->add_callback([](Mask::Ptr cur_mask) -> bool {
return true;
}, input_mask);
}
}
first_input_mask->add_callback(create_input_mask_callback_for_idx(first_input_idx),
output_mask);
output_mask->apply_callback(first_input_mask);
setMask(m_output, output_mask);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMaskPropagation");
register_matcher(m, callback);
}
};
@ -266,7 +554,9 @@ public:
class ngraph::pass::mask_propagation::PassThrough : public MatcherPass {
public:
PassThrough() {
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp>();
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp,
opset6::Convert, opset6::ConvertLike, opset6::AvgPool, opset6::MaxPool,
opset6::ROIPooling, opset6::PSROIPooling, opset6::Pad>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
@ -312,5 +602,8 @@ ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::GroupConvolution>();
add_matcher<mask_propagation::Elementwise>();
add_matcher<mask_propagation::PassThrough>();
add_matcher<mask_propagation::FakeQuantize>();
add_matcher<mask_propagation::Concat>();
add_matcher<mask_propagation::Reshape>();
add_matcher<mask_propagation::StopPropagation>();
}

View File

@ -15,8 +15,13 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::Pruning, "Pruning", 0);
bool ngraph::pass::Pruning::run_on_function(std::shared_ptr<Function> f) {
Manager manager(get_pass_config());
// Initialize masks only for Convolutions/GroupConvolutions weights (needed to init mask in source Constant of
// weights-calculating subgraph). For other node types masks initialized in PropagateMasks pass.
manager.register_pass<InitMasks>();
manager.register_pass<PropagateMasks>();
#ifdef NGRAPH_DEBUG_ENABLE
// VisualizeTree modifier helps to print Masks and mark nodes with masks
/*

View File

@ -54,6 +54,8 @@ bool ngraph::pass::ShrinkWeights::run_on_function(std::shared_ptr<ngraph::Functi
for (size_t dim = 0; dim < mask->size(); ++dim) {
const auto &dim_size = mask->at(dim).size();
if (dim_size == 0) continue;
// Broadcastable 1-size dimension shouldn't be shrank with mask
if (const_node->get_shape().at(dim) == 1 && dim_size > 1) continue;
// Convert dims that we want remove to dims that we need to keep
std::vector<int64_t> dims_to_keep;

View File

@ -15,6 +15,7 @@
#include <transformations/init_node_info.hpp>
#include <ngraph/coordinate_transform.hpp>
#include <ngraph/pass/manager.hpp>
#include <inference_engine.hpp>
using namespace testing;
using namespace ngraph;
@ -67,6 +68,23 @@ TEST(TransformationTests, InitMasksOutputChannel) {
compare_masks(*getMask(weights->output(0)), {{}, {1}, {}, {}});
}
// TODO: add test init masks with subgraph
TEST(TransformationTests, TestInitMasks) {
Shape weights_shape{6, 3, 3, 3};
Shape input_shape{1, 3, 64, 64};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{1, 2, 3}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.run_passes(f);
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), {{1, 2, 3}, {}, {}, {}});
}
TEST(TransformationTests, InitMasksNegative) {
Shape weights_shape{6, 3, 3, 3};
auto weights = opset5::Constant::create(element::f32, weights_shape, {0.5});
@ -85,6 +103,7 @@ TEST(TransformationTests, PropagateMasksNegative) {
auto f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
@ -102,27 +121,35 @@ TEST(TransformationTests, PropagateMasksBasic) {
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto add_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {1, 2, 3, 4, 5}, {}, {}});
auto add = std::make_shared<opset5::Add>(relu, add_const);
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}});
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {4}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
auto weights2 = create_constant_with_zeros(weights_shape2, {{1, 2}, {1, 2, 3}, {}, {}});
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights->output(0)), Mask({{2}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{2}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{2}, {}, {}}));
compare_masks(*getMask(weights2->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
}
@ -148,6 +175,7 @@ TEST(TransformationTests, PropagateMasksDynamicConvolution) {
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
@ -182,6 +210,7 @@ TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) {
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
}
@ -199,15 +228,16 @@ TEST(TransformationTests, PropagateMasksEmpty) {
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
auto add_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
auto add = std::make_shared<opset5::Subtract>(sub, add_const);
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
auto conv2 = std::make_shared<opset5::Convolution>(add, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
@ -215,11 +245,55 @@ TEST(TransformationTests, PropagateMasksEmpty) {
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{}, {}, {}}));
compare_masks(*getMask(add_const), Mask({{}, {}, {}}));
compare_masks(*getMask(weights2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
}
TEST(TransformationTests, PropagateMaskPassThrough) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
Shape weight_shape2{3, 8, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights_const_1 = create_constant_with_zeros(weights_shape, {{1, 2, 3}, {}, {}, {}});
weights_const_1.get_node_shared_ptr()->set_friendly_name("weights_1");
auto conv_1 = std::make_shared<opset5::Convolution>(input, weights_const_1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv_1->set_friendly_name("conv_1");
// Adding a couple of PassThrough operations
auto relu = std::make_shared<opset5::Relu>(conv_1);
relu->set_friendly_name("relu");
auto clamp = std::make_shared<opset5::Clamp>(relu, 0, 6);
clamp->set_friendly_name("clamp");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(clamp, pads_begin, pads_end, op::PadMode::CONSTANT);
auto max_pool = std::make_shared<opset5::MaxPool>(pad, Strides{1, 1},
Shape{0, 0}, Shape{1, 1}, Shape{4, 4});
max_pool->set_friendly_name("max_pool");
auto weights2 = opset5::Constant::create(element::f32, weight_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(max_pool, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights_const_1.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(clamp->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(max_pool->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
}
TEST(TransformationTests, PropagateMasksHardDependencies) {
Shape input_shape{1, 3, 3, 3};
@ -280,4 +354,344 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
// compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
// compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
// compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}}));
}
}
TEST(TransformationTests, PropagateMasksQuantizedGroupConvolution) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
Shape weights_group_shape{8, 1, 3, 3};
Shape weight_shape2{3, 8, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_group = opset5::Constant::create(element::i8, weights_group_shape, {0});
weights_group->set_friendly_name("weights_group");
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
auto reshape = std::make_shared<opset5::Reshape>(mul, opset5::Constant::create(element::i64, Shape{5}, {8, 1, 1, 3, 3}), false);
auto conv_group = std::make_shared<opset5::GroupConvolution>(conv1, reshape, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
auto add = std::make_shared<opset5::Add>(conv_group, add_const);
add->set_friendly_name("add");
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(add, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(f);
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
compare_masks(*getMask(weights_group->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(reshape->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}, {}}));
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
}
TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
Shape weight_shape2{3, 8, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
weights_1->set_friendly_name("weights_int8_const");
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv1->set_friendly_name("conv1");
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
auto add = std::make_shared<opset5::Add>(conv1, add_const);
add->set_friendly_name("add");
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 1, 1, 1}, {20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {1});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(f);
compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(fq->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
}
TEST(TransformationTests, PropagateMasksFakeQuantizePerChannel) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
Shape weight_shape2{3, 8, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
weights_1->set_friendly_name("weights_int8_const");
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv1->set_friendly_name("conv1");
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
auto add = std::make_shared<opset5::Add>(conv1, add_const);
add->set_friendly_name("add");
auto input_low = opset5::Constant::create(element::f32, Shape{1, 8, 1, 1}, {0});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 8, 1, 1}, {20});
auto output_low = opset5::Constant::create(element::f32, Shape{8, 1, 1}, {1});
auto output_high = opset5::Constant::create(element::f32, Shape{8, 1, 1}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(fq->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(fq->input(1).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(fq->input(2).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(fq->input(3).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(fq->input(4).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
}
TEST(TransformationTests, TestConcatMaskPropagation) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape1{8, 3, 3, 3};
Shape weights_shape2{16, 3, 3, 3};
Shape weights_shape3{8, 3, 3, 3};
Shape weight_shape_out_conv{3, 32, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_3 = create_constant_with_zeros(weights_shape3, {{4, 5, 6, 7}, {}, {}, {}});
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}});
auto conv_out = std::make_shared<opset5::Convolution>(concat, weights_out_conv, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv_out}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{7, 8, 9, 10}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {7, 8, 9, 10}, {}, {}}));
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{4, 5, 6, 7}, {}, {}, {}}));
compare_masks(*getMask(conv3->output(0)), Mask({{}, {4, 5, 6, 7}, {}, {}}));
compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
}
TEST(TransformationTests, TestConcatMaskPropagationUp) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape1{8, 3, 3, 3};
Shape weights_shape2{16, 3, 3, 3};
Shape weights_shape3{8, 3, 3, 3};
Shape weight_shape_out_conv{3, 32, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3, 4, 5}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_3 = create_constant_with_zeros(weights_shape3, {{2, 3, 4, 5, 6, 7}, {}, {}, {}});
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
auto add_const = create_constant_with_zeros(Shape{1, 32, 1, 1}, {{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}});
auto add = std::make_shared<opset5::Add>(concat, add_const);
auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}});
auto conv_out = std::make_shared<opset5::Convolution>(add, weights_out_conv, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto f = std::make_shared<Function>(NodeVector{conv_out}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{7, 8, 9, 10}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {7, 8, 9, 10}, {}, {}}));
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{4, 5, 6, 7}, {}, {}, {}}));
compare_masks(*getMask(conv3->output(0)), Mask({{}, {4, 5, 6, 7}, {}, {}}));
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
}
TEST(TransformationTests, TestConcatMaskPropagationUpEmpty) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape1{8, 3, 3, 3};
Shape weights_shape2{16, 3, 3, 3};
Shape weights_shape3{8, 3, 3, 3};
Shape weight_shape_out_conv{3, 32, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3, 4, 5}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_3 = create_constant_with_zeros(weights_shape3, {{2, 3, 4, 5, 6, 7}, {}, {}, {}});
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
auto add_const = create_constant_with_zeros(Shape{1, 32, 1, 1}, {{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}});
auto add = std::make_shared<opset5::Add>(concat, add_const);
auto f = std::make_shared<Function>(NodeVector{add}, ParameterVector{input});
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(f);
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv3->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(concat->output(0)), Mask({{}, {}, {}, {}}));
}