Enabled clang-format for offline_transformations (#13410)
This commit is contained in:
parent
b00796324c
commit
54eede2e6a
@ -32,6 +32,8 @@ target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR}
|
||||
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})
|
||||
|
||||
add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME})
|
||||
|
||||
# developer package
|
||||
|
||||
openvino_developer_export_targets(COMPONENT core TARGETS ${TARGET_NAME})
|
||||
|
@ -15,10 +15,9 @@ class ZeroPointOptimizer;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
|
||||
/*
|
||||
CompressQuantizeWeights transformation goal is to pre-quantize data to minimize runtime calculations with constant data.
|
||||
To achieve this goal we perform FakeQuantize decomposition to separate quantization from dequantization in it.
|
||||
CompressQuantizeWeights transformation goal is to pre-quantize data to minimize runtime calculations with constant
|
||||
data. To achieve this goal we perform FakeQuantize decomposition to separate quantization from dequantization in it.
|
||||
|
||||
Initial graph (FakeQuantize where all inputs are Constants):
|
||||
|
||||
@ -59,7 +58,7 @@ class ZeroPointOptimizer;
|
||||
Such constant data packing reduces IR size (.bin file size) in offline transformations.
|
||||
With that we can skip same calculations in the runtime and make loading of such sub-graphs to the plugin faster.
|
||||
*/
|
||||
class ngraph::pass::CompressQuantizeWeights: public ngraph::pass::MatcherPass {
|
||||
class ngraph::pass::CompressQuantizeWeights : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("CompressQuantizeWeights", "0");
|
||||
CompressQuantizeWeights();
|
||||
@ -86,7 +85,7 @@ public:
|
||||
|
|
||||
v
|
||||
*/
|
||||
class ngraph::pass::ZeroPointOptimizer: public ngraph::pass::MatcherPass {
|
||||
class ngraph::pass::ZeroPointOptimizer : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ZeroPointOptimizer");
|
||||
ZeroPointOptimizer();
|
||||
|
@ -8,8 +8,8 @@
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "openvino/frontend/extension/decoder_transformation.hpp"
|
||||
#include "openvino/core/extension.hpp"
|
||||
#include "openvino/frontend/extension/decoder_transformation.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
@ -5,7 +5,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
@ -19,13 +18,15 @@ class GenerateMappingFile;
|
||||
/**
|
||||
* @brief Generate mapping file based on output tensor names.
|
||||
*/
|
||||
class ngraph::pass::GenerateMappingFile: public ngraph::pass::FunctionPass {
|
||||
class ngraph::pass::GenerateMappingFile : public ngraph::pass::FunctionPass {
|
||||
std::string m_path_to_file;
|
||||
bool m_extract_name;
|
||||
|
||||
public:
|
||||
OPENVINO_RTTI("GenerateMappingFile", "0");
|
||||
explicit GenerateMappingFile(const std::string & path, bool extract_name = true)
|
||||
: m_path_to_file(path), m_extract_name(extract_name) {}
|
||||
explicit GenerateMappingFile(const std::string& path, bool extract_name = true)
|
||||
: m_path_to_file(path),
|
||||
m_extract_name(extract_name) {}
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ngraph::Function>&) override;
|
||||
};
|
||||
|
@ -10,13 +10,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <ngraph/node.hpp>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
@ -25,8 +25,7 @@ namespace ngraph {
|
||||
* @brief each element in vector represents dimension and each element
|
||||
* in set is an id of dimensions which contains zeros.
|
||||
*/
|
||||
class Mask : public std::vector<std::set<uint64_t>>,
|
||||
public std::enable_shared_from_this<Mask> {
|
||||
class Mask : public std::vector<std::set<uint64_t>>, public std::enable_shared_from_this<Mask> {
|
||||
public:
|
||||
static const ::ov::DiscreteTypeInfo& get_type_info_static() {
|
||||
static const ::ov::DiscreteTypeInfo type_info_static{"Mask", 0, "0"};
|
||||
@ -37,35 +36,26 @@ public:
|
||||
|
||||
Mask() = default;
|
||||
|
||||
explicit Mask(const ngraph::PartialShape & shape)
|
||||
: std::vector<value_type>(shape.rank().get_length()) {
|
||||
}
|
||||
explicit Mask(const ngraph::PartialShape& shape) : std::vector<value_type>(shape.rank().get_length()) {}
|
||||
|
||||
explicit Mask(const size_t & size)
|
||||
: std::vector<value_type>(size) {
|
||||
}
|
||||
explicit Mask(const size_t& size) : std::vector<value_type>(size) {}
|
||||
|
||||
explicit Mask(const size_t & size, const bool adjust_value)
|
||||
: std::vector<value_type>(size),
|
||||
m_adjust_value(adjust_value) {
|
||||
}
|
||||
explicit Mask(const size_t& size, const bool adjust_value)
|
||||
: std::vector<value_type>(size),
|
||||
m_adjust_value(adjust_value) {}
|
||||
|
||||
explicit Mask(const std::vector<value_type> val)
|
||||
: std::vector<value_type>(val) {
|
||||
}
|
||||
explicit Mask(const std::vector<value_type> val) : std::vector<value_type>(val) {}
|
||||
|
||||
Mask(std::initializer_list<std::initializer_list<uint64_t>> list)
|
||||
: std::vector<value_type>() {
|
||||
for (const auto & dim_values : list) {
|
||||
Mask(std::initializer_list<std::initializer_list<uint64_t>> list) : std::vector<value_type>() {
|
||||
for (const auto& dim_values : list) {
|
||||
push_back(dim_values);
|
||||
}
|
||||
}
|
||||
|
||||
bool all_dims_are_empty() const {
|
||||
return std::all_of(begin(), end(),
|
||||
[](const value_type & value) {
|
||||
return value.empty();
|
||||
});
|
||||
return std::all_of(begin(), end(), [](const value_type& value) {
|
||||
return value.empty();
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<size_t> get_not_empty_dims() {
|
||||
@ -77,11 +67,15 @@ public:
|
||||
return not_empty_dims;
|
||||
}
|
||||
|
||||
bool is_shape_like() const { return m_is_shape_like; }
|
||||
bool is_shape_like() const {
|
||||
return m_is_shape_like;
|
||||
}
|
||||
|
||||
void set_shape_like(bool flag) { m_is_shape_like = flag; }
|
||||
void set_shape_like(bool flag) {
|
||||
m_is_shape_like = flag;
|
||||
}
|
||||
|
||||
void copy_value_from_mask(Mask *const mask) {
|
||||
void copy_value_from_mask(Mask* const mask) {
|
||||
auto cur_mask_iter = begin();
|
||||
auto mask_iter = mask->begin();
|
||||
while (cur_mask_iter != end() && mask_iter != mask->end()) {
|
||||
@ -95,7 +89,7 @@ public:
|
||||
/* Copy values from given mask in reversed order.
|
||||
param: mask - given mask.
|
||||
*/
|
||||
void copy_value_from_mask_reversed(Mask *const mask) {
|
||||
void copy_value_from_mask_reversed(Mask* const mask) {
|
||||
auto cur_mask_iter = rbegin();
|
||||
auto mask_iter = mask->rbegin();
|
||||
while (cur_mask_iter != rend() && mask_iter != mask->rend()) {
|
||||
@ -114,7 +108,9 @@ public:
|
||||
param: idx_mask - current mask dimensions indexes which will be skipped during copying.
|
||||
param invert_mask - do mask need to be inverted. Default value == false.
|
||||
*/
|
||||
void copy_value_from_mask_reversed_masked(Mask *const mask, const std::set<int64_t> &idx_mask, const bool invert_mask = false) {
|
||||
void copy_value_from_mask_reversed_masked(Mask* const mask,
|
||||
const std::set<int64_t>& idx_mask,
|
||||
const bool invert_mask = false) {
|
||||
auto cur_mask_iter = rbegin();
|
||||
auto mask_iter = mask->rbegin();
|
||||
while (cur_mask_iter != rend() && mask_iter != mask->rend()) {
|
||||
@ -132,18 +128,16 @@ public:
|
||||
param: mask - given mask.
|
||||
returns: intersected masks alligned from the end.
|
||||
*/
|
||||
Mask::Ptr intersect_masks_reversed(Mask *const mask) const {
|
||||
Mask::Ptr intersect_masks_reversed(Mask* const mask) const {
|
||||
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
|
||||
auto result_iter = result_mask->rbegin();
|
||||
auto mask_1_iter = rbegin();
|
||||
auto mask_2_iter = mask->rbegin();
|
||||
|
||||
while (mask_1_iter != rend() &&
|
||||
mask_2_iter != mask->rend() &&
|
||||
result_iter != result_mask->rend()) {
|
||||
while (mask_1_iter != rend() && mask_2_iter != mask->rend() && result_iter != result_mask->rend()) {
|
||||
// Merge mask dimension values for both masks
|
||||
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
|
||||
for (const auto & value : *mask_1_iter) {
|
||||
for (const auto& value : *mask_1_iter) {
|
||||
if (mask_2_iter->count(value)) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
@ -161,21 +155,19 @@ public:
|
||||
param: mask - given mask.
|
||||
returns: united masks alligned from the end.
|
||||
*/
|
||||
Mask::Ptr union_masks_reversed(Mask *const mask) const {
|
||||
Mask::Ptr union_masks_reversed(Mask* const mask) const {
|
||||
auto result_mask = std::make_shared<Mask>(std::max(size(), mask->size()));
|
||||
auto result_iter = result_mask->rbegin();
|
||||
auto mask_1_iter = rbegin();
|
||||
auto mask_2_iter = mask->rbegin();
|
||||
|
||||
while (mask_1_iter != rend() &&
|
||||
mask_2_iter != mask->rend() &&
|
||||
result_iter != result_mask->rend()) {
|
||||
while (mask_1_iter != rend() && mask_2_iter != mask->rend() && result_iter != result_mask->rend()) {
|
||||
// Union mask dimension values for both masks
|
||||
// Example: (MaskValue[1,2,3,4], MaskValue[2, 5]) -> MaskValue[1, 2, 3, 4, 5]
|
||||
for (const auto & value : *mask_1_iter) {
|
||||
for (const auto& value : *mask_1_iter) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
for (const auto & value : *mask_2_iter) {
|
||||
for (const auto& value : *mask_2_iter) {
|
||||
if (!result_iter->count(value)) {
|
||||
result_iter->insert(value);
|
||||
}
|
||||
@ -188,7 +180,7 @@ public:
|
||||
return result_mask;
|
||||
}
|
||||
|
||||
bool add_callback(const std::function<bool(Mask::Ptr)> & receive_callback, Mask::Ptr mask) {
|
||||
bool add_callback(const std::function<bool(Mask::Ptr)>& receive_callback, Mask::Ptr mask) {
|
||||
if (m_callbacks.find(mask.get()) != m_callbacks.end())
|
||||
NGRAPH_DEBUG << "Attempt to rewrite callback, could lead to unexpected behaviour";
|
||||
|
||||
@ -202,7 +194,7 @@ public:
|
||||
modify all dependent masks by their corresponding callbacks*/
|
||||
bool apply_callback(Mask::Ptr mask) {
|
||||
// TODO: in case if callback returns false we need to propagate original value
|
||||
const auto & ref_state = Mask(*this);
|
||||
const auto& ref_state = Mask(*this);
|
||||
// Modify this mask by recived mask
|
||||
if (!m_callbacks.at(mask.get())(shared_from_this())) {
|
||||
return false;
|
||||
@ -215,7 +207,7 @@ public:
|
||||
// Mark mask as visited
|
||||
m_need_initialization = false;
|
||||
// recursively apply callbacks for each dependent mask
|
||||
for (const auto & m_dependency : m_dependencies) {
|
||||
for (const auto& m_dependency : m_dependencies) {
|
||||
if (m_dependency == mask.get())
|
||||
continue;
|
||||
if (!m_dependency->apply_callback(shared_from_this())) {
|
||||
@ -228,7 +220,7 @@ public:
|
||||
|
||||
void invalidate() {
|
||||
clean_dim_values();
|
||||
for (const auto & d : m_dependencies) {
|
||||
for (const auto& d : m_dependencies) {
|
||||
if (d->apply_callback(shared_from_this())) {
|
||||
// TODO: throw an exception if zero dims can't be propagated
|
||||
}
|
||||
@ -236,7 +228,7 @@ public:
|
||||
}
|
||||
|
||||
void clean_dim_values() {
|
||||
for (auto & item : *this) {
|
||||
for (auto& item : *this) {
|
||||
item.clear();
|
||||
}
|
||||
}
|
||||
@ -264,29 +256,29 @@ private:
|
||||
|
||||
// Masks dependent on this mask vs methods, specifying how
|
||||
// this mask will be modifed by correspondent dependent mask
|
||||
std::map<Mask *, std::function<bool(Mask::Ptr)>> m_callbacks;
|
||||
std::map<Mask*, std::function<bool(Mask::Ptr)>> m_callbacks;
|
||||
// Vector of all dependent masks
|
||||
std::vector<Mask *> m_dependencies;
|
||||
std::vector<Mask*> m_dependencies;
|
||||
// Param used like visiting label (visited or not) during mask applying call
|
||||
bool m_need_initialization{true};
|
||||
};
|
||||
|
||||
std::ostream & operator<< (std::ostream & out, const Mask & mask);
|
||||
std::ostream& operator<<(std::ostream& out, const Mask& mask);
|
||||
|
||||
Mask::Ptr getMask(const Output<const Node> & output);
|
||||
Mask::Ptr getMask(const Output<const Node>& output);
|
||||
|
||||
Mask::Ptr getMask(const Output<Node> & output);
|
||||
Mask::Ptr getMask(const Output<Node>& output);
|
||||
|
||||
void setMask(Output<Node> output, const Mask::Ptr & mask);
|
||||
void setMask(Output<Node> output, const Mask::Ptr& mask);
|
||||
|
||||
void setMask(Input<Node> node, const Mask::Ptr & mask);
|
||||
void setMask(Input<Node> node, const Mask::Ptr& mask);
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
/* Get mask which was defined on InitMasks matcher pass*/
|
||||
Mask::Ptr getInitMask(const Output<Node> & output);
|
||||
Mask::Ptr getInitMask(const Output<Node>& output);
|
||||
|
||||
/* Set mask which was defined on InitMasks matcher pass*/
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr & mask);
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr& mask);
|
||||
#endif
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -5,9 +5,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@ -22,7 +21,7 @@ class POTTransformations;
|
||||
* executed inside POT.
|
||||
*/
|
||||
|
||||
class ngraph::pass::POTTransformations: public ngraph::pass::FunctionPass {
|
||||
class ngraph::pass::POTTransformations : public ngraph::pass::FunctionPass {
|
||||
std::string m_device;
|
||||
|
||||
public:
|
||||
|
@ -5,10 +5,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@ -20,8 +19,8 @@ class ShrinkWeights;
|
||||
|
||||
class Pruning;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
@ -42,8 +41,11 @@ public:
|
||||
class ngraph::pass::InitConstMask : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("InitConstMask", "0");
|
||||
explicit InitConstMask(const ngraph::AxisSet & dims,
|
||||
const std::function<bool(const double & value)> & condition = [](const double & value) { return value == 0; });
|
||||
explicit InitConstMask(
|
||||
const ngraph::AxisSet& dims,
|
||||
const std::function<bool(const double& value)>& condition = [](const double& value) {
|
||||
return value == 0;
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -2,31 +2,34 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <compress_quantize_weights.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
|
||||
static bool has_dequantization_subgraph(const std::shared_ptr<ngraph::Node>& first_convert) {
|
||||
auto first_convert_users = first_convert->get_users();
|
||||
const auto second_convert = std::find_if(first_convert_users.begin(), first_convert_users.end(),
|
||||
[] (const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
return ov::is_type<ngraph::opset8::Convert>(n);
|
||||
});
|
||||
const auto second_convert = std::find_if(first_convert_users.begin(),
|
||||
first_convert_users.end(),
|
||||
[](const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
return ov::is_type<ngraph::opset8::Convert>(n);
|
||||
});
|
||||
if (second_convert == first_convert_users.end())
|
||||
return false;
|
||||
auto convert_or_subtract_users = (*second_convert)->get_users();
|
||||
const auto subtract = std::find_if(convert_or_subtract_users.begin(), convert_or_subtract_users.end(),
|
||||
[] (const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
const auto subtract = std::find_if(convert_or_subtract_users.begin(),
|
||||
convert_or_subtract_users.end(),
|
||||
[](const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
return ov::is_type<ngraph::opset8::Subtract>(n);
|
||||
});
|
||||
if (subtract != convert_or_subtract_users.end()) {
|
||||
convert_or_subtract_users = (*subtract)->get_users();
|
||||
}
|
||||
const auto multiply = std::find_if(convert_or_subtract_users.begin(), convert_or_subtract_users.end(),
|
||||
[] (const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
const auto multiply = std::find_if(convert_or_subtract_users.begin(),
|
||||
convert_or_subtract_users.end(),
|
||||
[](const std::shared_ptr<ngraph::Node>& n) -> bool {
|
||||
return ov::is_type<ngraph::opset8::Multiply>(n);
|
||||
});
|
||||
return multiply != convert_or_subtract_users.end();
|
||||
@ -38,8 +41,8 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
auto input_high_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
auto output_low_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
auto output_high_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
auto fq_pattern = pattern::wrap_type<opset8::FakeQuantize>({weights_pattern, input_low_pattern, input_high_pattern,
|
||||
output_low_pattern, output_high_pattern});
|
||||
auto fq_pattern = pattern::wrap_type<opset8::FakeQuantize>(
|
||||
{weights_pattern, input_low_pattern, input_high_pattern, output_low_pattern, output_high_pattern});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto fq = std::dynamic_pointer_cast<opset8::FakeQuantize>(m.get_match_root());
|
||||
@ -86,12 +89,13 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
*/
|
||||
std::shared_ptr<Node> new_input_low;
|
||||
auto new_output_low = op::Constant::create(input_type, Shape{}, {-static_cast<float>(levels / 2)});
|
||||
auto new_output_high = std::make_shared<opset8::Add>(new_output_low, op::Constant::create(input_type, Shape{}, {levels - 1}));
|
||||
auto new_output_high =
|
||||
std::make_shared<opset8::Add>(new_output_low, op::Constant::create(input_type, Shape{}, {levels - 1}));
|
||||
const auto& weights = pattern_value_map.at(weights_pattern);
|
||||
const auto& input_low = pattern_value_map.at(input_low_pattern);
|
||||
const auto& input_high = pattern_value_map.at(input_high_pattern);
|
||||
auto quantize = fq->clone_with_new_inputs({weights, input_low, input_high,
|
||||
new_output_low, new_output_high});
|
||||
auto quantize =
|
||||
fq->clone_with_new_inputs({weights, input_low, input_high, new_output_low, new_output_high});
|
||||
// Convert quantized weights to low precision type
|
||||
std::shared_ptr<Node> new_weights = std::make_shared<opset8::Convert>(quantize, quantized_type);
|
||||
// Constant fold quantized weights
|
||||
@ -169,47 +173,53 @@ ngraph::pass::ZeroPointOptimizer::ZeroPointOptimizer() {
|
||||
const auto& pattern_value_map = m.get_pattern_value_map();
|
||||
auto convert = pattern_value_map.at(convert_pattern).get_node_shared_ptr();
|
||||
auto sub = pattern_value_map.at(sub_pattern).get_node_shared_ptr();
|
||||
auto weights = std::dynamic_pointer_cast<opset8::Constant>(pattern_value_map.at(weights_pattern).get_node_shared_ptr());
|
||||
auto weights =
|
||||
std::dynamic_pointer_cast<opset8::Constant>(pattern_value_map.at(weights_pattern).get_node_shared_ptr());
|
||||
if (!weights || weights->get_element_type() != element::i8)
|
||||
return false;
|
||||
auto zero_point = std::dynamic_pointer_cast<opset8::Constant>(pattern_value_map.at(zero_point_pattern).get_node_shared_ptr());
|
||||
auto zero_point =
|
||||
std::dynamic_pointer_cast<opset8::Constant>(pattern_value_map.at(zero_point_pattern).get_node_shared_ptr());
|
||||
if (!zero_point)
|
||||
return false;
|
||||
|
||||
auto zp_value = zero_point->cast_vector<float>();
|
||||
if (std::all_of(zp_value.begin(), zp_value.end(), [] (float f) -> bool { return std::fabs(f) <= std::numeric_limits<float>::epsilon(); })) {
|
||||
if (std::all_of(zp_value.begin(), zp_value.end(), [](float f) -> bool {
|
||||
return std::fabs(f) <= std::numeric_limits<float>::epsilon();
|
||||
})) {
|
||||
copy_runtime_info(sub, convert);
|
||||
replace_node(sub, convert);
|
||||
}
|
||||
|
||||
auto int8_zero_point = std::make_shared<opset8::Convert>(
|
||||
std::make_shared<opset8::Round>(zero_point, opset8::Round::RoundMode::HALF_TO_EVEN),
|
||||
weights->get_element_type());
|
||||
auto adj_zero_point = std::make_shared<opset8::Subtract>(zero_point, std::make_shared<opset8::Convert>(int8_zero_point, convert->get_element_type()));
|
||||
std::make_shared<opset8::Round>(zero_point, opset8::Round::RoundMode::HALF_TO_EVEN),
|
||||
weights->get_element_type());
|
||||
auto adj_zero_point = std::make_shared<opset8::Subtract>(
|
||||
zero_point,
|
||||
std::make_shared<opset8::Convert>(int8_zero_point, convert->get_element_type()));
|
||||
|
||||
auto adj_zero_point_const = ov::get_constant_from_source(adj_zero_point);
|
||||
if (!adj_zero_point_const)
|
||||
return false;
|
||||
auto adj_zero_point_val = adj_zero_point_const->cast_vector<float>();
|
||||
bool is_adj_zero_point_close_to_zero = std::all_of(adj_zero_point_val.begin(), adj_zero_point_val.end(),
|
||||
[] (float f) -> bool {
|
||||
return std::fabs(f) < 1e-4;
|
||||
});
|
||||
bool is_adj_zero_point_close_to_zero =
|
||||
std::all_of(adj_zero_point_val.begin(), adj_zero_point_val.end(), [](float f) -> bool {
|
||||
return std::fabs(f) < 1e-4;
|
||||
});
|
||||
if (!is_adj_zero_point_close_to_zero)
|
||||
return false;
|
||||
|
||||
auto transformed = std::make_shared<opset8::Subtract>(
|
||||
std::make_shared<opset8::Convert>(std::make_shared<opset8::Subtract>(weights, int8_zero_point), convert->get_element_type()),
|
||||
std::make_shared<opset8::Convert>(std::make_shared<opset8::Subtract>(weights, int8_zero_point),
|
||||
convert->get_element_type()),
|
||||
adj_zero_point);
|
||||
auto diff = std::make_shared<opset8::Subtract>(sub, transformed);
|
||||
auto diff_const = ov::get_constant_from_source(diff);
|
||||
if (!diff_const)
|
||||
return false;
|
||||
auto diff_val = diff_const->cast_vector<float>();
|
||||
bool is_transformed_and_original_equal = std::all_of(diff_val.begin(), diff_val.end(),
|
||||
[] (float f) -> bool {
|
||||
return std::fabs(f) < std::numeric_limits<float>::epsilon();
|
||||
});
|
||||
bool is_transformed_and_original_equal = std::all_of(diff_val.begin(), diff_val.end(), [](float f) -> bool {
|
||||
return std::fabs(f) < std::numeric_limits<float>::epsilon();
|
||||
});
|
||||
if (!is_transformed_and_original_equal)
|
||||
return false;
|
||||
|
||||
|
@ -2,11 +2,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "generate_mapping_file.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "generate_mapping_file.hpp"
|
||||
|
||||
#include "pugixml.hpp"
|
||||
|
||||
@ -14,8 +14,10 @@ bool ngraph::pass::GenerateMappingFile::run_on_model(const std::shared_ptr<ngrap
|
||||
pugi::xml_document xml_doc;
|
||||
pugi::xml_node root_node = xml_doc.append_child("mapping");
|
||||
|
||||
auto add_mapping = [&](const std::string & fw_name, const std::string & fw_port_name,
|
||||
const std::string & ir_name, const std::string & ir_port_name) {
|
||||
auto add_mapping = [&](const std::string& fw_name,
|
||||
const std::string& fw_port_name,
|
||||
const std::string& ir_name,
|
||||
const std::string& ir_port_name) {
|
||||
auto map_node = root_node.append_child("map");
|
||||
auto framework_node = map_node.append_child("framework");
|
||||
auto ir_node = map_node.append_child("IR");
|
||||
@ -27,24 +29,24 @@ bool ngraph::pass::GenerateMappingFile::run_on_model(const std::shared_ptr<ngrap
|
||||
ir_node.append_attribute("output_port_id").set_value(ir_port_name.c_str());
|
||||
};
|
||||
|
||||
auto extract_name = [](const std::string & port_name) -> std::string {
|
||||
auto extract_name = [](const std::string& port_name) -> std::string {
|
||||
return port_name.substr(0, port_name.find(':'));
|
||||
};
|
||||
|
||||
for (auto && node : f->get_ordered_ops()) {
|
||||
for (auto&& node : f->get_ordered_ops()) {
|
||||
uint64_t ie_port_index{node->inputs().size()};
|
||||
uint64_t ng_port_index{0};
|
||||
if (std::dynamic_pointer_cast<ov::op::v0::Result>(node))
|
||||
continue;
|
||||
for (auto && output : node->outputs()) {
|
||||
const auto & node_name = node->get_friendly_name();
|
||||
const auto & t = output.get_tensor_ptr();
|
||||
for (auto&& output : node->outputs()) {
|
||||
const auto& node_name = node->get_friendly_name();
|
||||
const auto& t = output.get_tensor_ptr();
|
||||
|
||||
for (const auto & port_name : t->get_names()) {
|
||||
for (const auto& port_name : t->get_names()) {
|
||||
add_mapping(node_name, port_name, node_name, std::to_string(ie_port_index));
|
||||
|
||||
if (m_extract_name) {
|
||||
for (auto &name : t->get_names()) {
|
||||
for (auto& name : t->get_names()) {
|
||||
add_mapping(extract_name(name), port_name, node_name, std::to_string(ie_port_index));
|
||||
}
|
||||
}
|
||||
|
@ -4,14 +4,19 @@
|
||||
|
||||
#include "extension/json_config.hpp"
|
||||
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
#include "nlohmann/json-schema.hpp"
|
||||
#include "openvino/frontend/extension/decoder_transformation.hpp"
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
#include "extension/json_transformation.hpp"
|
||||
#include "openvino/frontend/extension/decoder_transformation.hpp"
|
||||
#include "so_extension.hpp"
|
||||
|
||||
namespace {
|
||||
static const nlohmann::json validation_schema =
|
||||
R"(
|
||||
R"(
|
||||
{
|
||||
"definitions": {},
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
@ -139,7 +144,7 @@ R"(
|
||||
}
|
||||
}
|
||||
)"_json;
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::frontend;
|
||||
|
@ -2,17 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pot_transformations.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include <transformations/op_conversions/bidirectional_sequences_decomposition.hpp>
|
||||
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
|
||||
#include <transformations/op_conversions/gru_cell_decomposition.hpp>
|
||||
#include <transformations/op_conversions/lstm_cell_decomposition.hpp>
|
||||
|
||||
#include "pot_transformations.hpp"
|
||||
|
||||
bool ngraph::pass::POTTransformations::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
if (m_device == "GNA") {
|
||||
|
@ -3,34 +3,33 @@
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "pruning.hpp"
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/coordinate_transform.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
|
||||
const std::function<bool(const double & value)> & condition) {
|
||||
#include "mask_attribute.hpp"
|
||||
#include "pruning.hpp"
|
||||
|
||||
ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet& dims,
|
||||
const std::function<bool(const double& value)>& condition) {
|
||||
auto constant = pattern::wrap_type<opset6::Constant>(
|
||||
pattern::type_matches_any({element::i8, element::u8, element::f16, element::f32, element::f64}));
|
||||
pattern::type_matches_any({element::i8, element::u8, element::f16, element::f32, element::f64}));
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(m.get_match_root());
|
||||
if (!const_node) return false;
|
||||
if (!const_node)
|
||||
return false;
|
||||
|
||||
const auto & shape = const_node->get_shape();
|
||||
const auto & values = const_node->cast_vector<double>();
|
||||
const auto& shape = const_node->get_shape();
|
||||
const auto& values = const_node->cast_vector<double>();
|
||||
|
||||
auto mask = std::make_shared<Mask>(shape);
|
||||
|
||||
for (const auto & dim : dims) {
|
||||
for (const auto& dim : dims) {
|
||||
if (dim >= shape.size()) {
|
||||
NGRAPH_DEBUG << "[WARNING] Attemt to initialize masks on " << dim
|
||||
<< " dimension which is out of shape " << shape
|
||||
<< " for node (" << const_node->get_friendly_name() << ")";
|
||||
NGRAPH_DEBUG << "[WARNING] Attemt to initialize masks on " << dim << " dimension which is out of shape "
|
||||
<< shape << " for node (" << const_node->get_friendly_name() << ")";
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -44,7 +43,7 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
|
||||
bool skip_dim_value = false;
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
CoordinateTransform iter(shape, begin, end);
|
||||
for (const Coordinate & coord : iter) {
|
||||
for (const Coordinate& coord : iter) {
|
||||
if (!condition(values.at(iter.index(coord)))) {
|
||||
skip_dim_value = true;
|
||||
break;
|
||||
|
@ -2,14 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pruning.hpp"
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "mask_attribute.hpp"
|
||||
#include "pruning.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace init_masks {
|
||||
@ -17,9 +17,9 @@ namespace init_masks {
|
||||
class InitConvMask;
|
||||
class InitMatMulMask;
|
||||
|
||||
} // namespace init_masks
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace init_masks
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::init_masks::InitConvMask : public MatcherPass {
|
||||
public:
|
||||
@ -29,8 +29,8 @@ public:
|
||||
auto conv = pattern::wrap_type<opset6::Convolution, opset6::GroupConvolution>({input, weights});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_output = pattern_map.at(conv);
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto& m_output = pattern_map.at(conv);
|
||||
|
||||
// Initializing weights mask:
|
||||
// 1. Looking for Const node with weights
|
||||
@ -42,13 +42,13 @@ public:
|
||||
cur_node = cur_node->get_input_node_shared_ptr(0);
|
||||
}
|
||||
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for Convolution: " <<
|
||||
m_output.get_node()->get_friendly_name() << std::endl;
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for Convolution: "
|
||||
<< m_output.get_node()->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2. Init mask for Const node
|
||||
InitConstMask({0}/* check only output channels dim */).apply(cur_node);
|
||||
InitConstMask({0} /* check only output channels dim */).apply(cur_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -57,7 +57,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class ngraph::pass::init_masks::InitMatMulMask : public MatcherPass {
|
||||
public:
|
||||
InitMatMulMask() {
|
||||
@ -66,9 +65,11 @@ public:
|
||||
auto matmul_pattern = pattern::wrap_type<opset6::MatMul>({a, b});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & matmul = std::dynamic_pointer_cast<opset6::MatMul>(pattern_map.at(matmul_pattern).get_node_shared_ptr());
|
||||
if (!matmul) return false;
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto& matmul =
|
||||
std::dynamic_pointer_cast<opset6::MatMul>(pattern_map.at(matmul_pattern).get_node_shared_ptr());
|
||||
if (!matmul)
|
||||
return false;
|
||||
|
||||
// Assume constant always in the first input port.
|
||||
// Initializing weights mask:
|
||||
@ -76,7 +77,8 @@ public:
|
||||
NodeVector weights_calculation_nodes;
|
||||
auto cur_node = matmul->get_input_node_shared_ptr(1);
|
||||
|
||||
if (cur_node->get_output_partial_shape(0).is_dynamic()) return false;
|
||||
if (cur_node->get_output_partial_shape(0).is_dynamic())
|
||||
return false;
|
||||
const auto input_size = cur_node->get_output_shape(0).size();
|
||||
auto dim_order = std::vector<int64_t>(input_size);
|
||||
std::iota(dim_order.begin(), dim_order.end(), 0);
|
||||
@ -85,9 +87,11 @@ public:
|
||||
weights_calculation_nodes.push_back(cur_node);
|
||||
if (ngraph::is_type<opset6::Transpose>(cur_node)) {
|
||||
const auto forward_order = get_constant_from_source(cur_node->get_input_node_shared_ptr(1));
|
||||
if (!forward_order) return false;
|
||||
if (!forward_order)
|
||||
return false;
|
||||
const auto forward_order_vec = forward_order->cast_vector<int64_t>();
|
||||
if (forward_order_vec.size() != input_size) return false;
|
||||
if (forward_order_vec.size() != input_size)
|
||||
return false;
|
||||
auto new_order = std::vector<int64_t>(forward_order_vec.size());
|
||||
for (size_t i = 0; i < forward_order_vec.size(); ++i) {
|
||||
new_order[forward_order_vec[i]] = dim_order[i];
|
||||
@ -95,32 +99,30 @@ public:
|
||||
dim_order = new_order;
|
||||
} else {
|
||||
if (ngraph::is_type<opset6::Reshape>(cur_node) || ngraph::is_type<opset6::MatMul>(cur_node)) {
|
||||
NGRAPH_DEBUG << "Can't init mask for MatMul: " <<
|
||||
matmul->get_friendly_name() << " because of node " <<
|
||||
cur_node->get_friendly_name() << " in the way from weights to Matmul" << std::endl;
|
||||
NGRAPH_DEBUG << "Can't init mask for MatMul: " << matmul->get_friendly_name()
|
||||
<< " because of node " << cur_node->get_friendly_name()
|
||||
<< " in the way from weights to Matmul" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
cur_node = cur_node->get_input_node_shared_ptr(0);
|
||||
}
|
||||
if (!ngraph::is_type<opset6::Constant>(cur_node)) {
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for MatMul: " <<
|
||||
matmul->get_friendly_name() << std::endl;
|
||||
NGRAPH_DEBUG << "Can't find Constant weights for MatMul: " << matmul->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
// 2. Get constant rank to set mask on last dimension
|
||||
const auto const_op = std::dynamic_pointer_cast<opset6::Constant>(cur_node);
|
||||
const auto shape_rank = const_op->get_shape().size();
|
||||
const auto shift = (matmul->get_transpose_b())? 2 : 1;
|
||||
const auto shift = (matmul->get_transpose_b()) ? 2 : 1;
|
||||
if (shape_rank < shift) {
|
||||
NGRAPH_DEBUG << "Can't init mask for MatMul: " <<
|
||||
matmul->get_friendly_name() << std::endl;
|
||||
NGRAPH_DEBUG << "Can't init mask for MatMul: " << matmul->get_friendly_name() << std::endl;
|
||||
return false;
|
||||
}
|
||||
const auto idx = shape_rank - shift;
|
||||
const size_t outer_dim = std::find(dim_order.begin(), dim_order.end(), idx) - dim_order.begin();
|
||||
// 3. Init mask for Const node
|
||||
InitConstMask({outer_dim}/* check only outer dim */).apply(cur_node);
|
||||
InitConstMask({outer_dim} /* check only outer dim */).apply(cur_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -129,9 +131,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
ngraph::pass::InitMasks::InitMasks() {
|
||||
add_matcher<init_masks::InitConvMask>();
|
||||
add_matcher<init_masks::InitMatMulMask>();
|
||||
}
|
||||
|
||||
|
@ -2,69 +2,71 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
|
||||
#include "mask_attribute.hpp"
|
||||
#include <ostream>
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
Mask::Ptr getMask(const Output<const Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
Mask::Ptr getMask(const Output<const Node>& output) {
|
||||
auto& rtInfo = output.get_rt_info();
|
||||
|
||||
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
if (attr_it == rtInfo.end())
|
||||
return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
const auto& attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
Mask::Ptr getMask(const Output<Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
Mask::Ptr getMask(const Output<Node>& output) {
|
||||
auto& rtInfo = output.get_rt_info();
|
||||
|
||||
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
if (attr_it == rtInfo.end())
|
||||
return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
const auto& attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
void setMask(Output<Node> output, const Mask::Ptr & mask) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
void setMask(Output<Node> output, const Mask::Ptr& mask) {
|
||||
auto& rtInfo = output.get_rt_info();
|
||||
rtInfo[Mask::get_type_info_static()] = mask;
|
||||
}
|
||||
|
||||
void setMask(Input<Node> node, const Mask::Ptr & mask) {
|
||||
auto &rtInfo = node.get_rt_info();
|
||||
void setMask(Input<Node> node, const Mask::Ptr& mask) {
|
||||
auto& rtInfo = node.get_rt_info();
|
||||
rtInfo[Mask::get_type_info_static()] = mask;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
static const char g_init_mask_key[] = "InitMask";
|
||||
Mask::Ptr getInitMask(const Output<Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
Mask::Ptr getInitMask(const Output<Node>& output) {
|
||||
auto& rtInfo = output.get_rt_info();
|
||||
|
||||
const auto attr_it = rtInfo.find(g_init_mask_key);
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
if (attr_it == rtInfo.end())
|
||||
return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
const auto& attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr & mask) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr& mask) {
|
||||
auto& rtInfo = output.get_rt_info();
|
||||
auto copy_mask = std::make_shared<Mask>();
|
||||
std::copy(mask->begin(), mask->end(), std::back_inserter(*copy_mask));
|
||||
rtInfo[g_init_mask_key] = copy_mask;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::ostream & operator<< (std::ostream & out, const Mask & mask) {
|
||||
std::ostream& operator<<(std::ostream& out, const Mask& mask) {
|
||||
out << "[ ";
|
||||
for (auto & dim : mask) {
|
||||
for (auto& dim : mask) {
|
||||
out << "{";
|
||||
out << dim.size();
|
||||
// Uncomment this to print values
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,32 +3,29 @@
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "pruning.hpp"
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
template<typename T>
|
||||
#include "mask_attribute.hpp"
|
||||
#include "pruning.hpp"
|
||||
|
||||
template <typename T>
|
||||
static std::string vec_to_str(const std::vector<T> m) {
|
||||
std::ostringstream out;
|
||||
out << "[ ";
|
||||
for (const auto & val : m)
|
||||
for (const auto& val : m)
|
||||
out << val << ' ';
|
||||
out << "]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
|
||||
static bool not_empty_mask(ngraph::Mask::Ptr mask) {
|
||||
return mask && !mask->all_dims_are_empty();
|
||||
}
|
||||
|
||||
|
||||
static bool is_static_reshape_op(std::shared_ptr<ov::Node> node) {
|
||||
auto reshape_node = std::dynamic_pointer_cast<ngraph::opset6::Reshape>(node);
|
||||
if (!reshape_node)
|
||||
@ -56,13 +53,12 @@ static bool maybe_adopt_reshape_node(std::shared_ptr<ov::Node> reshape, ngraph::
|
||||
const auto shape = reshape->input_value(1);
|
||||
const auto consumers = shape.get_node()->get_output_target_inputs(0);
|
||||
if (shape.get_node()->outputs().size() != 1 || consumers.size() != 1) {
|
||||
NGRAPH_DEBUG << "Adoptation for node " << shape.get_node()->get_friendly_name()
|
||||
<< " is not supported.";
|
||||
NGRAPH_DEBUG << "Adoptation for node " << shape.get_node()->get_friendly_name() << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sub_const_vector = std::vector<int64_t>();
|
||||
for (auto & dim : *mask.get())
|
||||
for (auto& dim : *mask.get())
|
||||
sub_const_vector.push_back(dim.size());
|
||||
|
||||
const auto sub_const = ngraph::opset6::Constant::create(shape.get_element_type(), {mask->size()}, sub_const_vector);
|
||||
@ -70,15 +66,14 @@ static bool maybe_adopt_reshape_node(std::shared_ptr<ov::Node> reshape, ngraph::
|
||||
consumers.begin()->replace_source_output(sub);
|
||||
|
||||
NGRAPH_DEBUG << "Adopting values in (" << shape.get_node()->get_friendly_name() << ")"
|
||||
<< " by substracting " << vec_to_str(sub_const_vector);
|
||||
<< " by substracting " << vec_to_str(sub_const_vector);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
int64_t reduced_weights_count{0};
|
||||
int64_t total_weights_count{0};
|
||||
for (const auto & node : f->get_ordered_ops()) {
|
||||
for (const auto& node : f->get_ordered_ops()) {
|
||||
// calculate shape for every node in graph as the input shape may change
|
||||
// during Constant shrinking
|
||||
auto mask = getMask(node->output(0));
|
||||
@ -94,22 +89,26 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
|
||||
node->revalidate_and_infer_types();
|
||||
|
||||
if (!mask) continue;
|
||||
if (!mask)
|
||||
continue;
|
||||
|
||||
// TODO: constant can be shared across functions so we need to avoid consumers from other function
|
||||
auto const_node = std::dynamic_pointer_cast<opset6::Constant>(node);
|
||||
if (!const_node) continue;
|
||||
if (!const_node)
|
||||
continue;
|
||||
|
||||
const auto & const_shape = const_node->get_shape();
|
||||
const auto& const_shape = const_node->get_shape();
|
||||
total_weights_count += shape_size(const_shape);
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
if (init_mask) {
|
||||
for (size_t dim = 0; dim < init_mask->size(); ++dim) {
|
||||
auto& dim_init_set = (*init_mask)[dim];
|
||||
auto& dim_current_set = (*mask)[dim];
|
||||
if (!dim_init_set.empty() && !std::includes(dim_current_set.begin(), dim_current_set.end(),
|
||||
dim_init_set.begin(), dim_init_set.end())) {
|
||||
auto& dim_current_set = (*mask)[dim];
|
||||
if (!dim_init_set.empty() && !std::includes(dim_current_set.begin(),
|
||||
dim_current_set.end(),
|
||||
dim_init_set.begin(),
|
||||
dim_init_set.end())) {
|
||||
NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name()
|
||||
<< "\nInit mask: " << *init_mask << "\nCurrent mask: " << *mask;
|
||||
break;
|
||||
@ -124,17 +123,17 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
auto value = const_node->cast_vector<int64_t>();
|
||||
for (size_t i = 0; i < mask->size(); i++) {
|
||||
const int64_t res = value[i] - mask->at(i).size();
|
||||
new_const_value.push_back((res > 0)? res : value[i]);
|
||||
new_const_value.push_back((res > 0) ? res : value[i]);
|
||||
}
|
||||
|
||||
const auto new_const = opset6::Constant::create(const_node->get_element_type(),
|
||||
const_node->get_shape(), new_const_value);
|
||||
const auto new_const =
|
||||
opset6::Constant::create(const_node->get_element_type(), const_node->get_shape(), new_const_value);
|
||||
new_const->set_friendly_name(const_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(const_node, new_const);
|
||||
ngraph::replace_node(const_node, new_const);
|
||||
|
||||
NGRAPH_DEBUG << "Adjust value in (" << const_node->get_friendly_name() << "): "
|
||||
<< vec_to_str(value) << " to " << vec_to_str(new_const_value);
|
||||
NGRAPH_DEBUG << "Adjust value in (" << const_node->get_friendly_name() << "): " << vec_to_str(value)
|
||||
<< " to " << vec_to_str(new_const_value);
|
||||
continue;
|
||||
}
|
||||
auto last_output = const_node->output(0);
|
||||
@ -144,22 +143,25 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
// TODO: think about it
|
||||
auto res = const_node->get_shape_val();
|
||||
if (res.size() != mask->size()) {
|
||||
throw ngraph_error("Mask size (" + std::to_string(mask->size()) + ") is not equal to (" + std::to_string(res.size()) + ")");
|
||||
throw ngraph_error("Mask size (" + std::to_string(mask->size()) + ") is not equal to (" +
|
||||
std::to_string(res.size()) + ")");
|
||||
}
|
||||
for (size_t dim = 0; dim < mask->size(); ++dim) {
|
||||
res[dim] -= mask->at(dim).size();
|
||||
}
|
||||
auto new_const = opset6::Constant::create(const_node->get_element_type(), Shape{res.size()}, res);
|
||||
replace_node(const_node, new_const);
|
||||
NGRAPH_DEBUG << "Transform shape like (" << last_output.get_node()->get_friendly_name() << "): "
|
||||
<< const_node->get_shape_val() << " to " << new_const->get_shape_val() << std::endl;
|
||||
NGRAPH_DEBUG << "Transform shape like (" << last_output.get_node()->get_friendly_name()
|
||||
<< "): " << const_node->get_shape_val() << " to " << new_const->get_shape_val() << std::endl;
|
||||
new_const->set_friendly_name(const_node->get_friendly_name());
|
||||
} else {
|
||||
for (size_t dim = 0; dim < mask->size(); ++dim) {
|
||||
const auto &dim_size = mask->at(dim).size();
|
||||
if (dim_size == 0) continue;
|
||||
const auto& dim_size = mask->at(dim).size();
|
||||
if (dim_size == 0)
|
||||
continue;
|
||||
// Broadcastable 1-size dimension shouldn't be shrank with mask
|
||||
if (const_node->get_shape().at(dim) == 1 && dim_size > 1) continue;
|
||||
if (const_node->get_shape().at(dim) == 1 && dim_size > 1)
|
||||
continue;
|
||||
|
||||
// Convert dims that we want remove to dims that we need to keep
|
||||
std::vector<int64_t> dims_to_keep;
|
||||
@ -169,12 +171,14 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
}
|
||||
}
|
||||
|
||||
const auto & prev_shape = last_output.get_partial_shape();
|
||||
const auto & prev_name = last_output.get_node()->get_friendly_name();
|
||||
last_output = std::make_shared<opset6::Gather>(last_output,
|
||||
opset6::Constant::create(element::i64, Shape{dims_to_keep.size()}, dims_to_keep),
|
||||
opset6::Constant::create(element::i64, Shape{}, {dim}));
|
||||
NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to " << last_output.get_partial_shape();
|
||||
const auto& prev_shape = last_output.get_partial_shape();
|
||||
const auto& prev_name = last_output.get_node()->get_friendly_name();
|
||||
last_output = std::make_shared<opset6::Gather>(
|
||||
last_output,
|
||||
opset6::Constant::create(element::i64, Shape{dims_to_keep.size()}, dims_to_keep),
|
||||
opset6::Constant::create(element::i64, Shape{}, {dim}));
|
||||
NGRAPH_DEBUG << "Transform(" << prev_name << "): " << prev_shape << " to "
|
||||
<< last_output.get_partial_shape();
|
||||
|
||||
if (prev_shape.is_static() && last_output.get_partial_shape().is_static()) {
|
||||
reduced_weights_count += shape_size(prev_shape.get_shape()) - shape_size(last_output.get_shape());
|
||||
|
Loading…
Reference in New Issue
Block a user