Pruning with FQ support (#5925)
* Add FQ, Concat, exted Eltwise support * Fix tests after rebase + small refactoring * Added Reshape on GroupConv weights mask propagating * Added printing of reduced weights to test transformation * Turn off pruning for test * Fixed comments + revert transformation comment * Fixed last comments
This commit is contained in:
parent
bc7f61be24
commit
6022df6687
@ -54,10 +54,90 @@ public:
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<size_t> get_not_empty_dims() {
|
||||
std::vector<size_t> not_empty_dims;
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
if (!this->at(i).empty())
|
||||
not_empty_dims.push_back(i);
|
||||
}
|
||||
return not_empty_dims;
|
||||
}
|
||||
|
||||
bool is_shape_like() const { return m_is_shape_like; }
|
||||
|
||||
void set_shape_like(bool flag) { m_is_shape_like = flag; }
|
||||
|
||||
void copy_value_from_mask(Mask *const mask) {
|
||||
auto cur_mask_iter = begin();
|
||||
auto mask_iter = mask->begin();
|
||||
while (cur_mask_iter != end() && mask_iter != mask->end()) {
|
||||
*cur_mask_iter = *mask_iter;
|
||||
|
||||
cur_mask_iter++;
|
||||
mask_iter++;
|
||||
}
|
||||
}
|
||||
|
||||
void copy_value_from_mask_reversed(Mask *const mask) {
|
||||
auto cur_mask_iter = rbegin();
|
||||
auto mask_iter = mask->rbegin();
|
||||
while (cur_mask_iter != rend() && mask_iter != mask->rend()) {
|
||||
*cur_mask_iter = *mask_iter;
|
||||
|
||||
cur_mask_iter++;
|
||||
mask_iter++;
|
||||
}
|
||||
}
|
||||
|
||||
Mask::Ptr intersect_masks_reversed(Mask *const mask) {
|
||||
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
|
||||
auto result_iter = result_mask->rbegin();
|
||||
auto mask_1_iter = rbegin();
|
||||
auto mask_2_iter = mask->rbegin();
|
||||
|
||||
while (mask_1_iter != rend() &&
|
||||
mask_2_iter != mask->rend()) {
|
||||
// Merge mask dimension values for both masks
|
||||
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
|
||||
for (const auto & value : *mask_1_iter) {
|
||||
if (mask_2_iter->count(value)) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
result_iter++;
|
||||
mask_1_iter++;
|
||||
mask_2_iter++;
|
||||
}
|
||||
return result_mask;
|
||||
}
|
||||
|
||||
Mask::Ptr union_masks_reversed(Mask *const mask) {
|
||||
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
|
||||
auto result_iter = result_mask->rbegin();
|
||||
auto mask_1_iter = rbegin();
|
||||
auto mask_2_iter = mask->rbegin();
|
||||
|
||||
while (mask_1_iter != rend() &&
|
||||
mask_2_iter != mask->rend()) {
|
||||
// Union mask dimension values for both masks
|
||||
// Example: (MaskValue[1,2,3,4], MaskValue[2, 5]) -> MaskValue[1, 2, 3, 4, 5]
|
||||
for (const auto & value : *mask_1_iter) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
for (const auto & value : *mask_2_iter) {
|
||||
if (!result_iter->count(value)) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
result_iter++;
|
||||
mask_1_iter++;
|
||||
mask_2_iter++;
|
||||
}
|
||||
return result_mask;
|
||||
}
|
||||
|
||||
void add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
|
||||
m_callbacks[mask.get()] = receive_callback;
|
||||
m_dependencies.push_back(mask.get());
|
||||
|
@ -14,6 +14,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class InitConstMask;
|
||||
class InitMasks;
|
||||
class PropagateMasks;
|
||||
class ShrinkWeights;
|
||||
|
||||
@ -22,6 +23,16 @@ class Pruning;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Initialising masks for pruned operations
|
||||
*/
|
||||
class ngraph::pass::InitMasks : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
InitMasks();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Check Constant operation values by given dimensions and set
|
||||
|
@ -17,7 +17,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::InitConstMask, "InitConstMask", 0);
|
||||
ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
|
||||
const std::function<bool(const double & value)> & condition) {
|
||||
auto constant = pattern::wrap_type<opset6::Constant>(
|
||||
pattern::type_matches_any({element::f16, element::f32, element::f64}));
|
||||
pattern::type_matches_any({element::i8, element::u8, element::f16, element::f32, element::f64}));
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(m.get_match_root());
|
||||
|
@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pruning.hpp"
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::InitMasks, "InitMasks", 0);
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace init_masks {
|
||||
|
||||
class InitConvMask;
|
||||
|
||||
} // namespace init_masks
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::init_masks::InitConvMask : public MatcherPass {
|
||||
public:
|
||||
InitConvMask() {
|
||||
auto input = pattern::any_input();
|
||||
auto weights = pattern::any_input();
|
||||
auto conv = pattern::wrap_type<opset6::Convolution, opset6::GroupConvolution>({input, weights});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_output = pattern_map.at(conv);
|
||||
|
||||
// Initializing weights mask:
|
||||
// 1. Looking for Const node with weights
|
||||
NodeVector weights_calculation_nodes;
|
||||
auto cur_node = m_output.get_node()->get_input_node_shared_ptr(1);
|
||||
|
||||
while (!ngraph::is_type<opset6::Constant>(cur_node) && cur_node->inputs().size()) {
|
||||
weights_calculation_nodes.push_back(cur_node);
|
||||
cur_node = cur_node->get_input_node_shared_ptr(0);
|
||||
}
|
||||
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for Convolution: " <<
|
||||
m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2. Init mask for Const node
|
||||
InitConstMask({0}/* check only output channels dim */).apply(cur_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvolutionInitMask");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
ngraph::pass::InitMasks::InitMasks() {
|
||||
add_matcher<init_masks::InitConvMask>();
|
||||
}
|
||||
|
@ -7,7 +7,9 @@
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PropagateMasks, "PropagateMasks", 0);
|
||||
|
||||
@ -20,11 +22,23 @@ class GroupConvolution;
|
||||
class Elementwise;
|
||||
class PassThrough;
|
||||
class StopPropagation;
|
||||
class FakeQuantize;
|
||||
class Concat;
|
||||
class Reshape;
|
||||
|
||||
} // namespace mask_propagation
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
ngraph::Shape broadcast_shape_to_rank(ngraph::Shape shape_to_broadcast, int64_t dst_rank) {
|
||||
auto initial_rank = static_cast<int64_t>(shape_to_broadcast.size());
|
||||
auto num_of_broadcased_dims = dst_rank - initial_rank;
|
||||
std::vector<size_t> dims(num_of_broadcased_dims, 1);
|
||||
dims.insert(dims.end(), shape_to_broadcast.begin(), shape_to_broadcast.end());
|
||||
auto new_shape = ngraph::Shape(dims);
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
|
||||
public:
|
||||
Convolution() {
|
||||
@ -38,12 +52,15 @@ public:
|
||||
const auto & m_output = pattern_map.at(conv);
|
||||
const auto & m_input = pattern_map.at(input);
|
||||
|
||||
// In case if weights are Constant we initialize Mask
|
||||
InitConstMask({0}/* check only output channel */).apply(m_weights.get_node_shared_ptr());
|
||||
|
||||
auto weights_mask = getMask(m_weights);
|
||||
// If weights are not a Constant and we didn't set Mask value before we will get nullptr
|
||||
if (!weights_mask) return false;
|
||||
|
||||
// Nullptr in weights-mask means that mask for this node wasn't initialized earlier.
|
||||
// Weights mask for convolution should be initialized in the InitMasks pass (and propagate after it).
|
||||
// If mask isn't initialized - this weights (and hence all convolution) can't be pruned for some reason.
|
||||
if (!weights_mask) {
|
||||
NGRAPH_DEBUG << "No weights mask for " << m_output.get_node()->get_friendly_name() << "\n";
|
||||
return false;
|
||||
}
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
|
||||
if (auto input_mask = getMask(m_input)) {
|
||||
@ -119,9 +136,15 @@ public:
|
||||
|
||||
auto weights_mask = getMask(m_weights);
|
||||
if (!weights_mask) {
|
||||
// TODO: only if weights are constant
|
||||
weights_mask = std::make_shared<Mask>(weights_shape.size());
|
||||
setMask(m_weights, weights_mask);
|
||||
// Setting mask only if weights are constant
|
||||
if (ngraph::is_type<opset6::Constant>(m_output.get_node_shared_ptr())) {
|
||||
weights_mask = std::make_shared<Mask>(weights_shape.size());
|
||||
setMask(m_weights, weights_mask);
|
||||
} else {
|
||||
NGRAPH_DEBUG << "GroupConvolution: No weights mask and weights aren't constant for " <<
|
||||
*m_output.get_node() << "\n";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
|
||||
@ -169,13 +192,85 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
|
||||
public:
|
||||
Reshape() {
|
||||
auto input = pattern::any_input(pattern::has_static_shape());
|
||||
auto shape = pattern::any_input();
|
||||
// Working only for Reshapes on Group Convolution weights
|
||||
auto reshape = pattern::wrap_type<opset6::Reshape>({input, shape}, pattern::consumers_count(1));
|
||||
auto gconv = pattern::wrap_type<opset6::GroupConvolution>({pattern::any_input(), reshape},
|
||||
pattern::has_static_shape());
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_shape = pattern_map.at(shape);
|
||||
const auto & m_output = pattern_map.at(reshape);
|
||||
const auto & m_input = pattern_map.at(input);
|
||||
|
||||
auto shape_val = m_shape.get_node_shared_ptr();
|
||||
|
||||
// In Depthwise Convolutions Reshape on weights just add additional dimension for output channels count
|
||||
// (1 in case of the depthwise) of kernel.
|
||||
// Example: Reshape from [G, 1 (I), X, Y, Z] -> [G, 1 (O), 1 (I), X, Y, Z], where G - group numbers,
|
||||
// X, Y, Z - spartial dimensions (can be only X or X, Y), I, O - number of input/output channels of kernel.
|
||||
|
||||
// Checking that matched Reshape meets this conditions (add 1-d dim on 1 position of shape constant)
|
||||
auto inp_shape = m_input.get_shape();
|
||||
auto out_shape = m_output.get_shape();
|
||||
inp_shape.insert(inp_shape.begin() + 1, 1);
|
||||
if (inp_shape != out_shape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_mask = getMask(m_input);
|
||||
if (!input_mask) {
|
||||
return false;
|
||||
}
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
auto output_mask_row = output_mask.get();
|
||||
|
||||
// Depthwise Convolution pruned only by input channels (== groups) ->
|
||||
// Propagating mask from Group (0) dim in Reshape input to Group (0) dim in Reshape output and back
|
||||
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(0) = output_mask_row->at(0);
|
||||
return true;
|
||||
}, output_mask);
|
||||
output_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(0) = input_mask_row->at(0);
|
||||
return true;
|
||||
}, input_mask);
|
||||
input_mask->apply_callback(output_mask);
|
||||
|
||||
// To allow pruning on weights (allow reshape input Group (0) dim changing) replace Reshape Shape constant
|
||||
// [G, 1, 1, X, Y, Z] by [-1, 1, 1, X, Y, Z].
|
||||
auto old_shape_const = std::dynamic_pointer_cast<opset6::Constant>(m_shape.get_node_shared_ptr());
|
||||
auto shape_value = old_shape_const.get()->cast_vector<int64_t>();
|
||||
shape_value[0] = -1;
|
||||
auto new_const = opset6::Constant::create(old_shape_const->get_element_type(),
|
||||
old_shape_const->get_shape(), shape_value);
|
||||
new_const->set_friendly_name(old_shape_const->get_friendly_name());
|
||||
ngraph::copy_runtime_info(old_shape_const, new_const);
|
||||
ngraph::replace_node(old_shape_const, new_const);
|
||||
|
||||
setMask(m_output, output_mask);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "ReshapeMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::Elementwise : public MatcherPass {
|
||||
public:
|
||||
Elementwise() {
|
||||
auto input = pattern::any_input();
|
||||
auto weights = pattern::any_input();
|
||||
auto eltwise = pattern::wrap_type<op::util::BinaryElementwiseArithmetic>({input, weights},
|
||||
pattern::has_static_rank());
|
||||
auto eltwise = pattern::wrap_type<opset6::Add, opset6::Subtract, opset6::Maximum, opset6::Minimum,
|
||||
opset6::Multiply>({input, weights}, pattern::has_static_rank());
|
||||
// TODO: add Div, Power support
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
@ -183,82 +278,275 @@ public:
|
||||
const auto & m_output = pattern_map.at(eltwise);
|
||||
const auto & m_input = pattern_map.at(input);
|
||||
|
||||
// TODO: implement check that compares input shape ranks
|
||||
// Case when input masks should be united instead of intersection
|
||||
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
|
||||
|
||||
const auto & input_rank = m_input.get_partial_shape().rank().get_length();
|
||||
const auto & weights_rank = m_weights.get_partial_shape().rank().get_length();
|
||||
// Here assuming that masks can be propagated only through 3/4 dimensional tensors
|
||||
// (since channel dim is necessary)
|
||||
if (weights_rank < 3 || input_rank < 3) return false;
|
||||
|
||||
// In case if one of the inputs is constant
|
||||
// TODO: need to find channel dimension instead of hardcoded zero
|
||||
const size_t & channel_dim = (input_rank == weights_rank ? 1 : 0);
|
||||
InitConstMask({channel_dim}).apply(m_input.get_node_shared_ptr());
|
||||
InitConstMask({channel_dim}).apply(m_weights.get_node_shared_ptr());
|
||||
// In case if first of the inputs is constant
|
||||
InitConstMask({0, 1/* potential output channel dim */}).apply(m_input.get_node_shared_ptr());
|
||||
auto input_mask = getMask(m_input);
|
||||
if (!input_mask) {
|
||||
NGRAPH_DEBUG << "No input mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
InitConstMask({0, 1}).apply(m_weights.get_node_shared_ptr());
|
||||
|
||||
auto weights_mask = getMask(m_weights);
|
||||
auto input_mask = getMask(m_input);
|
||||
|
||||
if (!weights_mask || !input_mask) {
|
||||
NGRAPH_DEBUG << "No mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
|
||||
if (!weights_mask) {
|
||||
NGRAPH_DEBUG << "No weights mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
|
||||
// Merge masks from two inputs
|
||||
// Merging masks from two inputs
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
auto output_mask_row = output_mask.get();
|
||||
|
||||
auto out_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
auto omask_iter = cur_mask->rbegin();
|
||||
auto imask_iter = input_mask_row->rbegin();
|
||||
auto wmask_iter = weights_mask_row->rbegin();
|
||||
|
||||
for (auto & item : *cur_mask) {
|
||||
item.clear();
|
||||
}
|
||||
|
||||
while (imask_iter != input_mask_row->rend() &&
|
||||
wmask_iter != weights_mask_row->rend()) {
|
||||
// Merge mask dimension values for both masks
|
||||
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
|
||||
for (const auto & value : *imask_iter) {
|
||||
if (wmask_iter->count(value)) {
|
||||
omask_iter->insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
omask_iter++;
|
||||
imask_iter++;
|
||||
wmask_iter++;
|
||||
auto out_mask_callback = [input_mask_row, weights_mask_row, union_eltwise_type](Mask::Ptr cur_mask) -> bool {
|
||||
Mask::Ptr result_mask;
|
||||
if (union_eltwise_type) {
|
||||
result_mask = input_mask_row->union_masks_reversed(weights_mask_row);
|
||||
} else {
|
||||
result_mask = input_mask_row->intersect_masks_reversed(weights_mask_row);
|
||||
}
|
||||
cur_mask->copy_value_from_mask_reversed(result_mask.get());
|
||||
return true;
|
||||
};
|
||||
output_mask->add_callback(out_mask_callback, input_mask);
|
||||
output_mask->add_callback(out_mask_callback, weights_mask);
|
||||
|
||||
auto callback = [output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
auto omask_iter = output_mask_row->rbegin();
|
||||
auto cmask_iter = cur_mask->rbegin();
|
||||
while (omask_iter != output_mask_row->rend() &&
|
||||
cmask_iter != cur_mask->rend()) {
|
||||
// TODO: check
|
||||
*cmask_iter = *omask_iter;
|
||||
|
||||
omask_iter++;
|
||||
cmask_iter++;
|
||||
}
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask_reversed(weights_mask_row);
|
||||
return true;
|
||||
};
|
||||
input_mask->add_callback(callback, output_mask);
|
||||
weights_mask->add_callback(callback, output_mask);
|
||||
}, weights_mask);
|
||||
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask_reversed(output_mask_row);
|
||||
return true;
|
||||
}, output_mask);
|
||||
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask_reversed(input_mask_row);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
// Init output mask
|
||||
output_mask->apply_callback(input_mask);
|
||||
weights_mask->apply_callback(input_mask);
|
||||
|
||||
setMask(m_output, output_mask);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "EltwiseMaskPropagation");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "ElementwiseMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::FakeQuantize : public MatcherPass{
|
||||
public:
|
||||
FakeQuantize(){
|
||||
auto input = pattern::any_input(pattern::has_static_shape());
|
||||
auto input_low = pattern::any_input(pattern::has_static_shape());
|
||||
auto input_high = pattern::any_input(pattern::has_static_shape());
|
||||
auto output_low = pattern::any_input(pattern::has_static_shape());
|
||||
auto output_high = pattern::any_input(pattern::has_static_shape());
|
||||
auto fake_quantize = pattern::wrap_type<opset6::FakeQuantize>({input, input_low, input_high, output_low,
|
||||
output_high});
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_input = pattern_map.at(input);
|
||||
const auto & m_input_low = pattern_map.at(input_low);
|
||||
const auto & m_input_high = pattern_map.at(input_high);
|
||||
const auto & m_output_low = pattern_map.at(output_low);
|
||||
const auto & m_output_high = pattern_map.at(output_high);
|
||||
const auto & m_output = pattern_map.at(fake_quantize);
|
||||
|
||||
auto input_mask = getMask(m_input);
|
||||
|
||||
// Input mask is the only source of pruning in FQ
|
||||
if (!input_mask) {
|
||||
NGRAPH_DEBUG << "FakeQuantize: No input mask for " << *m_output.get_node() << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_mask_row = input_mask.get();
|
||||
|
||||
// Propagate input mask to output mask and in the opposite direction
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
auto output_mask_row = output_mask.get();
|
||||
|
||||
// Output mask is equal to input mask
|
||||
auto output_mask_callback = [input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(input_mask_row);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto input_mask_callback = [output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(output_mask_row);
|
||||
return true;
|
||||
};
|
||||
|
||||
output_mask->add_callback(output_mask_callback, input_mask);
|
||||
input_mask->add_callback(input_mask_callback, output_mask);
|
||||
|
||||
// Calculate output mask
|
||||
output_mask->apply_callback(input_mask);
|
||||
setMask(m_output, output_mask);
|
||||
|
||||
auto input_low_size = shape_size(m_input_low.get_shape());
|
||||
auto input_high_size = shape_size(m_input_high.get_shape());
|
||||
auto output_low_size = shape_size(m_output_low.get_shape());
|
||||
auto output_high_size = shape_size(m_output_high.get_shape());
|
||||
|
||||
// In the per-tensor case FQ params shouldn't be pruned
|
||||
if (input_low_size == 1 && output_low_size == 1 && input_high_size == 1 && output_high_size == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If input/output ranges in FQ should be broadcasted to input shape -> broadcast this consant values
|
||||
// for the convenience of working with the masks
|
||||
NodeVector fq_params_nodes{m_input_low.get_node_shared_ptr(),
|
||||
m_input_high.get_node_shared_ptr(),
|
||||
m_output_low.get_node_shared_ptr(),
|
||||
m_output_high.get_node_shared_ptr()};
|
||||
auto fq_node = std::dynamic_pointer_cast<op::FakeQuantize>(m_output.get_node_shared_ptr());
|
||||
size_t idx = 0;
|
||||
if (fq_node->get_auto_broadcast() != ngraph::op::AutoBroadcastType::NONE) {
|
||||
for (auto const_node : fq_params_nodes) {
|
||||
auto new_shape = broadcast_shape_to_rank(const_node->get_shape(),
|
||||
m_input.get_partial_shape().rank().get_length());
|
||||
auto const_copy = const_node->clone_with_new_inputs(const_node->input_values());
|
||||
auto new_const = std::dynamic_pointer_cast<op::Constant>(const_copy);
|
||||
new_const->set_data_shape(new_shape);
|
||||
new_const->validate_and_infer_types();
|
||||
new_const->set_friendly_name(const_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(const_node, new_const);
|
||||
ngraph::replace_node(const_node, new_const);
|
||||
fq_params_nodes[idx++] = new_const;
|
||||
}
|
||||
}
|
||||
|
||||
auto fq_params_mask_callback = [input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1/* fq params have same shapes as input */) = input_mask_row->at(1 /* channel dim in data */);
|
||||
return true;
|
||||
};
|
||||
|
||||
for (auto fq_param : fq_params_nodes) {
|
||||
auto mask = std::make_shared<Mask>(fq_param->get_shape().size());
|
||||
mask->add_callback(fq_params_mask_callback, input_mask);
|
||||
input_mask->add_callback([mask](Mask::Ptr cur_mask) -> bool {
|
||||
return true;
|
||||
}, mask);
|
||||
mask->apply_callback(input_mask);
|
||||
setMask(fq_param->output(0), mask);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(fake_quantize, "FakeQuantizeMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::Concat : public MatcherPass{
|
||||
public:
|
||||
Concat() {
|
||||
auto concat = pattern::wrap_type<opset6::Concat>(pattern::has_static_shape());
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_output = pattern_map.at(concat);
|
||||
auto concat_ptr = std::dynamic_pointer_cast<opset6::Concat>(m_output.get_node_shared_ptr());
|
||||
auto axis = concat_ptr->get_concatenation_axis();
|
||||
|
||||
auto inputs = concat_ptr->inputs();
|
||||
std::map<int64_t , Mask::Ptr> input_masks;
|
||||
std::map<int64_t , Mask *> input_masks_row;
|
||||
std::vector<int64_t> input_sizes;
|
||||
|
||||
size_t first_input_idx = 0;
|
||||
Mask::Ptr first_input_mask;
|
||||
bool first_initialized = false;
|
||||
for (size_t i=0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
auto input_mask = getMask(input.get_source_output());
|
||||
if (input_mask) {
|
||||
input_masks[i] = input_mask;
|
||||
input_masks_row[i] = input_mask.get();
|
||||
|
||||
if (!first_initialized) {
|
||||
first_input_idx = i;
|
||||
first_input_mask = input_mask;
|
||||
first_initialized = true;
|
||||
}
|
||||
}
|
||||
input_sizes.push_back(input.get_shape().at(axis));
|
||||
}
|
||||
|
||||
if (!first_initialized) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
auto output_mask_row = output_mask.get();
|
||||
|
||||
auto out_mask_callback = [input_masks_row, input_sizes, axis](Mask::Ptr cur_mask) -> bool {
|
||||
int64_t cur_size = 0;
|
||||
cur_mask->at(axis).clear();
|
||||
|
||||
for (size_t i=0; i < input_sizes.size(); ++i) {
|
||||
if (input_masks_row.count(i)) {
|
||||
for (auto idx : input_masks_row.at(i)->at(axis)) {
|
||||
cur_mask->at(axis).insert(idx + cur_size);
|
||||
}
|
||||
}
|
||||
cur_size += input_sizes[i];
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto create_input_mask_callback_for_idx = [output_mask_row, input_sizes, axis](size_t input_idx){
|
||||
auto input_mask_callback = [output_mask_row, input_sizes, axis, input_idx](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->clean_dim_values();
|
||||
uint64_t min_val = 0;
|
||||
for (size_t i = 0; i < input_idx; i++) {
|
||||
min_val += input_sizes[i];
|
||||
}
|
||||
uint64_t max_val = min_val + input_sizes[input_idx];
|
||||
for (auto idx : output_mask_row->at(axis)) {
|
||||
if (idx < max_val && idx >= min_val) {
|
||||
cur_mask->at(axis).insert(idx - min_val);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return input_mask_callback;
|
||||
};
|
||||
output_mask->add_callback(out_mask_callback, first_input_mask);
|
||||
|
||||
for (size_t i=0; i < inputs.size(); ++i) {
|
||||
if (input_masks.count(i) && i != first_input_idx) {
|
||||
auto input_mask = input_masks.at(i);
|
||||
input_mask->add_callback(create_input_mask_callback_for_idx(i),
|
||||
first_input_mask);
|
||||
first_input_mask->add_callback([](Mask::Ptr cur_mask) -> bool {
|
||||
return true;
|
||||
}, input_mask);
|
||||
}
|
||||
}
|
||||
first_input_mask->add_callback(create_input_mask_callback_for_idx(first_input_idx),
|
||||
output_mask);
|
||||
output_mask->apply_callback(first_input_mask);
|
||||
setMask(m_output, output_mask);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
@ -266,7 +554,9 @@ public:
|
||||
class ngraph::pass::mask_propagation::PassThrough : public MatcherPass {
|
||||
public:
|
||||
PassThrough() {
|
||||
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp>();
|
||||
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp,
|
||||
opset6::Convert, opset6::ConvertLike, opset6::AvgPool, opset6::MaxPool,
|
||||
opset6::ROIPooling, opset6::PSROIPooling, opset6::Pad>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
@ -312,5 +602,8 @@ ngraph::pass::PropagateMasks::PropagateMasks() {
|
||||
add_matcher<mask_propagation::GroupConvolution>();
|
||||
add_matcher<mask_propagation::Elementwise>();
|
||||
add_matcher<mask_propagation::PassThrough>();
|
||||
add_matcher<mask_propagation::FakeQuantize>();
|
||||
add_matcher<mask_propagation::Concat>();
|
||||
add_matcher<mask_propagation::Reshape>();
|
||||
add_matcher<mask_propagation::StopPropagation>();
|
||||
}
|
||||
|
@ -15,8 +15,13 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::Pruning, "Pruning", 0);
|
||||
|
||||
bool ngraph::pass::Pruning::run_on_function(std::shared_ptr<Function> f) {
|
||||
Manager manager(get_pass_config());
|
||||
|
||||
// Initialize masks only for Convolutions/GroupConvolutions weights (needed to init mask in source Constant of
|
||||
// weights-calculating subgraph). For other node types masks initialized in PropagateMasks pass.
|
||||
manager.register_pass<InitMasks>();
|
||||
manager.register_pass<PropagateMasks>();
|
||||
|
||||
|
||||
#ifdef NGRAPH_DEBUG_ENABLE
|
||||
// VisualizeTree modifier helps to print Masks and mark nodes with masks
|
||||
/*
|
||||
|
@ -54,6 +54,8 @@ bool ngraph::pass::ShrinkWeights::run_on_function(std::shared_ptr<ngraph::Functi
|
||||
for (size_t dim = 0; dim < mask->size(); ++dim) {
|
||||
const auto &dim_size = mask->at(dim).size();
|
||||
if (dim_size == 0) continue;
|
||||
// Broadcastable 1-size dimension shouldn't be shrank with mask
|
||||
if (const_node->get_shape().at(dim) == 1 && dim_size > 1) continue;
|
||||
|
||||
// Convert dims that we want remove to dims that we need to keep
|
||||
std::vector<int64_t> dims_to_keep;
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/coordinate_transform.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <inference_engine.hpp>
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
@ -67,6 +68,23 @@ TEST(TransformationTests, InitMasksOutputChannel) {
|
||||
compare_masks(*getMask(weights->output(0)), {{}, {1}, {}, {}});
|
||||
}
|
||||
|
||||
// TODO: add test init masks with subgraph
|
||||
TEST(TransformationTests, TestInitMasks) {
|
||||
Shape weights_shape{6, 3, 3, 3};
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{1, 2, 3}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), {{1, 2, 3}, {}, {}, {}});
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InitMasksNegative) {
|
||||
Shape weights_shape{6, 3, 3, 3};
|
||||
auto weights = opset5::Constant::create(element::f32, weights_shape, {0.5});
|
||||
@ -85,6 +103,7 @@ TEST(TransformationTests, PropagateMasksNegative) {
|
||||
auto f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
@ -102,27 +121,35 @@ TEST(TransformationTests, PropagateMasksBasic) {
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {1, 2, 3, 4, 5}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(relu, add_const);
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
|
||||
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
|
||||
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{2}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
|
||||
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {4}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
|
||||
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
|
||||
auto weights2 = create_constant_with_zeros(weights_shape2, {{1, 2}, {1, 2, 3}, {}, {}});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights->output(0)), Mask({{2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const), Mask({{2}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const), Mask({{2}, {}, {}}));
|
||||
compare_masks(*getMask(weights2->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const), Mask({{1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
||||
@ -148,6 +175,7 @@ TEST(TransformationTests, PropagateMasksDynamicConvolution) {
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
@ -182,6 +210,7 @@ TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) {
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
@ -199,15 +228,16 @@ TEST(TransformationTests, PropagateMasksEmpty) {
|
||||
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
|
||||
auto sub = std::make_shared<opset5::Subtract>(relu, sub_const);
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Subtract>(sub, mul_const);
|
||||
auto add_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Subtract>(sub, add_const);
|
||||
|
||||
auto weights2 = opset5::Constant::create(element::f32, weights_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(add, weights2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
@ -215,11 +245,55 @@ TEST(TransformationTests, PropagateMasksEmpty) {
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const), Mask({{}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const), Mask({{}, {}, {}}));
|
||||
compare_masks(*getMask(add_const), Mask({{}, {}, {}}));
|
||||
compare_masks(*getMask(weights2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PropagateMaskPassThrough) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
Shape weight_shape2{3, 8, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
auto weights_const_1 = create_constant_with_zeros(weights_shape, {{1, 2, 3}, {}, {}, {}});
|
||||
weights_const_1.get_node_shared_ptr()->set_friendly_name("weights_1");
|
||||
|
||||
auto conv_1 = std::make_shared<opset5::Convolution>(input, weights_const_1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv_1->set_friendly_name("conv_1");
|
||||
|
||||
// Adding a couple of PassThrough operations
|
||||
auto relu = std::make_shared<opset5::Relu>(conv_1);
|
||||
relu->set_friendly_name("relu");
|
||||
|
||||
auto clamp = std::make_shared<opset5::Clamp>(relu, 0, 6);
|
||||
clamp->set_friendly_name("clamp");
|
||||
|
||||
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
|
||||
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
|
||||
auto pad = std::make_shared<opset5::Pad>(clamp, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
auto max_pool = std::make_shared<opset5::MaxPool>(pad, Strides{1, 1},
|
||||
Shape{0, 0}, Shape{1, 1}, Shape{4, 4});
|
||||
max_pool->set_friendly_name("max_pool");
|
||||
|
||||
auto weights2 = opset5::Constant::create(element::f32, weight_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(max_pool, weights2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_const_1.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(clamp->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(max_pool->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
Shape input_shape{1, 3, 3, 3};
|
||||
|
||||
@ -280,4 +354,344 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
// compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
|
||||
// compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
|
||||
// compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PropagateMasksQuantizedGroupConvolution) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
Shape weights_group_shape{8, 1, 3, 3};
|
||||
Shape weight_shape2{3, 8, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
|
||||
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto weights_group = opset5::Constant::create(element::i8, weights_group_shape, {0});
|
||||
weights_group->set_friendly_name("weights_group");
|
||||
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
auto reshape = std::make_shared<opset5::Reshape>(mul, opset5::Constant::create(element::i64, Shape{5}, {8, 1, 1, 3, 3}), false);
|
||||
|
||||
auto conv_group = std::make_shared<opset5::GroupConvolution>(conv1, reshape, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
|
||||
auto add = std::make_shared<opset5::Add>(conv_group, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(add, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_group->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(reshape->output(0)), Mask({{0 , 1, 2, 3}, {}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
Shape weight_shape2{3, 8, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
|
||||
weights_1->set_friendly_name("weights_int8_const");
|
||||
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv1->set_friendly_name("conv1");
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
|
||||
auto add = std::make_shared<opset5::Add>(conv1, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 1, 1, 1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {1});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
|
||||
|
||||
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(fq->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, PropagateMasksFakeQuantizePerChannel) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
Shape weight_shape2{3, 8, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
|
||||
weights_1->set_friendly_name("weights_int8_const");
|
||||
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv1->set_friendly_name("conv1");
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
|
||||
auto add = std::make_shared<opset5::Add>(conv1, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1, 8, 1, 1}, {0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 8, 1, 1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{8, 1, 1}, {1});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{8, 1, 1}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
|
||||
|
||||
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_1->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{0 , 1, 2, 3, 4}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(fq->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(fq->input(1).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(fq->input(2).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(fq->input(3).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(fq->input(4).get_source_output()), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, TestConcatMaskPropagation) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape1{8, 3, 3, 3};
|
||||
Shape weights_shape2{16, 3, 3, 3};
|
||||
Shape weights_shape3{8, 3, 3, 3};
|
||||
|
||||
Shape weight_shape_out_conv{3, 32, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_3 = create_constant_with_zeros(weights_shape3, {{4, 5, 6, 7}, {}, {}, {}});
|
||||
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
|
||||
|
||||
auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}});
|
||||
auto conv_out = std::make_shared<opset5::Convolution>(concat, weights_out_conv, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto f = std::make_shared<Function>(NodeVector{conv_out}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{7, 8, 9, 10}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {7, 8, 9, 10}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{4, 5, 6, 7}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv3->output(0)), Mask({{}, {4, 5, 6, 7}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, TestConcatMaskPropagationUp) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape1{8, 3, 3, 3};
|
||||
Shape weights_shape2{16, 3, 3, 3};
|
||||
Shape weights_shape3{8, 3, 3, 3};
|
||||
|
||||
Shape weight_shape_out_conv{3, 32, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3, 4, 5}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_3 = create_constant_with_zeros(weights_shape3, {{2, 3, 4, 5, 6, 7}, {}, {}, {}});
|
||||
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 32, 1, 1}, {{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(concat, add_const);
|
||||
|
||||
auto weights_out_conv = create_constant_with_zeros(weight_shape_out_conv, {{}, {}, {}, {}});
|
||||
auto conv_out = std::make_shared<opset5::Convolution>(add, weights_out_conv, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto f = std::make_shared<Function>(NodeVector{conv_out}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{7, 8, 9, 10}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {7, 8, 9, 10}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{4, 5, 6, 7}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv3->output(0)), Mask({{}, {4, 5, 6, 7}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
|
||||
|
||||
compare_masks(*getMask(concat->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
compare_masks(*getMask(weights_out_conv.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}}));
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, TestConcatMaskPropagationUpEmpty) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape1{8, 3, 3, 3};
|
||||
Shape weights_shape2{16, 3, 3, 3};
|
||||
Shape weights_shape3{8, 3, 3, 3};
|
||||
|
||||
Shape weight_shape_out_conv{3, 32, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights_1 = create_constant_with_zeros(weights_shape1, {{0, 1, 2, 3, 4, 5}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights_1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_2 = create_constant_with_zeros(weights_shape2, {{7, 8, 9, 10}, {}, {}, {}});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(input, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto weights_3 = create_constant_with_zeros(weights_shape3, {{2, 3, 4, 5, 6, 7}, {}, {}, {}});
|
||||
auto conv3 = std::make_shared<opset5::Convolution>(input, weights_3, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto concat = std::make_shared<opset5::Concat>(OutputVector{conv1->output(0), conv2->output(0), conv3->output(0)}, 1);
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 32, 1, 1}, {{}, {0, 1, 2, 3, 15, 16, 17, 18, 28, 29, 30, 31}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(concat, add_const);
|
||||
|
||||
auto f = std::make_shared<Function>(NodeVector{add}, ParameterVector{input});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(f);
|
||||
|
||||
compare_masks(*getMask(weights_1.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_3.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv3->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
|
||||
compare_masks(*getMask(concat->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user