[Ngraph transformation][Pruning]Matmul ops pruning support (#10211)
* Linear pruning support * Minor fix * Fix types * Fix: stop 1d multiply propagation
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include <set>
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
@@ -49,6 +50,10 @@ public:
|
||||
m_adjust_value(adjust_value) {
|
||||
}
|
||||
|
||||
explicit Mask(const std::vector<value_type> val)
|
||||
: std::vector<value_type>(val) {
|
||||
}
|
||||
|
||||
Mask(std::initializer_list<std::initializer_list<uint64_t>> list)
|
||||
: std::vector<value_type>() {
|
||||
for (const auto & dim_values : list) {
|
||||
@@ -122,7 +127,7 @@ public:
|
||||
return result_mask;
|
||||
}
|
||||
|
||||
Mask::Ptr union_masks_reversed(Mask *const mask) {
|
||||
Mask::Ptr union_masks_reversed(Mask *const mask) const {
|
||||
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
|
||||
auto result_iter = result_mask->rbegin();
|
||||
auto mask_1_iter = rbegin();
|
||||
@@ -149,9 +154,13 @@ public:
|
||||
return result_mask;
|
||||
}
|
||||
|
||||
void add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
|
||||
bool add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
|
||||
if (m_callbacks.find(mask.get()) != m_callbacks.end())
|
||||
NGRAPH_DEBUG << "Attempt to rewrite callback, could lead to unexpected behaviour";
|
||||
|
||||
m_callbacks[mask.get()] = receive_callback;
|
||||
m_dependencies.push_back(mask.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Modify state of this mask by corresponding callback,
|
||||
@@ -173,11 +182,14 @@ public:
|
||||
m_need_initialization = false;
|
||||
// recursively apply callbacks for each dependent mask
|
||||
for (const auto & m_dependency : m_dependencies) {
|
||||
if (m_dependency == mask.get())
|
||||
continue;
|
||||
if (!m_dependency->apply_callback(shared_from_this())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
return mask->apply_callback(shared_from_this());
|
||||
}
|
||||
|
||||
void invalidate() {
|
||||
|
||||
@@ -30,8 +30,12 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
|
||||
|
||||
for (const auto & dim : dims) {
|
||||
if (dim >= shape.size()) {
|
||||
throw ngraph_error("Dim value " + std::to_string(dim) + " is out of range [0;" +std::to_string(shape.size() - 1) + "]");
|
||||
NGRAPH_DEBUG << "[WARNING] Attemt to initialize masks on " << dim
|
||||
<< " dimension which is out of shape " << shape
|
||||
<< " for node (" << const_node->get_friendly_name() << ")";
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t value = 0; value < shape[dim]; ++value) {
|
||||
Coordinate begin(shape.size(), 0);
|
||||
Coordinate end(shape);
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace pass {
|
||||
namespace init_masks {
|
||||
|
||||
class InitConvMask;
|
||||
class InitMatMulMask;
|
||||
|
||||
} // namespace init_masks
|
||||
} // namespace pass
|
||||
@@ -58,7 +59,56 @@ public:
|
||||
};
|
||||
|
||||
|
||||
class ngraph::pass::init_masks::InitMatMulMask : public MatcherPass {
|
||||
public:
|
||||
InitMatMulMask() {
|
||||
auto a = pattern::any_input();
|
||||
auto b = pattern::any_input();
|
||||
auto matmul = pattern::wrap_type<opset6::MatMul>({a, b});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_output = pattern_map.at(matmul);
|
||||
|
||||
// Assume constant always in the first input port.
|
||||
// Initializing weights mask:
|
||||
// 1. Looking for Const node with weights
|
||||
NodeVector weights_calculation_nodes;
|
||||
auto cur_node = m_output.get_node()->get_input_node_shared_ptr(1);
|
||||
|
||||
while (!ngraph::is_type<opset6::Constant>(cur_node) && cur_node->inputs().size()) {
|
||||
weights_calculation_nodes.push_back(cur_node);
|
||||
cur_node = cur_node->get_input_node_shared_ptr(0);
|
||||
}
|
||||
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for MatMul: " <<
|
||||
m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
// 2. Get constant rank to set mask on last dimension
|
||||
const auto const_op = std::dynamic_pointer_cast<opset6::Constant>(cur_node);
|
||||
const auto shape_rank = const_op->get_shape().size();
|
||||
const auto matmul = std::dynamic_pointer_cast<opset6::MatMul>(m_output.get_node_shared_ptr());
|
||||
const auto shift = (matmul->get_transpose_b())? 2 : 1;
|
||||
if (shape_rank < shift) {
|
||||
NGRAPH_DEBUG << "Can't init mask for MatMul: " <<
|
||||
m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
const size_t outer_dim = shape_rank - shift;
|
||||
// 3. Init mask for Const node
|
||||
InitConstMask({outer_dim}/* check only outer dim */).apply(cur_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "MatMulInitMask");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
ngraph::pass::InitMasks::InitMasks() {
|
||||
add_matcher<init_masks::InitConvMask>();
|
||||
add_matcher<init_masks::InitMatMulMask>();
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <iterator>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
@@ -20,6 +22,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
namespace mask_propagation {
|
||||
|
||||
class MatMul;
|
||||
class Convolution;
|
||||
class GroupConvolution;
|
||||
class GroupConvolutionReshape;
|
||||
@@ -45,6 +48,85 @@ static ngraph::Shape broadcast_shape_to_rank(ngraph::Shape shape_to_broadcast, i
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
|
||||
class ngraph::pass::mask_propagation::MatMul : public MatcherPass {
|
||||
public:
|
||||
MatMul() {
|
||||
auto a = pattern::any_input(pattern::has_static_shape());
|
||||
auto b = pattern::any_input(pattern::has_static_shape());
|
||||
auto matmul = pattern::wrap_type<opset6::MatMul>({a, b});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_a = pattern_map.at(a);
|
||||
const auto & m_b = pattern_map.at(b);
|
||||
const auto & m_matmul = pattern_map.at(matmul);
|
||||
|
||||
auto a_mask = getMask(m_a);
|
||||
auto b_mask = getMask(m_b);
|
||||
|
||||
if (!a_mask || !b_mask) {
|
||||
NGRAPH_DEBUG << "No mask for any input of " << m_matmul.get_node()->get_friendly_name() << "\n";
|
||||
return false;
|
||||
}
|
||||
auto a_mask_row = a_mask.get();
|
||||
auto b_mask_row = b_mask.get();
|
||||
|
||||
const auto matmul_op = std::dynamic_pointer_cast<opset6::MatMul>(m_matmul.get_node_shared_ptr());
|
||||
const auto transpose_a = matmul_op->get_transpose_a();
|
||||
const auto transpose_b = matmul_op->get_transpose_b();
|
||||
|
||||
const auto shape_a = m_a.get_shape();
|
||||
const auto shape_b = m_b.get_shape();
|
||||
|
||||
const auto a_inner_dim = (transpose_a)? shape_a.size() - 2 : shape_a.size() - 1;
|
||||
const auto a_outer_dim = (transpose_a)? shape_a.size() - 1 : shape_a.size() - 2;
|
||||
const auto b_inner_dim = (transpose_b)? shape_b.size() - 1 : shape_b.size() - 2;
|
||||
const auto b_outer_dim = (transpose_b)? shape_b.size() - 2 : shape_b.size() - 1;
|
||||
|
||||
|
||||
const auto matmul_range = m_matmul.get_shape().size();
|
||||
auto matmul_mask = std::make_shared<Mask>(matmul_range);
|
||||
auto matmul_mask_row = matmul_mask.get();
|
||||
const auto matmul_cols_dim = matmul_range - 1;
|
||||
const auto matmul_rows_dim = matmul_range - 2;
|
||||
|
||||
const auto matmul_callback = [=](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(matmul_rows_dim) = a_mask_row->at(a_outer_dim);
|
||||
cur_mask->at(matmul_cols_dim) = b_mask_row->at(b_outer_dim);
|
||||
if (a_mask_row->at(a_inner_dim) != b_mask_row->at(b_inner_dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
};
|
||||
// Connect a with matmul mask
|
||||
matmul_mask->add_callback(matmul_callback, a_mask);
|
||||
a_mask->add_callback([=](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(a_inner_dim) = b_mask_row->at(b_inner_dim);
|
||||
cur_mask->at(a_outer_dim) = matmul_mask_row->at(matmul_rows_dim);
|
||||
return true;
|
||||
}, matmul_mask);
|
||||
// connect b with matmul mask
|
||||
matmul_mask->add_callback(matmul_callback, b_mask);
|
||||
b_mask->add_callback([=](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(b_inner_dim) = a_mask_row->at(a_inner_dim);
|
||||
cur_mask->at(b_outer_dim) = matmul_mask_row->at(matmul_cols_dim);
|
||||
return true;
|
||||
}, matmul_mask);
|
||||
|
||||
if (!matmul_mask->apply_callback(a_mask)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
setMask(m_matmul, matmul_mask);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "MatMulMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
|
||||
public:
|
||||
Convolution() {
|
||||
@@ -69,42 +151,47 @@ public:
|
||||
}
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
|
||||
if (auto input_mask = getMask(m_input)) {
|
||||
auto input_mask_row = input_mask.get();
|
||||
// Weights input channel is connected to the convolution input channel dimension
|
||||
// so we update weights mask to be aligned with input shape.
|
||||
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1/* weights input channel */) = input_mask_row->at(1 /* input data channel */);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1) = weights_mask_row->at(1);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
if (!weights_mask->apply_callback(input_mask)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Create output mask that describes which channel dimensions will be removed
|
||||
auto conv_mask = std::make_shared<Mask>(m_weights.get_shape().size());
|
||||
auto conv_mask_row = conv_mask.get();
|
||||
auto input_mask = getMask(m_input);
|
||||
Mask* input_mask_row = nullptr;
|
||||
if (input_mask)
|
||||
input_mask_row = input_mask.get();
|
||||
|
||||
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1) = weights_mask_row->at(0/*weights output channel dim */);
|
||||
const auto conv_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1/*input data channel*/) = weights_mask_row->at(0 /* weights output channel dim*/);
|
||||
if (input_mask_row && input_mask_row->at(1) != weights_mask_row->at(1))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, weights_mask);
|
||||
};
|
||||
|
||||
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
if (input_mask) {
|
||||
// Weights input channel is connected to the convolution input channel dimension
|
||||
// so we update weights mask to be aligned with input shape.
|
||||
conv_mask->add_callback(conv_mask_callback, input_mask);
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1) = weights_mask_row->at(1);
|
||||
return true;
|
||||
}, conv_mask);
|
||||
}
|
||||
|
||||
conv_mask->add_callback(conv_mask_callback, weights_mask);
|
||||
weights_mask->add_callback([input_mask_row, conv_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(0) = conv_mask_row->at(1);
|
||||
if (input_mask_row)
|
||||
cur_mask->at(1) = input_mask_row->at(1);
|
||||
return true;
|
||||
}, conv_mask);
|
||||
|
||||
if (!conv_mask->apply_callback(weights_mask)) {
|
||||
bool status;
|
||||
if (input_mask)
|
||||
status = conv_mask->apply_callback(input_mask);
|
||||
else
|
||||
status = conv_mask->apply_callback(weights_mask);
|
||||
|
||||
if (!status)
|
||||
return false;
|
||||
}
|
||||
|
||||
setMask(m_output, conv_mask);
|
||||
return true;
|
||||
@@ -154,37 +241,30 @@ public:
|
||||
}
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
|
||||
// Weights input channel is connected to the convolution input channel dimension
|
||||
// so we update weights mask to be aligned with input shape.
|
||||
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(0) = input_mask_row->at(1);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1) = weights_mask_row->at(0);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
if (!weights_mask->apply_callback(input_mask)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update output channels mask dims
|
||||
auto conv_mask = std::make_shared<Mask>(input_shape.rank().get_length());
|
||||
auto conv_mask_row = conv_mask.get();
|
||||
|
||||
conv_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1/*input data channel*/) = input_mask_row->at(1/*output data channel*/);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
input_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1/*output data channel*/) = conv_mask_row->at(1/*input data channel*/);
|
||||
return true;
|
||||
}, conv_mask);
|
||||
|
||||
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(1) = weights_mask_row->at(0);
|
||||
cur_mask->at(1/*input data channel*/) = weights_mask_row->at(0/*weights output channel dim*/);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->at(0) = conv_mask_row->at(1);
|
||||
cur_mask->at(0/*weights output channel dim*/) = conv_mask_row->at(1/*output data channel*/);
|
||||
return true;
|
||||
}, conv_mask);
|
||||
|
||||
if (!conv_mask->apply_callback(weights_mask)) {
|
||||
if (!conv_mask->apply_callback(input_mask)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -297,14 +377,13 @@ public:
|
||||
const auto & m_output = pattern_map.at(eltwise);
|
||||
const auto & m_input = pattern_map.at(input);
|
||||
|
||||
// Case when input masks should be united instead of intersection
|
||||
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
|
||||
|
||||
const auto & input_rank = m_input.get_partial_shape().rank().get_length();
|
||||
const auto & weights_rank = m_weights.get_partial_shape().rank().get_length();
|
||||
// Here assuming that masks can be propagated only through 3/4 dimensional tensors
|
||||
// (since channel dim is necessary)
|
||||
if (weights_rank < 3 || input_rank < 3) return false;
|
||||
// (since channel dim is necessary) or tensors with equal rank.
|
||||
if (!((weights_rank > 2 && input_rank > 2) || weights_rank == input_rank)) return false;
|
||||
// Case when input masks should be united instead of intersection
|
||||
bool union_eltwise_type = ngraph::is_type<opset6::Multiply>(m_output.get_node_shared_ptr());
|
||||
|
||||
// In case if first of the inputs is constant
|
||||
InitConstMask({0, 1/* potential output channel dim */}).apply(m_input.get_node_shared_ptr());
|
||||
@@ -578,7 +657,7 @@ public:
|
||||
opset6::Elu, opset6::HardSigmoid, opset6::PRelu, opset6::Mish,
|
||||
opset6::Softmax, opset6::SoftPlus, opset6::Convert, opset6::ConvertLike,
|
||||
opset6::AvgPool, opset6::MaxPool, opset6::ROIPooling, opset6::PSROIPooling,
|
||||
opset6::Pad>();
|
||||
opset6::Pad, opset6::MVN>();
|
||||
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
@@ -646,6 +725,39 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static std::pair<std::set<uint64_t>, bool> squeeze_mask(
|
||||
const std::set<uint64_t> mask_dim, const size_t elems_per_ch, const bool squeeze) {
|
||||
bool should_init_dep = false;
|
||||
auto ret_set = std::set<uint64_t>();
|
||||
auto mask_dim_copy = std::set<uint64_t>();
|
||||
std::copy(mask_dim.begin(), mask_dim.end(), std::inserter(mask_dim_copy, mask_dim_copy.begin()));
|
||||
while (mask_dim_copy.size()) {
|
||||
const auto elem = *mask_dim_copy.begin();
|
||||
const auto ch = elem / elems_per_ch;
|
||||
// Check all channel is zeroed
|
||||
const auto low = mask_dim_copy.lower_bound(ch * elems_per_ch);
|
||||
const auto upper = mask_dim_copy.lower_bound((ch + 1) * elems_per_ch);
|
||||
auto channel_zeros = std::set<uint64_t>();
|
||||
std::copy(low, upper, std::inserter(channel_zeros, channel_zeros.begin()));
|
||||
|
||||
// Remove all zeros related to current channel from iter mask
|
||||
mask_dim_copy.erase(low, upper);
|
||||
// In case any of elements are not zeroed - skip entire channel
|
||||
if (channel_zeros.size() != elems_per_ch) {
|
||||
should_init_dep = true;
|
||||
continue;
|
||||
}
|
||||
// Add zeros for current channel in current mask
|
||||
if (squeeze)
|
||||
ret_set.insert(ch);
|
||||
else
|
||||
ret_set.insert(channel_zeros.begin(), channel_zeros.end());
|
||||
}
|
||||
return std::make_pair(ret_set, should_init_dep);
|
||||
}
|
||||
|
||||
|
||||
class ngraph::pass::mask_propagation::Reshape : public MatcherPass {
|
||||
public:
|
||||
Reshape() {
|
||||
@@ -665,12 +777,15 @@ public:
|
||||
if (is_type<opset6::GroupConvolution>(inp.get_node()))
|
||||
return true;
|
||||
|
||||
// Can't process non constant node in the shape input by now.
|
||||
if (!std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr())) {
|
||||
NGRAPH_DEBUG << "Can't process reshape node " << m_output.get_node()->get_friendly_name()
|
||||
<<" with no constant node " << m_weights.get_node()->get_friendly_name()
|
||||
<< " as shape input.";
|
||||
return false;
|
||||
auto constant = std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr());
|
||||
if (!constant) {
|
||||
constant = get_constant_from_source(m_weights.get_node_shared_ptr());
|
||||
if (!constant) {
|
||||
NGRAPH_DEBUG << "Can't process reshape node " << m_output.get_node()->get_friendly_name()
|
||||
<<" with no constant node " << m_weights.get_node()->get_friendly_name()
|
||||
<< " as shape input.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check reshape operation reshape only dimension without masks
|
||||
@@ -683,44 +798,102 @@ public:
|
||||
|
||||
// Check dimensions equality from the begining and allow
|
||||
// to propagate masks only for dimensions which equal from the begining
|
||||
size_t i = 0;
|
||||
for (; i < std::min(input_shape.size(), output_shape.size()); ++i) {
|
||||
if (input_shape[i] != output_shape[i])
|
||||
break;
|
||||
size_t not_reshaped_dims;
|
||||
{
|
||||
size_t i = 0;
|
||||
for (; i < std::min(input_shape.size(), output_shape.size()); ++i) {
|
||||
if (input_shape[i] != output_shape[i])
|
||||
break;
|
||||
}
|
||||
not_reshaped_dims = i;
|
||||
}
|
||||
auto not_reshaped_dims = i;
|
||||
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto weights_mask_row = weights_mask.get();
|
||||
auto output_mask_row = output_mask.get();
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(weights_mask_row);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
weights_mask->add_callback([input_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool{
|
||||
// Propagate masks down through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = input_mask_row->at(dim);
|
||||
else if (cur_mask->at(dim) != input_mask_row->at(dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(weights_mask_row);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
// Case when reshape make flatten last dimension
|
||||
if (input_shape.size() > output_shape.size() &&
|
||||
output_shape.size() == not_reshaped_dims + 1) {
|
||||
const size_t elems_per_ch = std::accumulate(input_shape.begin() + not_reshaped_dims + 1,
|
||||
input_shape.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
weights_mask->add_callback([output_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool {
|
||||
// Propagate masks up through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < std::min(cur_mask->size(), output_mask_row->size()); ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = output_mask_row->at(dim);
|
||||
else if (cur_mask->at(dim) != output_mask_row->at(dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, output_mask);
|
||||
input_mask->add_callback([weights_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
|
||||
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = weights_mask_row->at(dim);
|
||||
|
||||
bool should_init_dep;
|
||||
std::set<uint64_t> updated_mask;
|
||||
std::tie(updated_mask, should_init_dep) = squeeze_mask(weights_mask_row->at(not_reshaped_dims), elems_per_ch, true);
|
||||
|
||||
cur_mask->at(not_reshaped_dims) = updated_mask;
|
||||
if (should_init_dep) cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
weights_mask->add_callback([input_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
|
||||
// Propagate masks down through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = input_mask_row->at(dim);
|
||||
// Flat the last mask
|
||||
for (auto &ch : input_mask_row->at(not_reshaped_dims))
|
||||
for (auto idx = ch * elems_per_ch; idx < (ch + 1) * elems_per_ch; ++idx)
|
||||
cur_mask->at(not_reshaped_dims).insert(idx);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(weights_mask_row);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
weights_mask->add_callback([output_mask_row, not_reshaped_dims, elems_per_ch](Mask::Ptr cur_mask) -> bool {
|
||||
// Propagate masks up through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < not_reshaped_dims; ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = output_mask_row->at(dim);
|
||||
// For the last dimension keep only those zeros which completely
|
||||
// covering a channel
|
||||
bool should_init_dep;
|
||||
std::set<uint64_t> updated_mask;
|
||||
std::tie(updated_mask, should_init_dep) = squeeze_mask(output_mask_row->at(not_reshaped_dims), elems_per_ch, false);
|
||||
|
||||
cur_mask->at(not_reshaped_dims) = updated_mask;
|
||||
if (should_init_dep) cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, output_mask);
|
||||
} else {
|
||||
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(weights_mask_row);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
weights_mask->add_callback([input_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool{
|
||||
// Propagate masks down through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = input_mask_row->at(dim);
|
||||
else if (cur_mask->at(dim) != input_mask_row->at(dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
output_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(weights_mask_row);
|
||||
return true;
|
||||
}, weights_mask);
|
||||
|
||||
weights_mask->add_callback([output_mask_row, not_reshaped_dims](Mask::Ptr cur_mask) -> bool {
|
||||
// Propagate masks up through dimension only if this dimension isn't reshaped
|
||||
for (size_t dim = 0; dim < std::min(cur_mask->size(), output_mask_row->size()); ++dim)
|
||||
if (dim < not_reshaped_dims)
|
||||
cur_mask->at(dim) = output_mask_row->at(dim);
|
||||
else if (cur_mask->at(dim) != output_mask_row->at(dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, output_mask);
|
||||
}
|
||||
|
||||
weights_mask->apply_callback(input_mask);
|
||||
setMask(m_output, output_mask);
|
||||
@@ -746,11 +919,11 @@ public:
|
||||
const auto & node = m.get_match_root();
|
||||
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
auto output_mask_row = output_mask.get();
|
||||
bool any_input_with_masks = false;
|
||||
for (const auto & input : node->input_values()) {
|
||||
if (auto input_mask = getMask(input)) {
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto output_mask_row = output_mask.get();
|
||||
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->clean_dim_values();
|
||||
if (!output_mask_row->all_dims_are_empty())
|
||||
@@ -800,6 +973,7 @@ public:
|
||||
};
|
||||
|
||||
ngraph::pass::PropagateMasks::PropagateMasks() {
|
||||
add_matcher<mask_propagation::MatMul>();
|
||||
add_matcher<mask_propagation::Convolution>();
|
||||
add_matcher<mask_propagation::GroupConvolutionReshape>();
|
||||
add_matcher<mask_propagation::GroupConvolution>();
|
||||
|
||||
@@ -55,8 +55,10 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
if (mask->adjust_value() && !mask->all_dims_are_empty()) {
|
||||
std::vector<int64_t> new_const_value;
|
||||
auto value = const_node->cast_vector<int64_t>();
|
||||
for (size_t i = 0; i < mask->size(); i++)
|
||||
new_const_value.push_back(value[i] - mask->at(i).size());
|
||||
for (size_t i = 0; i < mask->size(); i++) {
|
||||
const int64_t res = value[i] - mask->at(i).size();
|
||||
new_const_value.push_back((res > 0)? res : value[i]);
|
||||
}
|
||||
|
||||
const auto new_const = opset6::Constant::create(const_node->get_element_type(),
|
||||
const_node->get_shape(), new_const_value);
|
||||
|
||||
@@ -188,10 +188,10 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {1, 2, 3, 4, 5}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(relu, add_const);
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2, 3}, {}, {}});
|
||||
auto sub_const = create_constant_with_zeros(Shape{6, 1, 1}, {{1, 2}, {}, {}});
|
||||
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {4}, {}, {}});
|
||||
auto mul_const = create_constant_with_zeros(Shape{1, 6, 1, 1}, {{}, {3}, {}, {}});
|
||||
auto mul = std::make_shared<ov::op::v1::Multiply>(sub, mul_const);
|
||||
|
||||
auto weights2 = create_constant_with_zeros(weights_shape2, {{1, 2}, {1, 2, 3}, {}, {}});
|
||||
@@ -202,21 +202,21 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
|
||||
auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 4, weights_shape[1], weights_shape[2] , weights_shape[3]}, {0});
|
||||
auto weights = opset5::Constant::create(element::f32, {weights_shape[0] - 3, weights_shape[1], weights_shape[2] , weights_shape[3]}, {0});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto add_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1});
|
||||
auto add_const = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1});
|
||||
auto add = std::make_shared<opset5::Add>(relu, add_const);
|
||||
|
||||
auto sub_const = opset5::Constant::create(element::f32, Shape{2, 1, 1}, {1});
|
||||
auto sub_const = opset5::Constant::create(element::f32, Shape{3, 1, 1}, {1});
|
||||
auto sub = std::make_shared<opset5::Subtract>(add, sub_const);
|
||||
|
||||
auto mul_const = opset5::Constant::create(element::f32, Shape{1, 2, 1, 1}, {1});
|
||||
auto mul_const = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1});
|
||||
auto mul = std::make_shared<ov::op::v1::Multiply>(sub, mul_const);
|
||||
|
||||
auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 4, weights_shape2[2], weights_shape2[3]}, {1});
|
||||
auto weights2 = opset5::Constant::create(element::f32, {weights_shape2[0], weights_shape2[1] - 3, weights_shape2[2], weights_shape2[3]}, {1});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(mul, weights2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
function_ref = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
@@ -229,16 +229,16 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const), Mask({{1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(weights->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(add_const), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const), Mask({{1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
{
|
||||
pass::Manager m;
|
||||
@@ -495,7 +495,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
|
||||
Shape input_shape{1, 3, 3, 3};
|
||||
|
||||
auto input1 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
@@ -529,7 +529,8 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
auto reshape = std::make_shared<opset5::Reshape>(add1, opset5::Constant::create(element::i64, Shape{2}, {1, 6}), true);
|
||||
reshape->set_friendly_name("reshape");
|
||||
|
||||
auto matmul = std::make_shared<opset5::MatMul>(reshape, opset5::Constant::create(element::f32, Shape{6, 100}, {1.}));
|
||||
auto matmul_const = opset5::Constant::create(element::f32, Shape{6, 100}, {1.});
|
||||
auto matmul = std::make_shared<opset5::MatMul>(reshape, matmul_const);
|
||||
matmul->set_friendly_name("matmul");
|
||||
|
||||
auto add2 = std::make_shared<opset5::Add>(conv2, create_constant_with_zeros({6, 1, 1}, {{2}, {}, {}}));
|
||||
@@ -543,24 +544,91 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv3->set_friendly_name("conv3");
|
||||
|
||||
auto f = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
|
||||
function = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
|
||||
{
|
||||
auto input1 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input1->set_friendly_name("input1");
|
||||
|
||||
Shape weights1_shape{6, 3, 3, 3};
|
||||
auto weights1 = create_constant_with_zeros({
|
||||
weights1_shape[0] - 1,
|
||||
weights1_shape[1],
|
||||
weights1_shape[2],
|
||||
weights1_shape[3]
|
||||
}, {{}, {}, {}, {}});
|
||||
weights1.get_node_shared_ptr()->set_friendly_name("weights1");
|
||||
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input1, weights1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv1->set_friendly_name("conv1");
|
||||
|
||||
auto relu = std::make_shared<opset5::Relu>(conv1);
|
||||
relu->set_friendly_name("relu");
|
||||
|
||||
auto input2 = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input2->set_friendly_name("input2");
|
||||
|
||||
Shape weights2_shape{6, 3, 3, 3};
|
||||
auto weights2 = create_constant_with_zeros({weights2_shape[0] - 1,
|
||||
weights2_shape[1],
|
||||
weights2_shape[2],
|
||||
weights2_shape[3]
|
||||
}, {{2, 3}, {}, {}, {}});
|
||||
weights2.get_node_shared_ptr()->set_friendly_name("weights2");
|
||||
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(input2, weights2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv2->set_friendly_name("conv2");
|
||||
|
||||
auto add1 = std::make_shared<opset5::Add>(conv2, conv1);
|
||||
add1->set_friendly_name("add1");
|
||||
|
||||
auto reshape = std::make_shared<opset5::Reshape>(add1, opset5::Constant::create(element::i64, Shape{2}, {1, 5}), true);
|
||||
reshape->set_friendly_name("reshape");
|
||||
|
||||
auto matmul = std::make_shared<opset5::MatMul>(reshape, opset5::Constant::create(element::f32, Shape{5, 100}, {1.}));
|
||||
matmul->set_friendly_name("matmul");
|
||||
|
||||
auto add2 = std::make_shared<opset5::Add>(conv2, create_constant_with_zeros({5, 1, 1}, {{}, {}, {}}));
|
||||
add2->set_friendly_name("add2");
|
||||
|
||||
Shape weights_shape3{6, 6, 1, 1};
|
||||
auto weights3 = opset5::Constant::create(element::f32,
|
||||
{weights_shape3[0],
|
||||
weights_shape3[1] - 1,
|
||||
weights_shape3[2],
|
||||
weights_shape3[3]
|
||||
}, {0});
|
||||
weights3->set_friendly_name("weights3");
|
||||
|
||||
auto conv3 = std::make_shared<opset5::Convolution>(add2, weights3, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv3->set_friendly_name("conv3");
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{matmul, conv3}, ParameterVector{input1, input2});
|
||||
}
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg").run_on_function(f);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(f);
|
||||
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(weights3->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
compare_masks(*getMask(conv3->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights2.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(add2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(add1->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
compare_masks(*getMask(add2->output(0)), Mask({{}, {2}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(matmul_const->output(0)), Mask({{2}, {}}));
|
||||
compare_masks(*getMask(matmul->output(0)), Mask({{}, {}}));
|
||||
|
||||
// TODO: add checks after MatMul/Reshape/Pooling mask propagation is ready
|
||||
@@ -569,6 +637,13 @@ TEST(TransformationTests, PropagateMasksHardDependencies) {
|
||||
//compare_masks(*getMask(relu), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
|
||||
//compare_masks(*getMask(weights2), Mask({{}, {0, 1, 2, 3, 4, 5}, {}, {}}));
|
||||
//compare_masks(*getMask(conv2), Mask({{}, {}, {}, {}}));
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
@@ -580,7 +655,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
|
||||
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto weights1 = create_constant_with_zeros(weights_shape, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto weights_group = opset5::Constant::create(element::i8, weights_group_shape, {0});
|
||||
@@ -589,12 +664,12 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
@@ -614,12 +689,12 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
|
||||
auto weights1 = create_constant_with_zeros({weights_shape[0] - 4, weights_shape[1], weights_shape[2], weights_shape[3]}, {{}, {}, {}, {}});
|
||||
auto weights1 = create_constant_with_zeros({weights_shape[0] - 5, weights_shape[1], weights_shape[2], weights_shape[3]}, {{}, {}, {}, {}});
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, weights1, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto weights_group = opset5::Constant::create(element::i8,
|
||||
{
|
||||
weights_group_shape[0] - 4,
|
||||
weights_group_shape[0] - 5,
|
||||
weights_group_shape[1],
|
||||
weights_group_shape[2],
|
||||
weights_group_shape[3]
|
||||
@@ -627,11 +702,11 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_group, element::f32);
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
auto sub_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{4, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
auto mul_const = create_constant_with_zeros(Shape{3, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
|
||||
|
||||
@@ -648,10 +723,10 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
auto conv_group = std::make_shared<opset5::GroupConvolution>(conv1, reshape, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 4, 1, 1}, {{}, {}, {}, {}});;
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 3, 1, 1}, {{}, {}, {}, {}});;
|
||||
auto add = std::make_shared<opset5::Add>(conv_group, add_const);
|
||||
|
||||
auto weights_2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 4, weight_shape2[2], weight_shape2[3]}, {0});
|
||||
auto weights_2 = opset5::Constant::create(element::f32, {weight_shape2[0], weight_shape2[1] - 5, weight_shape2[2], weight_shape2[3]}, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(add, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
function_ref = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
@@ -665,22 +740,21 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
m.run_passes(function);
|
||||
}
|
||||
|
||||
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(weights1.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {0 , 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_group->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(weights_group->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul_const.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(reshape->output(0)), Mask({{0, 1, 2, 3}, {}, {}, {}, {}}));
|
||||
compare_masks(*getMask(reshape->output(0)), Mask({{0, 1, 2, 3, 4}, {}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(conv_group->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3, 4}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {0, 1, 2, 3}, {}, {}}));
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
@@ -841,7 +915,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
@@ -947,6 +1021,70 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor1DScale) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
Shape weight_shape2{3, 8, 3, 3};
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
auto weights_1 = opset5::Constant::create(element::i8, weights_shape, {0});
|
||||
weights_1->set_friendly_name("weights_int8_const");
|
||||
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{1}, {{}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{1}, {{}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
auto conv1 = std::make_shared<opset5::Convolution>(input, mul, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
conv1->set_friendly_name("conv1");
|
||||
|
||||
auto add_const = create_constant_with_zeros(Shape{1, 8, 1, 1}, {{}, {0, 1, 2, 3, 4}, {}, {}});;
|
||||
auto add = std::make_shared<opset5::Add>(conv1, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
auto input_low = opset5::Constant::create(element::f32, Shape{1}, {0});
|
||||
auto input_high = opset5::Constant::create(element::f32, Shape{1, 1, 1, 1}, {20});
|
||||
auto output_low = opset5::Constant::create(element::f32, Shape{}, {1});
|
||||
auto output_high = opset5::Constant::create(element::f32, Shape{}, {10});
|
||||
auto fq = std::make_shared<opset5::FakeQuantize>(add, input_low, input_high, output_low, output_high, 8);
|
||||
|
||||
auto weights_2 = opset5::Constant::create(element::f32, weight_shape2, {0});
|
||||
auto conv2 = std::make_shared<opset5::Convolution>(fq, weights_2, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto function = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerTensor1DScale.svg").run_on_function(function);
|
||||
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
|
||||
compare_masks(*getMask(weights_1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(sub->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(mul->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(add_const.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(add->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(fq->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv2->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{8, 3, 3, 3};
|
||||
@@ -959,12 +1097,12 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
|
||||
auto convert = std::make_shared<opset5::Convert>(weights_1, element::f32);
|
||||
convert->set_friendly_name("convert");
|
||||
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3}, {}, {}, {}});
|
||||
auto sub_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
|
||||
auto sub = std::make_shared<opset5::Subtract>(convert, sub_const);
|
||||
sub->set_friendly_name("sub");
|
||||
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{0, 1, 2, 3, 4}, {}, {}, {}});
|
||||
auto mul_const = create_constant_with_zeros(Shape{8, 1, 1, 1}, {{}, {}, {}, {}});
|
||||
auto mul = std::make_shared<opset5::Multiply>(sub, mul_const);
|
||||
mul->set_friendly_name("mul");
|
||||
|
||||
@@ -1841,7 +1979,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, MaskPropagationReshapeUpWithShapeOf) {
|
||||
TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) {
|
||||
auto inputShapes = PartialShape{1, 6, 8, 8};
|
||||
auto weightsShape = Shape{6, 6, 1, 1};
|
||||
|
||||
@@ -1862,20 +2000,57 @@ TEST(TransformationTests, MaskPropagationReshapeUpWithShapeOf) {
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
|
||||
auto weights = create_constant_with_zeros({
|
||||
weightsShape[0] - 3,
|
||||
weightsShape[1],
|
||||
weightsShape[2],
|
||||
weightsShape[3],
|
||||
}, {{}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto shape_of_conv = std::make_shared<opset5::ShapeOf>(conv);
|
||||
auto reshape = std::make_shared<opset5::Reshape>(conv, shape_of_conv, true);
|
||||
|
||||
auto conv_1_shape = Shape{6, 6, 1, 1};
|
||||
auto conv_1_weights = create_constant_with_zeros({
|
||||
conv_1_shape[0],
|
||||
conv_1_shape[1] - 3,
|
||||
conv_1_shape[2],
|
||||
conv_1_shape[3],
|
||||
}, {{}, {}, {}, {}});
|
||||
auto conv_1 = std::make_shared<opset5::Convolution>(reshape, conv_1_weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUpWithShapeOf.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2, 3}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv_1_weights.get_node_shared_ptr()->output(0)), Mask({{}, {1, 2, 3}, {}, {}}));
|
||||
compare_masks(*getMask(conv_1->output(0)), Mask({{}, {}, {}, {}}));
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
@@ -1999,7 +2174,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeUp) {
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input}, "GoodReshapeUp");
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeUp.svg").run_on_function(function);
|
||||
{
|
||||
@@ -2049,7 +2224,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeDown) {
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input}, "BadReshapeElementwiseModel");
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeDown.svg").run_on_function(function);
|
||||
@@ -2101,7 +2276,7 @@ TEST(TransformationTests, MaskPropagationWrongDimsElementwise) {
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input}, "BadReshapeElementwiseModel");
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationWrongDimsElementwise.svg").run_on_function(function);
|
||||
@@ -2160,7 +2335,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
function = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input}, "SEBlock");
|
||||
function = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input});
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
|
||||
auto first_conv_weights = create_constant_with_zeros({
|
||||
@@ -2230,3 +2405,455 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, PropagateMasksLinear) {
|
||||
const auto linear_input_features = 62 * 62 * 6;
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{6, 3, 3, 3};
|
||||
Shape weights_linear_shape{linear_input_features, 100};
|
||||
Shape weights_last_linear_shape{100, 10};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
// Do net search 0 dim zeros by now
|
||||
// Check stop mask prop for outer dim (1)
|
||||
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{3, 4, 5}, {2, 3, 4}});
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
|
||||
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros({
|
||||
weights_shape[0] - 3,
|
||||
weights_shape[1],
|
||||
weights_shape[2],
|
||||
weights_shape[3],
|
||||
}, {{}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features / 2});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros({
|
||||
weights_linear_shape[0] / 2,
|
||||
weights_linear_shape[1] - 3,
|
||||
}, {{}, {}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
auto weights_last_linear = create_constant_with_zeros({
|
||||
weights_last_linear_shape[0] - 3,
|
||||
weights_last_linear_shape[1],
|
||||
}, {{}, {}});
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
|
||||
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksLinear.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
|
||||
auto ref_flatten_mask = std::set<uint64_t>();
|
||||
for (uint64_t i = 0; i < linear_input_features / 2; ++i)
|
||||
ref_flatten_mask.insert(i);
|
||||
|
||||
using nested_vector = std::vector<std::set<uint64_t>>;
|
||||
auto reshape_ref_mask = nested_vector();
|
||||
reshape_ref_mask.push_back({});
|
||||
reshape_ref_mask.push_back(ref_flatten_mask);
|
||||
auto linear_ref_mask = nested_vector();
|
||||
linear_ref_mask.push_back(ref_flatten_mask);
|
||||
linear_ref_mask.push_back({0, 1, 2});
|
||||
|
||||
compare_masks(*getMask(reshape_const->output(0)), Mask(reshape_ref_mask));
|
||||
compare_masks(*getMask(reshape->output(0)), Mask(reshape_ref_mask));
|
||||
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask(linear_ref_mask));
|
||||
compare_masks(*getMask(linear->output(0)), Mask{{}, {0, 1, 2}});
|
||||
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{0, 1, 2}, {}});
|
||||
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}});
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, PruneLinearUp) {
|
||||
const auto linear_input_features = 6 * 2 * 2;
|
||||
auto inputShapes = PartialShape{1, 6, 2, 2};
|
||||
auto weightsShape = Shape{6, 6, 1, 1};
|
||||
auto linearShape = Shape{linear_input_features, linear_input_features};
|
||||
auto lastLinearShape = Shape{10, linear_input_features};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
|
||||
auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(conv, reshape_const, true);
|
||||
|
||||
auto linear_mask = Mask();
|
||||
auto outer_dim_zeros = std::set<uint64_t>();
|
||||
for (auto i = 0; i < linear_input_features / 2; ++i)
|
||||
outer_dim_zeros.insert(i);
|
||||
linear_mask.push_back({10, 11});
|
||||
linear_mask.push_back(outer_dim_zeros);
|
||||
auto linear_const = create_constant_with_zeros(linearShape, linear_mask);
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, linear_const);
|
||||
|
||||
auto add_mask = Mask();
|
||||
add_mask.push_back({});
|
||||
add_mask.push_back(outer_dim_zeros);
|
||||
auto add_const = create_constant_with_zeros({1, linear_input_features}, add_mask);
|
||||
auto add = std::make_shared<opset5::Add>(linear, add_const);
|
||||
auto add_const_1 = create_constant_with_zeros({1, linear_input_features}, add_mask);
|
||||
auto add_1 = std::make_shared<opset5::Add>(add, add_const_1);
|
||||
auto add_2 = std::make_shared<opset5::Add>(add_1, reshape);
|
||||
|
||||
auto bad_add_const = create_constant_with_zeros({1, linear_input_features}, {{}, {}});
|
||||
auto bad_add = std::make_shared<opset5::Add>(add_2, bad_add_const);
|
||||
|
||||
auto weights_end_linear = create_constant_with_zeros(lastLinearShape, {{1, 2, 3}, {3, 4, 6}});
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(bad_add, weights_end_linear, false, true);
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_linear}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneLinearUp.svg").run_on_function(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(function);
|
||||
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_end_linear.get_node_shared_ptr()->output(0)), Mask({{}, {}}));
|
||||
compare_masks(*getMask(last_linear->output(0)), Mask({{}, {}}));
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, PruneConvUpShort) {
|
||||
const auto linear_input_features = 6 * 2 * 2;
|
||||
auto inputShapes = PartialShape{1, 6, 2, 2};
|
||||
auto convShape = Shape{1, 6, 2, 2};
|
||||
auto weightsShape = Shape{6, 6, 1, 1};
|
||||
auto lastLinearShape = Shape{10, linear_input_features};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, inputShapes);
|
||||
auto weights = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
|
||||
auto conv_1_const = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {}, {}, {}});
|
||||
auto conv_1 = std::make_shared<opset5::Convolution>(conv, conv_1_const, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto add_const = create_constant_with_zeros(convShape, {{}, {1, 2, 3}, {}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(conv_1, add_const);
|
||||
auto add_const_1 = create_constant_with_zeros(convShape, {{}, {1, 2, 3}, {}, {}});
|
||||
auto add_1 = std::make_shared<opset5::Add>(add, add_const_1);
|
||||
auto add_2 = std::make_shared<opset5::Add>(add_1, conv);
|
||||
|
||||
auto bad_add_const = create_constant_with_zeros(convShape, {{}, {}, {}, {}});
|
||||
auto bad_add = std::make_shared<opset5::Add>(add_2, bad_add_const);
|
||||
|
||||
auto weights_end_conv = create_constant_with_zeros(weightsShape, {{1, 2, 3}, {1, 2, 3}, {}, {}});
|
||||
auto last_conv = std::make_shared<opset5::Convolution>(bad_add, weights_end_conv, Strides(2, 1),
|
||||
CoordinateDiff(2, 0),
|
||||
CoordinateDiff(2, 0),
|
||||
Strides(2, 1));
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvUpShort.svg").run_on_function(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
m.run_passes(function);
|
||||
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(weights_end_conv.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(last_conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
|
||||
const auto linear_input_features = 62 * 62;
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{6, 3, 3, 3};
|
||||
Shape weights_linear_shape{linear_input_features, 100};
|
||||
Shape weights_last_linear_shape{100, 10};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
// Do net search 0 dim zeros by now
|
||||
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{3, 4, 5}, {}});
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
|
||||
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros({
|
||||
weights_linear_shape[0],
|
||||
weights_linear_shape[1] - 3,
|
||||
}, {{}, {}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
auto weights_last_linear = create_constant_with_zeros({
|
||||
weights_last_linear_shape[0] - 3,
|
||||
weights_last_linear_shape[1],
|
||||
}, {{}, {}});
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear);
|
||||
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulColsStopRowsUp.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(reshape_const->output(0)), Mask{{}, {}, {}});
|
||||
compare_masks(*getMask(reshape->output(0)), Mask{{}, {}, {}});
|
||||
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask({{}, {0, 1, 2}}));
|
||||
compare_masks(*getMask(linear->output(0)), Mask{{}, {}, {0, 1, 2}});
|
||||
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{0, 1, 2}, {}});
|
||||
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
|
||||
// Checks rows matmul pruning + transpose input in matmul
|
||||
const auto linear_input_features = 62 * 62;
|
||||
Shape input_shape{1, 3, 64, 64};
|
||||
Shape weights_shape{6, 3, 3, 3};
|
||||
Shape weights_linear_shape{linear_input_features, 100};
|
||||
Shape weights_last_linear_shape{10, 6};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 6, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{3, 4, 5}, {3, 4}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
// Do net search this zeros by now
|
||||
auto weights_last_linear = create_constant_with_zeros(weights_last_linear_shape, {{}, {3, 4, 5}});
|
||||
// To prune rows we should transpose featuremap. Did it by transpose_a = true MatMul constructor attr
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear, true, true);
|
||||
function = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros({
|
||||
weights_shape[0] - 3,
|
||||
weights_shape[1],
|
||||
weights_shape[2],
|
||||
weights_shape[3],
|
||||
}, {{}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{3}, {1, 3, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros({
|
||||
weights_linear_shape[0],
|
||||
weights_linear_shape[1],
|
||||
}, {{}, {}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(reshape, weights_linear);
|
||||
|
||||
auto weights_last_linear = create_constant_with_zeros({weights_last_linear_shape[0],
|
||||
weights_last_linear_shape[1] - 3}, {{}, {}});
|
||||
// To prune rows we should transpose featuremap. Did it by transpose_a = true MatMul constructor attr
|
||||
auto last_linear = std::make_shared<opset5::MatMul>(linear, weights_last_linear, true, true);
|
||||
function_ref = std::make_shared<Function>(NodeVector{last_linear}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulRowsStopColsUp.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{0, 1, 2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {0, 1, 2}, {}, {}}));
|
||||
|
||||
compare_masks(*getMask(reshape_const->output(0)), Mask{{}, {0, 1, 2}, {}});
|
||||
compare_masks(*getMask(reshape->output(0)), Mask{{}, {0, 1, 2}, {}});
|
||||
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask{{}, {}});
|
||||
compare_masks(*getMask(linear->output(0)), Mask{{}, {0, 1, 2}, {}});
|
||||
compare_masks(*getMask(weights_last_linear.get_node_shared_ptr()->output(0)), Mask{{}, {0, 1, 2}});
|
||||
compare_masks(*getMask(last_linear->output(0)), Mask{{}, {}, {}});
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, PropagateFlattenUp) {
|
||||
// Propagate Flatten down is the same as in
|
||||
// PruneLinearIsClosingAndInGroup test
|
||||
using nested_vector = std::vector<std::set<uint64_t>>;
|
||||
constexpr auto linear_input_features = 6 * 8 * 8;
|
||||
Shape input_shape{1, 3, 8, 8};
|
||||
Shape weights_shape{6, 3, 1, 1};
|
||||
Shape weights_linear_shape{linear_input_features, 100};
|
||||
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros(weights_shape, {{0, 1, 2}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, linear_input_features});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
// Skip just one zero in dim should lead to
|
||||
// whole dimension invalidating.
|
||||
auto add_zeros = std::set<uint64_t>();
|
||||
for (size_t i = 1; i < linear_input_features / 2; i++)
|
||||
add_zeros.insert(i);
|
||||
auto add_mask = nested_vector();
|
||||
add_mask.push_back({});
|
||||
add_mask.push_back(add_zeros);
|
||||
auto weights_add = create_constant_with_zeros({1, linear_input_features}, Mask(add_mask));
|
||||
auto add = std::make_shared<opset5::Add>(reshape, weights_add);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros(weights_linear_shape, {{}, {0, 1, 2}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(add, weights_linear);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{linear}, ParameterVector{input});
|
||||
{
|
||||
auto input = std::make_shared<opset5::Parameter>(element::f32, input_shape);
|
||||
auto weights = create_constant_with_zeros({
|
||||
weights_shape[0] - 2,
|
||||
weights_shape[1],
|
||||
weights_shape[2],
|
||||
weights_shape[3],
|
||||
}, {{}, {}, {}, {}});
|
||||
auto conv = std::make_shared<opset5::Convolution>(input, weights, Strides(2, 1),
|
||||
CoordinateDiff(2, 0), CoordinateDiff(2, 0), Strides(2, 1));
|
||||
auto relu = std::make_shared<opset5::Relu>(conv);
|
||||
|
||||
auto reshape_const = opset5::Constant::create(element::i64, Shape{2}, {1, 2 * linear_input_features / 3});
|
||||
auto reshape = std::make_shared<opset5::Reshape>(relu, reshape_const, true);
|
||||
|
||||
auto weights_add = create_constant_with_zeros({1, 2 * linear_input_features / 3}, Mask{{}, {}});
|
||||
auto add = std::make_shared<opset5::Add>(reshape, weights_add);
|
||||
|
||||
auto weights_linear = create_constant_with_zeros({
|
||||
2 * weights_linear_shape[0] / 3,
|
||||
weights_linear_shape[1],
|
||||
}, {{}, {}});
|
||||
auto linear = std::make_shared<opset5::MatMul>(add, weights_linear);
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{linear}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateFlattenUp.svg").run_on_function(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
m.register_pass<pass::PropagateMasks>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
compare_masks(*getMask(weights.get_node_shared_ptr()->output(0)), Mask({{1, 2}, {}, {}, {}}));
|
||||
compare_masks(*getMask(conv->output(0)), Mask({{}, {1, 2}, {}, {}}));
|
||||
compare_masks(*getMask(relu->output(0)), Mask({{}, {1, 2}, {}, {}}));
|
||||
auto ref_flatten_mask = std::set<uint64_t>();
|
||||
for (uint64_t i = linear_input_features / 6; i < linear_input_features / 2; ++i)
|
||||
ref_flatten_mask.insert(i);
|
||||
|
||||
auto reshape_ref_mask = nested_vector();
|
||||
reshape_ref_mask.push_back({});
|
||||
reshape_ref_mask.push_back(ref_flatten_mask);
|
||||
auto linear_ref_mask = nested_vector();
|
||||
linear_ref_mask.push_back(ref_flatten_mask);
|
||||
linear_ref_mask.push_back({});
|
||||
|
||||
compare_masks(*getMask(reshape_const->output(0)), Mask(reshape_ref_mask));
|
||||
compare_masks(*getMask(reshape->output(0)), Mask(reshape_ref_mask));
|
||||
compare_masks(*getMask(weights_linear.get_node_shared_ptr()->output(0)), Mask(linear_ref_mask));
|
||||
compare_masks(*getMask(linear->output(0)), Mask{{}, {}});
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
disable_rt_info_check();
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user