fix warning as error, fix tests failures

This commit is contained in:
Tikhonov Ivan 2023-02-13 07:45:54 +00:00
parent 1de806f9f7
commit d9ea97bf4b

View File

@ -134,7 +134,7 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
MATCHER_SCOPE(TransposeReductionBackward);
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
auto transpose_label =
pattern::wrap_type<opset6::Transpose>({reduce_or_squeeze_label, pattern::wrap_type<opset6::Constant>()});
@ -150,11 +150,12 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
return false;
bool keep_dims = false; // squeeze always reduces number of output dimensions
// todo: support keep_dims
/*bool keep_dims = false; // squeeze always reduces number of output dimensions
if (logical_reduce)
keep_dims = logical_reduce->get_keep_dims();
else if (arithmetic_reduce)
keep_dims = arithmetic_reduce->get_keep_dims();
keep_dims = arithmetic_reduce->get_keep_dims();*/
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
@ -190,7 +191,7 @@ ov::pass::TransposeReduction::TransposeReduction() {
pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()},
pattern::consumers_count(1));
auto reduce_or_squeeze_label =
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{transpose_label, pattern::wrap_type<opset6::Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {