[Ngraph transformation][Pruning]Matmul ops pruning support (#10211)

* Linear pruning support

* Minor fix

* Fix types

* Fix: stop 1d multiply propagation
This commit is contained in:
Daniil Lyakhov
2022-02-14 22:00:29 +03:00
committed by GitHub
parent 2f876e3b5b
commit 2f9c5df271
6 changed files with 1028 additions and 159 deletions

View File

@@ -16,6 +16,7 @@
#include <set>
#include <ngraph/node.hpp>
#include <ngraph/log.hpp>
namespace ngraph {
@@ -49,6 +50,10 @@ public:
m_adjust_value(adjust_value) {
}
explicit Mask(const std::vector<value_type> val)
: std::vector<value_type>(val) {
}
Mask(std::initializer_list<std::initializer_list<uint64_t>> list)
: std::vector<value_type>() {
for (const auto & dim_values : list) {
@@ -122,7 +127,7 @@ public:
return result_mask;
}
Mask::Ptr union_masks_reversed(Mask *const mask) {
Mask::Ptr union_masks_reversed(Mask *const mask) const {
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
auto result_iter = result_mask->rbegin();
auto mask_1_iter = rbegin();
@@ -149,9 +154,13 @@ public:
return result_mask;
}
void add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
bool add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
if (m_callbacks.find(mask.get()) != m_callbacks.end())
NGRAPH_DEBUG << "Attempt to rewrite callback, could lead to unexpected behaviour";
m_callbacks[mask.get()] = receive_callback;
m_dependencies.push_back(mask.get());
return true;
}
/* Modify state of this mask by corresponding callback,
@@ -173,11 +182,14 @@ public:
m_need_initialization = false;
// recursively apply callbacks for each dependent mask
for (const auto & m_dependency : m_dependencies) {
if (m_dependency == mask.get())
continue;
if (!m_dependency->apply_callback(shared_from_this())) {
return false;
}
}
return true;
return mask->apply_callback(shared_from_this());
}
void invalidate() {

View File

@@ -30,8 +30,12 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
for (const auto & dim : dims) {
if (dim >= shape.size()) {
throw ngraph_error("Dim value " + std::to_string(dim) + " is out of range [0;" +std::to_string(shape.size() - 1) + "]");
NGRAPH_DEBUG << "[WARNING] Attemt to initialize masks on " << dim
<< " dimension which is out of shape " << shape
<< " for node (" << const_node->get_friendly_name() << ")";
continue;
}
for (size_t value = 0; value < shape[dim]; ++value) {
Coordinate begin(shape.size(), 0);
Coordinate end(shape);

View File

@@ -16,6 +16,7 @@ namespace pass {
namespace init_masks {
class InitConvMask;
class InitMatMulMask;
} // namespace init_masks
} // namespace pass
@@ -58,7 +59,56 @@ public:
};
class ngraph::pass::init_masks::InitMatMulMask : public MatcherPass {
public:
InitMatMulMask() {
auto a = pattern::any_input();
auto b = pattern::any_input();
auto matmul = pattern::wrap_type<opset6::MatMul>({a, b});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_output = pattern_map.at(matmul);
// Assume constant always in the first input port.
// Initializing weights mask:
// 1. Looking for Const node with weights
NodeVector weights_calculation_nodes;
auto cur_node = m_output.get_node()->get_input_node_shared_ptr(1);
while (!ngraph::is_type<opset6::Constant>(cur_node) && cur_node->inputs().size()) {
weights_calculation_nodes.push_back(cur_node);
cur_node = cur_node->get_input_node_shared_ptr(0);
}
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
NGRAPH_DEBUG << "Can't find Constant weights for MatMul: " <<
m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
// 2. Get constant rank to set mask on last dimension
const auto const_op = std::dynamic_pointer_cast<opset6::Constant>(cur_node);
const auto shape_rank = const_op->get_shape().size();
const auto matmul = std::dynamic_pointer_cast<opset6::MatMul>(m_output.get_node_shared_ptr());
const auto shift = (matmul->get_transpose_b())? 2 : 1;
if (shape_rank < shift) {
NGRAPH_DEBUG << "Can't init mask for MatMul: " <<
m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
const size_t outer_dim = shape_rank - shift;
// 3. Init mask for Const node
InitConstMask({outer_dim}/* check only outer dim */).apply(cur_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "MatMulInitMask");
register_matcher(m, callback);
}
};
ngraph::pass::InitMasks::InitMasks() {
add_matcher<init_masks::InitConvMask>();
add_matcher<init_masks::InitMatMulMask>();
}

View File

