fix TransposeReduce transformations
This commit is contained in:
parent
43ef82320f
commit
20168b251a
@ -140,7 +140,6 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
}
|
||||
REGISTER_PASS(manager, ConvertQuantizeDequantize)
|
||||
REGISTER_PASS(manager, SimplifyShapeOfSubGraph)
|
||||
|
||||
if (!m_use_shapes) {
|
||||
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
|
||||
}
|
||||
@ -194,14 +193,14 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
|
||||
ADD_MATCHER(common_fusions, DivideFusion)
|
||||
ADD_MATCHER(common_fusions, SubtractFusion)
|
||||
ADD_MATCHER(common_fusions, TransposeToReshape)
|
||||
//ADD_MATCHER(common_fusions, TransposeToReshape)
|
||||
ADD_MATCHER(common_fusions, ReshapeSequenceFusion, m_use_shapes)
|
||||
ADD_MATCHER(common_fusions, MatMulConstTransposesExtraction)
|
||||
ADD_MATCHER(common_fusions, PReluFusion)
|
||||
ADD_MATCHER(common_fusions, DepthToSpaceFusion)
|
||||
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
|
||||
common_fusions->set_name("ov::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<Serialize>("/home/tikhonov/OpenVINO/tmp/serialized/ts_before_align_eltwise.xml", "/home/tikhonov/OpenVINO/tmp/serialized/ts_before_align_eltwise.bin");
|
||||
REGISTER_PASS(manager, BinarizeWeights)
|
||||
REGISTER_PASS(manager, ConvToBinaryConv)
|
||||
|
||||
@ -224,7 +223,6 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
ADD_MATCHER(multiply_fusions, MatMulMultiplyFusion)
|
||||
multiply_fusions->set_name("ov::pass::MultiplyFusions");
|
||||
REGISTER_PASS(manager, ConstantFolding)
|
||||
|
||||
auto fq_fusions = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeReshapeFusion)
|
||||
|
@ -21,25 +21,51 @@ using namespace ov;
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<opset6::Constant> get_reduced_order_constant(const std::shared_ptr<opset6::Constant>& axes_const,
|
||||
const std::shared_ptr<opset6::Constant>& order_const) {
|
||||
auto order = order_const->cast_vector<int64_t>();
|
||||
std::vector<size_t> get_updated_order(std::vector<size_t>& axes_values,
|
||||
std::vector<size_t>& order_values,
|
||||
bool is_forward) {
|
||||
std::sort(axes_values.begin(), axes_values.end());
|
||||
size_t buffer_size = is_forward ? order_values.size() - axes_values.size() : order_values.size() + axes_values.size();
|
||||
|
||||
auto axes = axes_const->cast_vector<int64_t>();
|
||||
std::sort(axes.rbegin(), axes.rend());
|
||||
for (const auto& i : axes)
|
||||
order.erase(order.begin() + i);
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
for (size_t i = 0, j = 0; i < std::max(aligned_order.size(), order_values.size()); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
if (is_forward) {
|
||||
continue;
|
||||
} else {
|
||||
aligned_order[i] = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& updated_order_size = static_cast<int64_t>(order.size());
|
||||
|
||||
auto order_sorted = order;
|
||||
sort(order_sorted.begin(), order_sorted.end());
|
||||
for (int64_t i = 0; i < updated_order_size; ++i) {
|
||||
auto lowest_greater_eq_i = std::lower_bound(order_sorted.begin(), order_sorted.end(), i);
|
||||
std::replace(order.begin(), order.end(), *lowest_greater_eq_i, i);
|
||||
std::replace(order_sorted.begin(), order_sorted.end(), *lowest_greater_eq_i, i);
|
||||
if (is_forward) {
|
||||
auto ub = std::upper_bound(axes_values.begin(), axes_values.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (ub - axes_values.begin());
|
||||
} else {
|
||||
auto ub = std::upper_bound(axes_values.begin(), axes_values.end(), order_values[j]);
|
||||
aligned_order[i] = order_values[j] + (ub - axes_values.begin());
|
||||
}
|
||||
++j;
|
||||
}
|
||||
return std::make_shared<opset6::Constant>(ngraph::element::i64, ngraph::Shape{order.size()}, order);
|
||||
std::cout << "new order" << std::endl;
|
||||
for (const auto& it : aligned_order) {
|
||||
std::cout << it << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
|
||||
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
return keep_dims;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset6::Constant> get_reversed_order_constant(const std::shared_ptr<opset6::Constant>& order_const) {
|
||||
@ -135,63 +161,39 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
|
||||
|
||||
auto reduce_or_squeeze_label =
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
|
||||
{pattern::any_input(), pattern::wrap_type<opset6::Constant>()}, transpose_sinking::HasSameOutputTransposeNodes);
|
||||
auto transpose_label =
|
||||
pattern::wrap_type<opset6::Transpose>({reduce_or_squeeze_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
pattern::wrap_type<opset6::Transpose>({reduce_or_squeeze_label, pattern::wrap_type<opset6::Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
|
||||
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
|
||||
return false;
|
||||
|
||||
// todo: support keep_dims
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
const auto& non_negative_axes = normalize_axes(reduction->get_friendly_name(),
|
||||
reduction_axes->cast_vector<int64_t>(),
|
||||
reduction->get_input_partial_shape(0).rank());
|
||||
|
||||
transpose->output(0).replace(reduction);
|
||||
auto non_negative_axes = normalize_axes(reduction->get_friendly_name(),
|
||||
reduction_axes->cast_vector<int64_t>(),
|
||||
reduction->get_input_partial_shape(0).rank());
|
||||
|
||||
for (const auto& it : reduction->output(0).get_target_inputs()) {
|
||||
it.get_node()->output(0).replace(reduction);
|
||||
}
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
if (!keep_dims) {
|
||||
int shift = 0;
|
||||
std::vector<size_t> aligned_order(transpose_order_values.size() + non_negative_axes.size());
|
||||
for (size_t i = 0, j = 0; j < aligned_order.size(); ++j) {
|
||||
std::cout << "XXXXXX j " << j << std::endl;
|
||||
if (std::find(non_negative_axes.begin(), non_negative_axes.end(), j) != non_negative_axes.end()) {
|
||||
aligned_order[j] = j;
|
||||
++shift;
|
||||
continue;
|
||||
}
|
||||
aligned_order[j] = transpose_order_values[i] + shift;
|
||||
++i;
|
||||
}
|
||||
|
||||
transpose_order_values = aligned_order;
|
||||
std::cout << "XXXXX : " << std::endl;
|
||||
for (const auto& it : transpose_order_values) {
|
||||
std::cout << it << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
transpose_order_values = get_updated_order(non_negative_axes, transpose_order_values, false);
|
||||
}
|
||||
|
||||
auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values);
|
||||
std::vector<int64_t> new_values;
|
||||
std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
@ -202,6 +204,7 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
|
||||
transpose->input(1).replace_source_output(new_transpose_order);
|
||||
reduction->input(1).replace_source_output(new_const);
|
||||
reduction->input(0).replace_source_output(transpose);
|
||||
register_new_node(transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -224,49 +227,41 @@ ov::pass::TransposeReduction::TransposeReduction() {
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
|
||||
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
|
||||
return false;
|
||||
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
const auto& non_negative_axes = normalize_axes(reduction->get_friendly_name(),
|
||||
auto non_negative_axes = normalize_axes(reduction->get_friendly_name(),
|
||||
reduction_axes->cast_vector<int64_t>(),
|
||||
reduction->get_input_partial_shape(0).rank());
|
||||
reduction_axes = opset6::Constant::create(ngraph::element::i64, {non_negative_axes.size()}, non_negative_axes);
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
auto new_axes =
|
||||
ov::op::util::make_try_fold<opset6::Gather>(transpose_order,
|
||||
reduction_axes,
|
||||
opset6::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
new_ops.push_back(new_axes);
|
||||
auto new_reduce = reduction->clone_with_new_inputs({transpose->input_value(0), new_axes});
|
||||
new_ops.push_back(new_reduce);
|
||||
|
||||
auto updated_order = transpose_order;
|
||||
if (!keep_dims) {
|
||||
updated_order = get_reduced_order_constant(reduction_axes, transpose_order);
|
||||
new_ops.push_back(updated_order);
|
||||
reduction->output(0).replace(transpose);
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(transpose_order_values[axis]);
|
||||
}
|
||||
auto new_transpose = register_new_node<opset6::Transpose>(new_reduce, updated_order);
|
||||
new_ops.push_back(new_transpose);
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info({reduction, transpose}, new_ops);
|
||||
ngraph::replace_node(reduction, new_transpose);
|
||||
if (!keep_dims) {
|
||||
transpose_order_values = get_updated_order(non_negative_axes, transpose_order_values, true);
|
||||
}
|
||||
std::cout << "XXXXXX TransposeReductionForward" << std::endl;
|
||||
|
||||
auto new_transpose_order = std::make_shared<opset6::Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
auto new_const = std::make_shared<opset6::Constant>(reduction_axes->get_element_type(),
|
||||
reduction_axes->get_shape(),
|
||||
new_values);
|
||||
reduction->input(0).replace_source_output(transpose->input_value(0));
|
||||
reduction->input(1).replace_source_output(new_const);
|
||||
transpose->input(1).replace_source_output(new_transpose_order);
|
||||
transpose->input(0).replace_source_output(reduction);
|
||||
register_new_node(transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -45,16 +45,26 @@ bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::M
|
||||
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
|
||||
{
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::Serialize>("/home/tikhonov/OpenVINO/tmp/serialized/ts_before_forward.xml",
|
||||
"/home/tikhonov/OpenVINO/tmp/serialized/ts_before_forward.bin");
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ov::pass::Serialize>("/home/tikhonov/OpenVINO/tmp/serialized/ts_after_forward.xml",
|
||||
"/home/tikhonov/OpenVINO/tmp/serialized/ts_after_forward.bin");
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "XXXXXX Backward start" << std::endl;
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::Serialize>("/home/tikhonov/OpenVINO/tmp/serialized/ts_before_backward.xml",
|
||||
"/home/tikhonov/OpenVINO/tmp/serialized/ts_before_backward.bin");
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ov::pass::Serialize>("/home/tikhonov/OpenVINO/tmp/serialized/ts_after_backward.xml",
|
||||
"/home/tikhonov/OpenVINO/tmp/serialized/ts_after_backward.bin");
|
||||
manager.run_passes(f);
|
||||
std::cout << "XXXXXX Backward end" << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -55,9 +55,9 @@ ov::pass::TransposeToReshape::TransposeToReshape() {
|
||||
};
|
||||
std::vector<DimensionToPosition> dims;
|
||||
for (size_t i = 0; i < input_shape_rank; ++i) {
|
||||
if (order_value[i] != static_cast<int64_t>(i)) {
|
||||
//if (order_value[i] != static_cast<int64_t>(i)) {
|
||||
dims.push_back({input_shape[order_value[i]], i});
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
|
||||
@ -79,13 +79,18 @@ ov::pass::TransposeToReshape::TransposeToReshape() {
|
||||
|
||||
Output<Node> reshape_dim;
|
||||
NodeVector new_ops;
|
||||
|
||||
std::cout << "XXXX Replace Transpose " << std::endl;
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return item.dim.is_dynamic();
|
||||
}) < 2) {
|
||||
std::vector<int64_t> reshape_value(input_shape_rank, 0);
|
||||
for (const auto& item : dims) {
|
||||
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
|
||||
std::cout << reshape_value[item.pos] << std::endl;
|
||||
}
|
||||
std::cout << "reshape value" << std::endl;
|
||||
for (const auto &it : reshape_value) {
|
||||
std::cout << it << std::endl;
|
||||
}
|
||||
reshape_dim = opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
|
||||
} else {
|
||||
@ -103,6 +108,7 @@ ov::pass::TransposeToReshape::TransposeToReshape() {
|
||||
reshape_op->set_friendly_name(transpose->get_friendly_name());
|
||||
copy_runtime_info(transpose, new_ops);
|
||||
replace_node(transpose, reshape_op);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -890,10 +890,11 @@ std::vector<size_t> ov::normalize_axes(const std::string& node_description,
|
||||
const Rank& tensor_rank) {
|
||||
std::vector<size_t> new_axes;
|
||||
new_axes.reserve(axes.size());
|
||||
std::cout << "XXXXXX2" << std::endl;
|
||||
for (const auto& axis : axes) {
|
||||
new_axes.push_back(normalize_axis(node_description, axis, tensor_rank));
|
||||
}
|
||||
|
||||
std::cout << "XXXXXX3" << std::endl;
|
||||
return new_axes;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user