fix multi-axis reduce transformation (#18414)

This commit is contained in:
Aleksandr Voron
2023-07-07 08:56:09 +02:00
committed by GitHub
parent df2ed95dab
commit 2b795afc09

View File

@@ -36,17 +36,13 @@ template <class T>
static std::shared_ptr<ov::Model> createRefGraph(ov::Shape param_shape) {
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, param_shape);
std::vector<int64_t> axes = {0, 1};
ngraph::NodeVector new_ops;
std::shared_ptr<ngraph::Node> node = param;
for (auto axis : axes) {
auto reduction_axis = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});
node = std::make_shared<T>(node, reduction_axis, true);
new_ops.push_back(node);
}
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{param_shape.size()}, {1, 1, 2, 9});
auto reshape = std::make_shared<ngraph::opset1::Reshape>(node, reshape_shape, true);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{ reshape }, ngraph::ParameterVector{ param });
return std::make_shared<ngraph::Function>(ngraph::NodeVector{ node }, ngraph::ParameterVector{ param });
}
template <class T>