fix multi-axis reduce transformation (#18414)
This commit is contained in:
@@ -36,17 +36,13 @@ template <class T>
|
||||
static std::shared_ptr<ov::Model> createRefGraph(ov::Shape param_shape) {
|
||||
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, param_shape);
|
||||
std::vector<int64_t> axes = {0, 1};
|
||||
ngraph::NodeVector new_ops;
|
||||
std::shared_ptr<ngraph::Node> node = param;
|
||||
for (auto axis : axes) {
|
||||
auto reduction_axis = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});
|
||||
node = std::make_shared<T>(node, reduction_axis, true);
|
||||
new_ops.push_back(node);
|
||||
}
|
||||
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{param_shape.size()}, {1, 1, 2, 9});
|
||||
auto reshape = std::make_shared<ngraph::opset1::Reshape>(node, reshape_shape, true);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{ reshape }, ngraph::ParameterVector{ param });
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{ node }, ngraph::ParameterVector{ param });
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
Reference in New Issue
Block a user