@@ -6,6 +6,8 @@
#include "mask_attribute.hpp"
#include <algorithm>
#include <memory>
#include <iterator>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset6.hpp>
@@ -20,6 +22,7 @@ namespace ngraph {
namespace pass {
namespace mask_propagation {
class MatMul;
class Convolution;
class GroupConvolution;
class GroupConvolutionReshape;
@@ -45,6 +48,85 @@ static ngraph::Shape broadcast_shape_to_rank(ngraph::Shape shape_to_broadcast, i
return new_shape;
}
class ngraph::pass::mask_propagation::MatMul : public MatcherPass {
public:
MatMul() {
auto a = pattern::any_input(pattern::has_static_shape());
auto b = pattern::any_input(pattern::has_static_shape());
auto matmul = pattern::wrap_type<opset6::MatMul>({a, b});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_a = pattern_map.at(a);
const auto & m_b = pattern_map.at(b);
const auto & m_matmul = pattern_map.at(matmul);
auto a_mask = getMask(m_a);
auto b_mask = getMask(m_b);
if (!a_mask || !b_mask) {
NGRAPH_DEBUG << "No mask for any input of " << m_matmul.get_node()->get_friendly_name() << "\n";
return false;
}
auto a_mask_row = a_mask.get();
auto b_mask_row = b_mask.get();
const auto matmul_op = std::dynamic_pointer_cast<opset6::MatMul>(m_matmul.get_node_shared_ptr());
const auto transpose_a = matmul_op->get_transpose_a();
const auto transpose_b = matmul_op->get_transpose_b();
const auto shape_a = m_a.get_shape();
const auto shape_b = m_b.get_shape();
const auto a_inner_dim = (transpose_a)? shape_a.size() - 2 : shape_a.size() - 1;
const auto a_outer_dim = (transpose_a)? shape_a.size() - 1 : shape_a.size() - 2;
const auto b_inner_dim = (transpose_b)? shape_b.size() - 1 : shape_b.size() - 2;
const auto b_outer_dim = (transpose_b)? shape_b.size() - 2 : shape_b.size() - 1;
const auto matmul_range = m_matmul.get_shape().size();
auto matmul_mask = std::make_shared<Mask>(matmul_range);
auto matmul_mask_row = matmul_mask.get();
const auto matmul_cols_dim = matmul_range - 1;
const auto matmul_rows_dim = matmul_range - 2;
const auto matmul_callback = [=](Mask::Ptr cur_mask) -> bool {
cur_mask->at(matmul_rows_dim) = a_mask_row->at(a_outer_dim);
cur_mask->at(matmul_cols_dim) = b_mask_row->at(b_outer_dim);
if (a_mask_row->at(a_inner_dim) != b_mask_row->at(b_inner_dim))
cur_mask->initialize_dependencies();
return true;
};
// Connect a with matmul mask
matmul_mask->add_callback(matmul_callback, a_mask);
a_mask->add_callback([=](Mask::Ptr cur_mask) -> bool {
cur_mask->at(a_inner_dim) = b_mask_row->at(b_inner_dim);
cur_mask->at(a_outer_dim) = matmul_mask_row->at(matmul_rows_dim);
return true;
}, matmul_mask);
// connect b with matmul mask
matmul_mask->add_callback(matmul_callback, b_mask);
b_mask->add_callback([=](Mask::Ptr cur_mask) -> bool {
cur_mask->at(b_inner_dim) = a_mask_row->at(a_inner_dim);
cur_mask->at(b_outer_dim) = matmul_mask_row->at(matmul_cols_dim);
return true;
}, matmul_mask);
if (!matmul_mask->apply_callback(a_mask)) {
return false;
}
setMask(m_matmul, matmul_mask);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "MatMulMaskPropagation");
register_matcher(m, callback);
}
};
class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
public:
Convolution() {
@@ -69,42 +151,47 @@ public:
}
auto weights_mask_row = weights_mask.get();
if (auto input_mask = getMask(m_input)) {
auto input_mask_row = input_mask.get();
// Weights input channel is connected to the convolution input channel dimension
// so we update weights mask to be aligned with input shape.
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/* weights input channel */) = input_mask_row->at(1 /* input data channel */);
return true;
}, input_mask);
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(1);
return true;
}, weights_mask);
if (!weights_mask->apply_callback(input_mask)) {
return false;
}
}
// Create output mask that describes which channel dimensions will be removed
auto conv_mask = std::make_shared<Mask>(m_weights.get_shape().size());
auto conv_mask_row = conv_mask.get();
auto input_mask = getMask(m_input);
Mask* input_mask_row = nullptr;
if (input_mask)
input_mask_row = input_mask.get();
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0/*weights output channel dim */);
const auto conv_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/*input data channel*/) = weights_mask_row->at(0 /* weights output channel dim*/);
if (input_mask_row && input_mask_row->at(1) != weights_mask_row->at(1))
cur_mask->initialize_dependencies();
return true;
}, weights_mask);
};
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
if (input_mask) {
// Weights input channel is connected to the convolution input channel dimension
// so we update weights mask to be aligned with input shape.
conv_mask->add_callback(conv_mask_callback, input_mask);
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(1);
return true;
}, conv_mask);
}
conv_mask->add_callback(conv_mask_callback, weights_mask);
weights_mask->add_callback([input_mask_row, conv_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask_row->at(1);
if (input_mask_row)
cur_mask->at(1) = input_mask_row->at(1);
return true;
}, conv_mask);
if (!conv_mask->apply_callback(weights_mask)) {
bool status;
if (input_mask)
status = conv_mask->apply_callback(input_mask);
else
status = conv_mask->apply_callback(weights_mask);
if (!status)
return false;
}
setMask(m_output, conv_mask);
return true;
@@ -154,37 +241,30 @@ public:
}
auto weights_mask_row = weights_mask.get();
// Weights input channel is connected to the convolution input channel dimension
// so we update weights mask to be aligned with input shape.
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = input_mask_row->at(1);
return true;
}, input_mask);
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0);
return true;
}, weights_mask);
if (!weights_mask->apply_callback(input_mask)) {
return false;
}
// Update output channels mask dims
auto conv_mask = std::make_shared<Mask>(input_shape.rank().get_length());
auto conv_mask_row = conv_mask.get();
conv_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/*input data channel*/) = input_mask_row->at(1/*output data channel*/);
return true;
}, input_mask);
input_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/*output data channel*/) = conv_mask_row->at(1/*input data channel*/);
return true;
}, conv_mask);
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0);
cur_mask->at(1/*input data channel*/) = weights_mask_row->at(0/*weights output channel dim*/);
return true;
}, weights_mask);
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask_row->at(1);
cur_mask->at(0/*weights output channel dim*/) = conv_mask_row->at(1/*output data channel*/);
return true;
}, conv_mask);
if (!conv_mask->apply_callback(weights_mask)) {
if (!conv_mask->apply_callback(input_mask)) {
return false;
}
@@ -297,14 +377,13 @@ public:
const auto & m_output = pattern_map.at(eltwise);
const auto & m_input = pattern_map.at(input);
// Case when input masks should be united instead of intersection
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
const auto & input_rank = m_input.get_partial_shape().rank().get_length();
const auto & weights_rank = m_weights.get_partial_shape().rank().get_length();
// Here assuming that masks can be propagated only through 3/4 dimensional tensors
// (since channel dim is necessary)
if (weights_rank < 3 || input_rank < 3) return false;
// (since channel dim is necessary) or tensors with equal rank.
if (!((weights_rank > 2 && input_rank > 2) || weights_rank == input_rank)) return false;
// Case when input masks should be united instead of intersection
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
// In case if first of the inputs is constant
InitConstMask({0, 1/* potential output channel dim */}).apply(m_input.get_node_shared_ptr());
@@ -578,7 +657,7 @@ public:
opset6::Elu, opset6::HardSigmoid, opset6::PRelu, opset6::Mish,
opset6::Softmax, opset6::SoftPlus, opset6::Convert, opset6::ConvertLike,
opset6::AvgPool, opset6::MaxPool, opset6::ROIPooling, opset6::PSROIPooling,
opset6::Pad>();
opset6::Pad, opset6::MVN>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
@@ -646,6 +725,39 @@ public:
}
};
static std::pair<std::set<uint64_t>, bool> squeeze_mask(
const std::set<uint64_t> mask_dim, const size_t elems_per_ch, const bool squeeze) {
bool should_init_dep = false;
auto ret_set = std::set<uint64_t>();
auto mask_dim_copy = std::set<uint64_t>();
std::copy(mask_dim.begin(), mask_dim.end(), std::inserter(mask_dim_copy, mask_dim_copy.begin()));
while (mask_dim_copy.size()) {
const auto elem = *mask_dim_copy.begin();
const auto ch = elem / elems_per_ch;
// Check all channel is zeroed
const auto low = mask_dim_copy.lower_bound(ch * elems_per_ch);
const auto upper = mask_dim_copy.lower_bound((ch + 1) * elems_per_ch);
auto channel_zeros = std::set<uint64_t>();
std::copy(low, upper, std::inserter(channel_zeros, channel_zeros.begin()));
// Remove all zeros related to current channel from iter mask
mask_dim_copy.erase(low, upper);
// In case any of elements are not zeroed - skip entire channel
if (channel_zeros.size() != elems_per_ch) {
should_init_dep = true;
continue;
}
// Add zeros for current channel in current mask
if (squeeze)
ret_set.insert(ch);
else
ret_set.insert(channel_zeros.begin(), channel_zeros.end());
}
return std::make_pair(ret_set, should_init_dep);
}
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
public:
Reshape() {
@@ -665,12 +777,15 @@ public:
if (is_type<opset6::GroupConvolution>(inp.get_node()))
return true;
// Can't process non constant node in the shape input by now.
if (!std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr())) {
NGRAPH_DEBUG << "Can't process reshape node " << m_output.get_node()->get_friendly_name()
<<" with no constant node " << m_weights.get_node()->get_friendly_name()
<< " as shape input.";
return false;
auto constant = std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr());
if (!constant) {
constant = get_constant_from_source(m_weights.get_node_shared_ptr());
if (!constant) {
NGRAPH_DEBUG << "Can't process reshape node " << m_output.get_node()->get_friendly_name()
<<" with no constant node " << m_weights.get_node()->get_friendly_name()
<< " as shape input.";
return false;
}
}
// Check reshape operation reshape only dimension without masks
@@ -683,44 +798,102 @@ public:
// Check dimensions equality from the begining and allow
// to propagate masks only for dimensions which equal from the begining
size_t i = 0;
for (; i < std::min(input_shape.size(), output_shape.size()); ++i) {
if (input_shape[i] != output_shape[i])
break;
size_t not_reshaped_dims;
{
size_t i = 0;
for (; i < std::min(input_shape.size(), output_shape.size()); ++i) {
if (input_shape[i] != output_shape[i])
break;
}
not_reshaped_dims = i;
}
auto not_reshaped_dims = i;
auto input_mask_row = input_mask.get();
auto weights_mask_row = weights_mask.get();
auto output_mask_row = output_mask.get();
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(weights_mask_row);
return true;
}, weights_mask);
weights_mask->add_callback([input_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool{
// Propagate masks down through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = input_mask_row->at(dim);
else if (cur_mask->at(dim) != input_mask_row->at(dim))
cur_mask->initialize_dependencies();
return true;
}, input_mask);
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(weights_mask_row);
return true;
}, weights_mask);
// Case when reshape make flatten last dimension
if (input_shape.size() > output_shape.size() &&
output_shape.size() == not_reshaped_dims + 1) {
const size_t elems_per_ch = std::accumulate(input_shape.begin() + not_reshaped_dims + 1,
input_shape.end(), 1, std::multiplies<size_t>());
weights_mask->add_callback([output_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool {
// Propagate masks up through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < std::min(cur_mask->size(), output_mask_row->size()); ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = output_mask_row->at(dim);
else if (cur_mask->at(dim) != output_mask_row->at(dim))
cur_mask->initialize_dependencies();
return true;
}, output_mask);
input_mask->add_callback([weights_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = weights_mask_row->at(dim);
bool should_init_dep;
std::set<uint64_t> updated_mask;
std::tie(updated_mask, should_init_dep) = squeeze_mask(weights_mask_row->at(not_reshaped_dims), elems_per_ch, true);
cur_mask->at(not_reshaped_dims) = updated_mask;
if (should_init_dep) cur_mask->initialize_dependencies();
return true;
}, weights_mask);
weights_mask->add_callback([input_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
// Propagate masks down through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = input_mask_row->at(dim);
// Flat the last mask
for (auto &ch : input_mask_row->at(not_reshaped_dims))
for (auto idx = ch * elems_per_ch; idx < (ch + 1) * elems_per_ch; ++idx)
cur_mask->at(not_reshaped_dims).insert(idx);
return true;
}, input_mask);
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(weights_mask_row);
return true;
}, weights_mask);
weights_mask->add_callback([output_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
// Propagate masks up through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = output_mask_row->at(dim);
// For the last dimension keep only those zeros which completely
// covering a channel
bool should_init_dep;
std::set<uint64_t> updated_mask;
std::tie(updated_mask, should_init_dep) = squeeze_mask(output_mask_row->at(not_reshaped_dims), elems_per_ch, false);
cur_mask->at(not_reshaped_dims) = updated_mask;
if (should_init_dep) cur_mask->initialize_dependencies();
return true;
}, output_mask);
} else {
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(weights_mask_row);
return true;
}, weights_mask);
weights_mask->add_callback([input_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool{
// Propagate masks down through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = input_mask_row->at(dim);
else if (cur_mask->at(dim) != input_mask_row->at(dim))
cur_mask->initialize_dependencies();
return true;
}, input_mask);
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(weights_mask_row);
return true;
}, weights_mask);
weights_mask->add_callback([output_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool {
// Propagate masks up through dimension only if this dimension isn't reshaped
for (size_t dim = 0; dim < std::min(cur_mask->size(), output_mask_row->size()); ++dim)
if (dim < not_reshaped_dims)
cur_mask->at(dim) = output_mask_row->at(dim);
else if (cur_mask->at(dim) != output_mask_row->at(dim))
cur_mask->initialize_dependencies();
return true;
}, output_mask);
}
weights_mask->apply_callback(input_mask);
setMask(m_output, output_mask);
@@ -746,11 +919,11 @@ public:
const auto & node = m.get_match_root();
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
bool any_input_with_masks = false;
for (const auto & input : node->input_values()) {
if (auto input_mask = getMask(input)) {
auto input_mask_row = input_mask.get();
auto output_mask_row = output_mask.get();
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->clean_dim_values();
if (!output_mask_row->all_dims_are_empty())
@@ -800,6 +973,7 @@ public:
};
ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::MatMul>();
add_matcher<mask_propagation::Convolution>();
add_matcher<mask_propagation::GroupConvolutionReshape>();
add_matcher<mask_propagation::GroupConvolution>();

View File

@@ -55,8 +55,10 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
if (mask->adjust_value() && !mask->all_dims_are_empty()) {
std::vector<int64_t> new_const_value;
auto value = const_node->cast_vector<int64_t>();
for (size_t i = 0; i < mask->size(); i++)
new_const_value.push_back(value[i] - mask->at(i).size());
for (size_t i = 0; i < mask->size(); i++) {
const int64_t res = value[i] - mask->at(i).size();
new_const_value.push_back((res > 0)? res : value[i]);
}
const auto new_const = opset6::Constant::create(const_node->get_element_type(),
const_node->get_shape(), new_const_value);

View File

@@ -188,10 +188,10 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
auto add_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {1, 2, 3, 4, 5}, {}, {}});
auto add = std::make_shared<opset5::Add>(relu, add_const);
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {4}, {}, {}});
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {3}, {}, {}});
auto mul = std::make_shared<ov::op::v1::Multiply>(sub, mul_const);
auto weights2 = create_constant_with_zeros(weights_shape2, {{1, 2}, {1, 2, 3}, {}, {}});
@@ -202,21 +202,21 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 4, weights_shape[1], weights_shape[2] , weights_shape[3]}, {0});
auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 3, weights_shape[1], weights_shape[2] , weights_shape[3]}, {0});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto add_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1});
auto add_const = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1});
auto add = std::make_shared<opset5::Add>(relu, add_const);
auto sub_const = opset5::Constant::create(element::f32, Shape{2, 1, 1}, {1});
auto sub_const = opset5::Constant::create(element::f32, Shape{3, 1, 1}, {1});
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
auto mul_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1});
auto mul_const = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1});
auto mul = std::make_shared<ov::op::v1::Multiply>(sub, mul_const);
auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 4, weights_shape2[2], weights_shape2[3]}, {1});
auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 3, weights_shape2[2], weights_shape2[3]}, {1});
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
function_ref = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
@@ -229,16 +229,16 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(sub_const), Mask({{1, 2, 3}, {}, {}}));
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
{
pass::Manager m;
@@ -495,7 +495,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
}
TEST(TransformationTests, PropagateMasksHardDependencies) {
TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
Shape input_shape{1, 3, 3, 3};
auto input1 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
@@ -529,7 +529,8 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
auto reshape = std::make_shared<opset5::Reshape>(add1, opset5::Constant::create(element::i64, Shape{2}, {1, 6}), true);
reshape->set_friendly_name("reshape");
auto matmul = std::make_shared<opset5::MatMul>(reshape, opset5::Constant::create(element::f32, Shape{6, 100}, {1.}));
auto matmul_const = opset5::Constant::create(element::f32, Shape{6, 100}, {1.});
auto matmul = std::make_shared<opset5::MatMul>(reshape, matmul_const);
matmul->set_friendly_name("matmul");
auto add2 = std::make_shared<opset5::Add>(conv2, create_constant_with_zeros({6, 1, 1}, {{2}, {}, {}}));
@@ -543,24 +544,91 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv3->set_friendly_name("conv3");
auto f = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
function = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
{
auto input1 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input1->set_friendly_name("input1");
Shape weights1_shape{6, 3, 3, 3};
auto weights1 = create_constant_with_zeros({
weights1_shape[0] - 1,
weights1_shape[1],
weights1_shape[2],
weights1_shape[3]
}, {{}, {}, {}, {}});
weights1.get_node_shared_ptr()->set_friendly_name("weights1");
auto conv1 = std::make_shared<opset5::Convolution>(input1, weights1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv1->set_friendly_name("conv1");
auto relu = std::make_shared<opset5::Relu>(conv1);
relu->set_friendly_name("relu");
auto input2 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input2->set_friendly_name("input2");
Shape weights2_shape{6, 3, 3, 3};
auto weights2 = create_constant_with_zeros({weights2_shape[0] - 1,
weights2_shape[1],
weights2_shape[2],
weights2_shape[3]
}, {{2, 3}, {}, {}, {}});
weights2.get_node_shared_ptr()->set_friendly_name("weights2");
auto conv2 = std::make_shared<opset5::Convolution>(input2, weights2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv2->set_friendly_name("conv2");
auto add1 = std::make_shared<opset5::Add>(conv2, conv1);
add1->set_friendly_name("add1");
auto reshape = std::make_shared<opset5::Reshape>(add1, opset5::Constant::create(element::i64, Shape{2}, {1, 5}), true);
reshape->set_friendly_name("reshape");
auto matmul = std::make_shared<opset5::MatMul>(reshape, opset5::Constant::create(element::f32, Shape{5, 100}, {1.}));
matmul->set_friendly_name("matmul");
auto add2 = std::make_shared<opset5::Add>(conv2, create_constant_with_zeros({5, 1, 1}, {{}, {}, {}}));
add2->set_friendly_name("add2");
Shape weights_shape3{6, 6, 1, 1};
auto weights3 = opset5::Constant::create(element::f32,
{weights_shape3[0],
weights_shape3[1] - 1,
weights_shape3[2],
weights_shape3[3]
}, {0});
weights3->set_friendly_name("weights3");
auto conv3 = std::make_shared<opset5::Convolution>(add2, weights3, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv3->set_friendly_name("conv3");
function_ref = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg").run_on_function(f);
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{2}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {2}, {}, {}}));
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(f);
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{2}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights3->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(conv3->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add1->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add1->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(add2->output(0)), Mask({{}, {2}, {}, {}}));
compare_masks(*getMask(matmul_const->output(0)), Mask({{2}, {}}));
compare_masks(*getMask(matmul->output(0)), Mask({{}, {}}));
// TODO: add checks after MatMul/Reshape/Pooling mask propagation is ready
@@ -569,6 +637,13 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
//compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
//compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
//compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}}));
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}
@@ -580,7 +655,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3}, {}, {}, {}});
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_group = opset5::Constant::create(element::i8, weights_group_shape, {0});
@@ -589,12 +664,12 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
@@ -614,12 +689,12 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights1 = create_constant_with_zeros({weights_shape[0] - 4, weights_shape[1], weights_shape[2], weights_shape[3]}, {{}, {}, {}, {}});
auto weights1 = create_constant_with_zeros({weights_shape[0] - 5, weights_shape[1], weights_shape[2], weights_shape[3]}, {{}, {}, {}, {}});
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto weights_group = opset5::Constant::create(element::i8,
{
weights_group_shape[0] - 4,
weights_group_shape[0] - 5,
weights_group_shape[1],
weights_group_shape[2],
weights_group_shape[3]
@@ -627,11 +702,11 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
auto sub_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}});
auto sub_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
auto mul_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}});
auto mul_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
@@ -648,10 +723,10 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
auto conv_group = std::make_shared<opset5::GroupConvolution>(conv1, reshape, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto add_const = create_constant_with_zeros(Shape{1, 4, 1, 1}, {{}, {}, {}, {}});;
auto add_const = create_constant_with_zeros(Shape{1, 3, 1, 1}, {{}, {}, {}, {}});;
auto add = std::make_shared<opset5::Add>(conv_group, add_const);
auto weights_2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 4, weight_shape2[2], weight_shape2[3]}, {0});
auto weights_2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 5, weight_shape2[2], weight_shape2[3]}, {0});
auto conv2 = std::make_shared<opset5::Convolution>(add, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
function_ref = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
@@ -665,22 +740,21 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
m.run_passes(function);
}
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights_group->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(weights_group->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
compare_masks(*getMask(reshape->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}, {}}));
compare_masks(*getMask(reshape->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}, {}}));
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
@@ -841,7 +915,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
@@ -947,6 +1021,70 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
}
TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor1DScale) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
Shape weight_shape2{3, 8, 3, 3};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
input->set_friendly_name("input");
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
weights_1->set_friendly_name("weights_int8_const");
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{1}, {{}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{1}, {{}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
conv1->set_friendly_name("conv1");
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
auto add = std::make_shared<opset5::Add>(conv1, add_const);
add->set_friendly_name("add");
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
auto input_high = opset5::Constant::create(element::f32, Shape{1, 1, 1, 1}, {20});
auto output_low = opset5::Constant::create(element::f32, Shape{}, {1});
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto function = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerTensor1DScale.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(function);
}
compare_masks(*getMask(weights_1->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(sub->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(mul->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(add->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(fq->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
}
TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{8, 3, 3, 3};
@@ -959,12 +1097,12 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
convert->set_friendly_name("convert");
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
sub->set_friendly_name("sub");
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{}, {}, {}, {}});
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
mul->set_friendly_name("mul");
@@ -1841,7 +1979,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
}
TEST(TransformationTests, MaskPropagationReshapeUpWithShapeOf) {
TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) {
auto inputShapes = PartialShape{1, 6, 8, 8};
auto weightsShape = Shape{6, 6, 1, 1};
@@ -1862,20 +2000,57 @@ TEST(TransformationTests, MaskPropagationReshapeUpWithShapeOf) {
CoordinateDiff(2, 0),
Strides(2, 1));
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
auto weights = create_constant_with_zeros({
weightsShape[0] - 3,
weightsShape[1],
weightsShape[2],
weightsShape[3],
}, {{}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
auto shape_of_conv = std::make_shared<opset5::ShapeOf>(conv);
auto reshape = std::make_shared<opset5::Reshape>(conv, shape_of_conv, true);
auto conv_1_shape = Shape{6, 6, 1, 1};
auto conv_1_weights = create_constant_with_zeros({
conv_1_shape[0],
conv_1_shape[1] - 3,
conv_1_shape[2],
conv_1_shape[3],
}, {{}, {}, {}, {}});
auto conv_1 = std::make_shared<opset5::Convolution>(reshape, conv_1_weights, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
function_ref = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUpWithShapeOf.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::Pruning>();
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}
@@ -1999,7 +2174,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeUp) {
CoordinateDiff(2, 0),
Strides(2, 1));
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input}, "GoodReshapeUp");
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeUp.svg").run_on_function(function);
{
@@ -2049,7 +2224,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeDown) {
CoordinateDiff(2, 0),
Strides(2, 1));
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input}, "BadReshapeElementwiseModel");
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeDown.svg").run_on_function(function);
@@ -2101,7 +2276,7 @@ TEST(TransformationTests, MaskPropagationWrongDimsElementwise) {
CoordinateDiff(2, 0),
Strides(2, 1));
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input}, "BadReshapeElementwiseModel");
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationWrongDimsElementwise.svg").run_on_function(function);
@@ -2160,7 +2335,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
CoordinateDiff(2, 0),
Strides(2, 1));
function = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input}, "SEBlock");
function = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
auto first_conv_weights = create_constant_with_zeros({
@@ -2230,3 +2405,455 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
disable_rt_info_check();
enable_accuracy_check();
}
TEST_F(TransformationTestsF, PropagateMasksLinear) {
const auto linear_input_features = 62 * 62 * 6;
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{6, 3, 3, 3};
Shape weights_linear_shape{linear_input_features, 100};
Shape weights_last_linear_shape{100, 10};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
// Do net search 0 dim zeros by now
// Check stop mask prop for outer dim (1)
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{3, 4, 5}, {2, 3, 4}});
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros({
weights_shape[0] - 3,
weights_shape[1],
weights_shape[2],
weights_shape[3],
}, {{}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features / 2});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros({
weights_linear_shape[0] / 2,
weights_linear_shape[1] - 3,
}, {{}, {}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
auto weights_last_linear = create_constant_with_zeros({
weights_last_linear_shape[0] - 3,
weights_last_linear_shape[1],
}, {{}, {}});
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksLinear.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
auto ref_flatten_mask = std::set<uint64_t>();
for (uint64_t i = 0; i < linear_input_features / 2; ++i)
ref_flatten_mask.insert(i);
using nested_vector = std::vector<std::set<uint64_t>>;
auto reshape_ref_mask = nested_vector();
reshape_ref_mask.push_back({});
reshape_ref_mask.push_back(ref_flatten_mask);
auto linear_ref_mask = nested_vector();
linear_ref_mask.push_back(ref_flatten_mask);
linear_ref_mask.push_back({0, 1, 2});
compare_masks(*getMask(reshape_const->output(0)), Mask(reshape_ref_mask));
compare_masks(*getMask(reshape->output(0)), Mask(reshape_ref_mask));
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask(linear_ref_mask));
compare_masks(*getMask(linear->output(0)), Mask{{}, {0, 1, 2}});
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{0, 1, 2}, {}});
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}});
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}
TEST(TransformationTests, PruneLinearUp) {
const auto linear_input_features = 6 * 2 * 2;
auto inputShapes = PartialShape{1, 6, 2, 2};
auto weightsShape = Shape{6, 6, 1, 1};
auto linearShape = Shape{linear_input_features, linear_input_features};
auto lastLinearShape = Shape{10, linear_input_features};
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(conv, reshape_const, true);
auto linear_mask = Mask();
auto outer_dim_zeros = std::set<uint64_t>();
for (auto i = 0; i < linear_input_features / 2; ++i)
outer_dim_zeros.insert(i);
linear_mask.push_back({10, 11});
linear_mask.push_back(outer_dim_zeros);
auto linear_const = create_constant_with_zeros(linearShape, linear_mask);
auto linear = std::make_shared<opset5::MatMul>(reshape, linear_const);
auto add_mask = Mask();
add_mask.push_back({});
add_mask.push_back(outer_dim_zeros);
auto add_const = create_constant_with_zeros({1, linear_input_features}, add_mask);
auto add = std::make_shared<opset5::Add>(linear, add_const);
auto add_const_1 = create_constant_with_zeros({1, linear_input_features}, add_mask);
auto add_1 = std::make_shared<opset5::Add>(add, add_const_1);
auto add_2 = std::make_shared<opset5::Add>(add_1, reshape);
auto bad_add_const = create_constant_with_zeros({1, linear_input_features}, {{}, {}});
auto bad_add = std::make_shared<opset5::Add>(add_2, bad_add_const);
auto weights_end_linear = create_constant_with_zeros(lastLinearShape, {{1, 2, 3}, {3, 4, 6}});
auto last_linear = std::make_shared<opset5::MatMul>(bad_add, weights_end_linear, false, true);
auto function = std::make_shared<ngraph::Function>(OutputVector{last_linear}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneLinearUp.svg").run_on_function(function);
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(function);
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_end_linear.get_node_shared_ptr()->output(0)), Mask({{}, {}}));
compare_masks(*getMask(last_linear->output(0)), Mask({{}, {}}));
}
TEST(TransformationTests, PruneConvUpShort) {
const auto linear_input_features = 6 * 2 * 2;
auto inputShapes = PartialShape{1, 6, 2, 2};
auto convShape = Shape{1, 6, 2, 2};
auto weightsShape = Shape{6, 6, 1, 1};
auto lastLinearShape = Shape{10, linear_input_features};
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
auto conv_1_const = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
auto conv_1 = std::make_shared<opset5::Convolution>(conv, conv_1_const, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
auto add_const = create_constant_with_zeros(convShape, {{}, {1, 2, 3}, {}, {}});
auto add = std::make_shared<opset5::Add>(conv_1, add_const);
auto add_const_1 = create_constant_with_zeros(convShape, {{}, {1, 2, 3}, {}, {}});
auto add_1 = std::make_shared<opset5::Add>(add, add_const_1);
auto add_2 = std::make_shared<opset5::Add>(add_1, conv);
auto bad_add_const = create_constant_with_zeros(convShape, {{}, {}, {}, {}});
auto bad_add = std::make_shared<opset5::Add>(add_2, bad_add_const);
auto weights_end_conv = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {1, 2, 3}, {}, {}});
auto last_conv = std::make_shared<opset5::Convolution>(bad_add, weights_end_conv, Strides(2, 1),
CoordinateDiff(2, 0),
CoordinateDiff(2, 0),
Strides(2, 1));
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvUpShort.svg").run_on_function(function);
pass::Manager m;
m.register_pass<pass::Pruning>();
m.run_passes(function);
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(last_conv->output(0)), Mask({{}, {}, {}, {}}));
}
TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
const auto linear_input_features = 62 * 62;
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{6, 3, 3, 3};
Shape weights_linear_shape{linear_input_features, 100};
Shape weights_last_linear_shape{100, 10};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
// Do net search 0 dim zeros by now
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{3, 4, 5}, {}});
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros({
weights_linear_shape[0],
weights_linear_shape[1] - 3,
}, {{}, {}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
auto weights_last_linear = create_constant_with_zeros({
weights_last_linear_shape[0] - 3,
weights_last_linear_shape[1],
}, {{}, {}});
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulColsStopRowsUp.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {}, {}, {}}));
compare_masks(*getMask(reshape_const->output(0)), Mask{{}, {}, {}});
compare_masks(*getMask(reshape->output(0)), Mask{{}, {}, {}});
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2}}));
compare_masks(*getMask(linear->output(0)), Mask{{}, {}, {0, 1, 2}});
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{0, 1, 2}, {}});
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}
TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
// Checks rows matmul pruning + transpose input in matmul
const auto linear_input_features = 62 * 62;
Shape input_shape{1, 3, 64, 64};
Shape weights_shape{6, 3, 3, 3};
Shape weights_linear_shape{linear_input_features, 100};
Shape weights_last_linear_shape{10, 6};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{3, 4, 5}, {3, 4}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
// Do net search this zeros by now
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{}, {3, 4, 5}});
// To prune rows we should transpose featuremap. Did it by transpose_a = true MatMul constructor attr
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear, true, true);
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros({
weights_shape[0] - 3,
weights_shape[1],
weights_shape[2],
weights_shape[3],
}, {{}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 3, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_linear = create_constant_with_zeros({
weights_linear_shape[0],
weights_linear_shape[1],
}, {{}, {}});
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
auto weights_last_linear = create_constant_with_zeros({weights_last_linear_shape[0],
weights_last_linear_shape[1] - 3}, {{}, {}});
// To prune rows we should transpose featuremap. Did it by transpose_a = true MatMul constructor attr
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear, true, true);
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulRowsStopColsUp.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
compare_masks(*getMask(reshape_const->output(0)), Mask{{}, {0, 1, 2}, {}});
compare_masks(*getMask(reshape->output(0)), Mask{{}, {0, 1, 2}, {}});
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask{{}, {}});
compare_masks(*getMask(linear->output(0)), Mask{{}, {0, 1, 2}, {}});
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{}, {0, 1, 2}});
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}
TEST_F(TransformationTestsF, PropagateFlattenUp) {
// Propagate Flatten down is the same as in
// PruneLinearIsClosingAndInGroup test
using nested_vector = std::vector<std::set<uint64_t>>;
constexpr auto linear_input_features = 6 * 8 * 8;
Shape input_shape{1, 3, 8, 8};
Shape weights_shape{6, 3, 1, 1};
Shape weights_linear_shape{linear_input_features, 100};
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
// Skip just one zero in dim should lead to
// whole dimension invalidating.
auto add_zeros = std::set<uint64_t>();
for (size_t i = 1; i < linear_input_features / 2; i++)
add_zeros.insert(i);
auto add_mask = nested_vector();
add_mask.push_back({});
add_mask.push_back(add_zeros);
auto weights_add = create_constant_with_zeros({1, linear_input_features}, Mask(add_mask));
auto add = std::make_shared<opset5::Add>(reshape, weights_add);
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
auto linear = std::make_shared<opset5::MatMul>(add, weights_linear);
function = std::make_shared<Function>(NodeVector{linear}, ParameterVector{input});
{
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
auto weights = create_constant_with_zeros({
weights_shape[0] - 2,
weights_shape[1],
weights_shape[2],
weights_shape[3],
}, {{}, {}, {}, {}});
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
auto relu = std::make_shared<opset5::Relu>(conv);
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, 2 * linear_input_features / 3});
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
auto weights_add = create_constant_with_zeros({1, 2 * linear_input_features / 3}, Mask{{}, {}});
auto add = std::make_shared<opset5::Add>(reshape, weights_add);
auto weights_linear = create_constant_with_zeros({
2 * weights_linear_shape[0] / 3,
weights_linear_shape[1],
}, {{}, {}});
auto linear = std::make_shared<opset5::MatMul>(add, weights_linear);
function_ref = std::make_shared<Function>(NodeVector{linear}, ParameterVector{input});
}
if (VISUALIZE_TESTS_TREE)
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateFlattenUp.svg").run_on_function(function);
{
pass::Manager m;
m.register_pass<pass::InitMasks>();
m.register_pass<pass::PropagateMasks>();
m.run_passes(function);
}
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2}, {}, {}, {}}));
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2}, {}, {}}));
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2}, {}, {}}));
auto ref_flatten_mask = std::set<uint64_t>();
for (uint64_t i = linear_input_features / 6; i < linear_input_features / 2; ++i)
ref_flatten_mask.insert(i);
auto reshape_ref_mask = nested_vector();
reshape_ref_mask.push_back({});
reshape_ref_mask.push_back(ref_flatten_mask);
auto linear_ref_mask = nested_vector();
linear_ref_mask.push_back(ref_flatten_mask);
linear_ref_mask.push_back({});
compare_masks(*getMask(reshape_const->output(0)), Mask(reshape_ref_mask));
compare_masks(*getMask(reshape->output(0)), Mask(reshape_ref_mask));
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask(linear_ref_mask));
compare_masks(*getMask(linear->output(0)), Mask{{}, {}});
{
pass::Manager m;
m.register_pass<pass::ShrinkWeights>();
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
}