fix warning as error, fix tests failures
This commit is contained in:
parent
1de806f9f7
commit
d9ea97bf4b
@ -134,7 +134,7 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
|
||||
MATCHER_SCOPE(TransposeReductionBackward);
|
||||
|
||||
auto reduce_or_squeeze_label =
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
|
||||
auto transpose_label =
|
||||
pattern::wrap_type<opset6::Transpose>({reduce_or_squeeze_label, pattern::wrap_type<opset6::Constant>()});
|
||||
@ -150,11 +150,12 @@ ov::pass::TransposeReductionBackward::TransposeReductionBackward() {
|
||||
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
|
||||
return false;
|
||||
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
// todo: support keep_dims
|
||||
/*bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();*/
|
||||
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
@ -190,7 +191,7 @@ ov::pass::TransposeReduction::TransposeReduction() {
|
||||
pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
auto reduce_or_squeeze_label =
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
|
Loading…
Reference in New Issue
Block a user