[GPU] Fix outputs are not allocated in loop_inst (#20585)
* [GPU] Fix outputs are not allocated in loop_inst * Fill empty padding when the number of output paddings is less than num_outputs * Fill empty data types when the number of output data types is less than num_outputs * Modify postprocess_output_memory to set output memory without set_output_memory function * In postprocess_output_memory, get concatenated_output_mem using input_info including output idx * Modify gpu functional tests for dynamic loop to check multiple outputs of dynamic loop * update postprocessing for condition * Fix empty dimension issue for scalar value * change code to get output paddings and output data type in primitive * allocate memory for scalar data type with zero dimension * Fix mismatch issue of input layout with shape and data types in body_network * Fix output setting in post-processing * pass bytes_count to gpu_usm params * Fix condition gpu functional test issue * Revert "allocate memory for scalar data type with zero dimension" This reverts commit 2f10f3687c78406b20d52b6e37b1be2a30b4b73f. * reinterpret one dimension memory buffer to zer dimension memor buffer to avoid zero byte memory allocation issue
This commit is contained in:
parent
0139fffc18
commit
9cc1e992f4
@ -47,6 +47,11 @@ struct input_info {
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// @brief Compare
|
||||
bool operator==(const input_info& rhs) const {
|
||||
return ((pid == rhs.pid) && (idx == rhs.idx));
|
||||
}
|
||||
|
||||
primitive_id pid;
|
||||
int32_t idx;
|
||||
struct cmp {
|
||||
@ -259,6 +264,22 @@ public:
|
||||
ib >> num_outputs;
|
||||
}
|
||||
|
||||
virtual padding get_output_padding(size_t idx) const {
|
||||
if (idx < output_paddings.size()) {
|
||||
return output_paddings[idx];
|
||||
} else {
|
||||
return padding();
|
||||
}
|
||||
}
|
||||
|
||||
virtual optional_data_type get_output_data_type(size_t idx) const {
|
||||
if (idx < output_data_types.size()) {
|
||||
return output_data_types[idx];
|
||||
} else {
|
||||
return optional_data_type();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const { return {}; }
|
||||
class condition;
|
||||
|
@ -105,7 +105,7 @@ std::vector<layout> arg_max_min_inst::calc_output_layouts(arg_max_min_node const
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < desc->num_outputs; ++i) {
|
||||
auto dt = desc->output_data_types[i].value_or(input_layout.data_type);
|
||||
auto dt = desc->get_output_data_type(i).value_or(input_layout.data_type);
|
||||
layouts.push_back({output_shapes[i], dt, format::get_default_format(output_shapes[i].size())});
|
||||
}
|
||||
return layouts;
|
||||
|
@ -240,14 +240,26 @@ void condition_inst::update_output_layout() {
|
||||
auto new_layouts = _node->type()->calc_output_layouts(*_node, *_impl_params);
|
||||
if (new_layouts.empty()) {
|
||||
auto new_layout = _node->type()->calc_output_layout(*_node, *_impl_params);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->output_paddings[0], new_layout.data_padding);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->get_output_padding(0), new_layout.data_padding);
|
||||
_impl_params->output_layouts[0] = new_layout;
|
||||
} else {
|
||||
for (size_t i = 0; i != new_layouts.size(); ++i) {
|
||||
auto new_layout = new_layouts[i];
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->output_paddings[i], new_layout.data_padding);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->get_output_padding(i), new_layout.data_padding);
|
||||
_impl_params->output_layouts[i] = new_layout;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void condition_inst::postprocess_output_memory(network::ptr executed_net, cldnn::condition::branch& branch) {
|
||||
_outputs.clear();
|
||||
_outputs.resize(outputs_memory_count());
|
||||
for (auto out_mem_map : branch.output_map) {
|
||||
auto out_mem_idx = out_mem_map.first;
|
||||
auto inner_out_id = out_mem_map.second;
|
||||
auto mem_ptr = executed_net->get_output(inner_out_id).get_memory();
|
||||
_outputs[out_mem_idx] = mem_ptr;
|
||||
GPU_DEBUG_LOG << "Inner net - Outputs[" << out_mem_idx << "]" << mem_ptr->get_layout().to_short_string() << std::endl;
|
||||
}
|
||||
}
|
||||
} // namespace cldnn
|
||||
|
@ -59,13 +59,7 @@ struct condition_impl : typed_primitive_impl<condition> {
|
||||
instance.update_output_layout();
|
||||
|
||||
// Set output memory of condition_inst to inner network output memory after inner network execution
|
||||
for (auto out_mem_map : branch.output_map) {
|
||||
auto out_mem_idx = out_mem_map.first;
|
||||
auto inner_out_id = out_mem_map.second;
|
||||
auto mem_ptr = executed_net->get_output(inner_out_id).get_memory();
|
||||
instance.set_output_memory(mem_ptr, false, out_mem_idx);
|
||||
GPU_DEBUG_LOG << "Inner net - Outputs[" << out_mem_idx << "]" << mem_ptr->get_layout().to_short_string() << std::endl;
|
||||
}
|
||||
instance.postprocess_output_memory(executed_net, branch);
|
||||
|
||||
ev->set();
|
||||
return ev;
|
||||
|
@ -79,6 +79,7 @@ public:
|
||||
condition::branch get_branch_false() const { return node->get_branch_false(); }
|
||||
|
||||
void update_output_layout();
|
||||
void postprocess_output_memory(network::ptr executed_net, cldnn::condition::branch& branch);
|
||||
|
||||
private:
|
||||
network::ptr _net_true;
|
||||
|
@ -289,6 +289,7 @@ public:
|
||||
ss << "* iteration_elements : " << iteration_elements << std::endl;
|
||||
ss << "* stride : " << stride << std::endl;
|
||||
ss << "* initial_offset : " << initial_offset << std::endl;
|
||||
ss << "* input_info : " << concat_data_id.to_string() << std::endl;
|
||||
ss << "* sliced_mems :{ ";
|
||||
for (auto mem : sliced_mems) {
|
||||
ss << mem->get_layout().to_short_string() << ",";
|
||||
@ -300,6 +301,7 @@ public:
|
||||
const int64_t axis;
|
||||
std::shared_ptr<primitive_inst> concat_data_prim;
|
||||
std::shared_ptr<primitive_inst> sliced_data_prim;
|
||||
cldnn::input_info concat_data_id;
|
||||
|
||||
private:
|
||||
mutable memory::ptr concatenated_mem;
|
||||
|
@ -306,12 +306,7 @@ void loop_inst::update_input_mapped_memory() {
|
||||
}
|
||||
|
||||
void loop_inst::update_output_mapped_memory() {
|
||||
if (is_dynamic()) {
|
||||
if (!outputs_allocated()) {
|
||||
_outputs = allocate_outputs(_impl_params.get(), true, true);
|
||||
}
|
||||
}
|
||||
|
||||
OPENVINO_ASSERT(outputs_allocated(), "output buffer should be allocated");
|
||||
for (size_t i = 0; i < _output_primitive_maps.size(); ++i) {
|
||||
const auto& output_mapping = _output_primitive_maps.at(i);
|
||||
const primitive_id& external_id = output_mapping.external_id.pid;
|
||||
@ -469,6 +464,7 @@ void loop_inst::preprocess_output_memory(const int64_t trip_count) {
|
||||
if (iter == concatenated_output_mem_mappings.end()) {
|
||||
auto memory_mapping_info = create_concat_memory_map(internal_id, output_mapping, memory, trip_count);
|
||||
memory_mapping_info->concat_data_prim = get_network().get_primitive(external_id.pid);
|
||||
memory_mapping_info->concat_data_id = external_id;
|
||||
concatenated_output_mem_mappings.push_back(memory_mapping_info);
|
||||
GPU_DEBUG_LOG << i << ") generate concat output memory mapping: " << memory_mapping_info->to_string() << std::endl;
|
||||
}
|
||||
@ -702,44 +698,52 @@ void loop_inst::load(BinaryInputBuffer& ib) {
|
||||
|
||||
void loop_inst::postprocess_output_memory(bool is_dynamic) {
|
||||
if (is_dynamic) {
|
||||
std::vector<cldnn::memory::ptr> external_outputs;
|
||||
external_outputs.resize(outputs_memory_count());
|
||||
|
||||
for (size_t i = 0; i < _output_primitive_maps.size(); ++i) {
|
||||
const auto& output_mapping = _output_primitive_maps.at(i);
|
||||
const auto& external_id = output_mapping.external_id;
|
||||
const auto& internal_id = output_mapping.internal_id;
|
||||
bool output_allocated = (static_cast<size_t>(external_id.idx) < _outputs.size() && _outputs[external_id.idx] != nullptr);
|
||||
if (output_mapping.axis < 0) {
|
||||
auto internalOutputPrim = get_body_network()->get_primitive(internal_id.pid);
|
||||
auto internal_mem = internalOutputPrim->output_memory_ptr(internal_id.idx);
|
||||
if (internal_mem == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto externalOutputPrim = _network.get_primitive(external_id.pid);
|
||||
if (!externalOutputPrim->outputs_allocated()) {
|
||||
externalOutputPrim->set_output_memory(internal_mem, external_id.idx);
|
||||
OPENVINO_ASSERT(internal_mem != nullptr, "internal_mem should not be nullptr");
|
||||
if (!output_allocated) {
|
||||
external_outputs[external_id.idx] = internal_mem;
|
||||
} else {
|
||||
auto external_mem = externalOutputPrim->output_memory_ptr(external_id.idx);
|
||||
if (external_mem->get_layout() != internal_mem->get_layout()) {
|
||||
externalOutputPrim->set_output_memory(internal_mem, external_id.idx);
|
||||
} else if (external_mem != internal_mem) {
|
||||
external_mem->copy_from(get_network().get_stream(), *internal_mem);
|
||||
auto external_mem = _outputs[external_id.idx];
|
||||
if (external_mem != internal_mem) {
|
||||
if (external_mem->get_layout() != internal_mem->get_layout()) {
|
||||
external_outputs[external_id.idx] = internal_mem;
|
||||
} else {
|
||||
external_mem->copy_from(get_network().get_stream(), *internal_mem);
|
||||
external_outputs[external_id.idx] = external_mem;
|
||||
}
|
||||
} else {
|
||||
external_outputs[external_id.idx] = external_mem;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto externalOutputPrim = _network.get_primitive(external_id.pid);
|
||||
if (!externalOutputPrim->outputs_allocated() || shape_changed()) {
|
||||
if (!output_allocated || shape_changed()) {
|
||||
auto concat_layout = _impl_params->get_output_layout(external_id.idx);
|
||||
auto concat_mem = _network.get_engine().allocate_memory(concat_layout, 0);
|
||||
externalOutputPrim->set_output_memory(concat_mem, external_id.idx);
|
||||
auto concat_mem = _network.get_engine().allocate_memory(concat_layout, false);
|
||||
external_outputs[external_id.idx] = concat_mem;
|
||||
auto iter = std::find_if(concatenated_output_mem_mappings.begin(),
|
||||
concatenated_output_mem_mappings.end(),
|
||||
[&](std::shared_ptr<loop_inst::concatenated_memory_mapping> &concat_output){
|
||||
return concat_output->concat_data_prim->id() == external_id.pid;
|
||||
return concat_output->concat_data_id == external_id;
|
||||
});
|
||||
if (iter != concatenated_output_mem_mappings.end()) {
|
||||
(*iter)->update_concatenated_mem(concat_mem);
|
||||
}
|
||||
} else {
|
||||
external_outputs[external_id.idx] = _outputs[external_id.idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
_outputs = external_outputs;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < concatenated_output_mem_mappings.size(); ++i) {
|
||||
@ -776,7 +780,7 @@ void loop_inst::update_output_layout() {
|
||||
auto new_layouts = _node->type()->calc_output_layouts(*_node, *_impl_params);
|
||||
if (new_layouts.empty()) {
|
||||
auto new_layout = _node->type()->calc_output_layout(*_node, *_impl_params);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->output_paddings[0], new_layout.data_padding);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->get_output_padding(0), new_layout.data_padding);
|
||||
_impl_params->output_layouts[0] = new_layout;
|
||||
} else {
|
||||
if (_impl_params->output_layouts.size() < new_layouts.size()) {
|
||||
@ -784,7 +788,7 @@ void loop_inst::update_output_layout() {
|
||||
}
|
||||
for (size_t i = 0; i < new_layouts.size(); ++i) {
|
||||
auto new_layout = new_layouts[i];
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->output_paddings[i], new_layout.data_padding);
|
||||
new_layout.data_padding = padding::max(_node->get_primitive()->get_output_padding(i), new_layout.data_padding);
|
||||
_impl_params->output_layouts[i] = new_layout;
|
||||
}
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ void primitive_inst::update_shape() {
|
||||
|
||||
auto update_output_layout = [&](layout& layout, size_t idx) {
|
||||
auto data_padding = padding::max(_impl_params->get_output_layout(idx).data_padding, layout.data_padding);
|
||||
layout.data_padding = padding::max(_node->get_primitive()->output_paddings[idx], data_padding);
|
||||
layout.data_padding = padding::max(_node->get_primitive()->get_output_padding(idx), data_padding);
|
||||
if (_impl_params->get_output_layout(idx) != layout) {
|
||||
GPU_DEBUG_TRACE_DETAIL << id() << ": update shape: was: " << _impl_params->get_output_layout(idx).to_short_string()
|
||||
<< " now: " << layout.to_short_string() << std::endl;
|
||||
@ -1013,7 +1013,7 @@ primitive_inst::primitive_inst(network& network, program_node const& node, bool
|
||||
_mem_allocated = allocate_memory;
|
||||
if (allocate_memory) {
|
||||
// In case when output is mutable_data primitive, and other users dependencies are only used for
|
||||
// suychronization, The output memory of such primitive will be fused with mutable_data
|
||||
// synchronization, The output memory of such primitive will be fused with mutable_data
|
||||
auto users = node.get_users();
|
||||
auto user_count = users.size();
|
||||
uint32_t mutable_data_count = 0;
|
||||
|
@ -38,7 +38,7 @@ program_node::program_node(std::shared_ptr<primitive> prim, program& prog)
|
||||
num_outputs = prim->num_outputs;
|
||||
for (size_t i = 0 ; i < num_outputs; ++i) {
|
||||
layout output_layout = layout{ov::PartialShape{}, data_types::f32, format::bfyx};
|
||||
output_layout.data_padding = prim->output_paddings[i];
|
||||
output_layout.data_padding = prim->get_output_padding(i);
|
||||
output_layouts.push_back(output_layout);
|
||||
valid_output_layouts.push_back(false);
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ static cldnn::condition::branch gen_branch(ProgramBuilder& p, const std::shared_
|
||||
}
|
||||
}
|
||||
config.set_property(ov::intel_gpu::max_dynamic_batch(1));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(op->is_dynamic()));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(op->is_dynamic() || p.use_new_shape_infer()));
|
||||
|
||||
ProgramBuilder prog(internal_body, p.get_engine(), config, false, false, p.get_task_executor(), p.get_compilation_context(), true);
|
||||
branch.inner_program = prog.get_compiled_program();
|
||||
|
@ -100,12 +100,17 @@ static void create_data(ProgramBuilder& p, const ov::Shape& const_shape, const s
|
||||
p.primitive_ids[initialconstPrimID] = constPrimID;
|
||||
p.profiling_ids.push_back(initialconstPrimID);
|
||||
} else {
|
||||
if (constLayout.count() == 0) {
|
||||
// Convert zero dimension constant layout to 1 dimension to fix the issue
|
||||
// that memory allocation is failed on windows when constant layout is zero dimension.
|
||||
constLayout = cldnn::layout(ov::PartialShape({1}), constLayout.data_type, constLayout.format);
|
||||
cldnn::memory::ptr mem = nullptr;
|
||||
if (constLayout.bytes_count() > 0) {
|
||||
mem = p.get_engine().allocate_memory(constLayout, false);
|
||||
} else {
|
||||
// In the case of empty const data with {0} shape, it has zero byte.
|
||||
// To avoid zero byte memory allocation issue, reinterpret one dimension memory to zero dimension memory.
|
||||
auto one_dim_layout = cldnn::layout(ov::PartialShape({1}), constLayout.data_type, constLayout.format);
|
||||
auto one_dim_mem = p.get_engine().allocate_memory(one_dim_layout, false);
|
||||
mem = p.get_engine().reinterpret_buffer(*one_dim_mem, constLayout);
|
||||
}
|
||||
cldnn::memory::ptr mem = p.get_engine().allocate_memory(constLayout, false);
|
||||
|
||||
GPU_DEBUG_LOG << "[" << initialconstPrimID << ": constant] layout: "
|
||||
<< constLayout.to_short_string() << ", mem_ptr(" << mem << ", " << mem->size() << " bytes)"<< std::endl;
|
||||
auto& stream = p.get_engine().get_service_stream();
|
||||
|
@ -238,12 +238,12 @@ static void CreateCommonLoopOp(ProgramBuilder& p, const std::shared_ptr<ov::op::
|
||||
|
||||
SetLoopInputOutputMap(p, op, inputs, input_primitive_maps, output_primitive_maps, back_edges);
|
||||
|
||||
auto shape = is_dynamic? ngraph::Shape{1} : ngraph::Shape{1, 1, 1, 1};
|
||||
auto shape = is_dynamic? ngraph::Shape{} : ngraph::Shape{1, 1, 1, 1};
|
||||
auto prec = ngraph::element::i64;
|
||||
if (current_iteration_input_op) {
|
||||
current_iteration_input_op->set_output_type(0, prec, shape);
|
||||
current_iteration_input_op->set_partial_shape(shape);
|
||||
current_iteration_input_op->set_element_type(prec);
|
||||
OPENVINO_ASSERT(current_iteration_input_op->get_partial_shape().is_static(), "current_iteration should be static layout");
|
||||
shape = is_dynamic? current_iteration_input_op->get_partial_shape().to_shape() : shape;
|
||||
prec = current_iteration_input_op->get_element_type();
|
||||
|
||||
auto increment_value_id = current_iteration_input_op->get_friendly_name() + "_inc";
|
||||
auto increment_value_op = std::make_shared<op::v0::Constant>(prec, shape, 1);
|
||||
|
@ -260,13 +260,11 @@ protected:
|
||||
class InnerBodyType06 : public InnerBodyGenerator {
|
||||
protected:
|
||||
std::shared_ptr<ngraph::Function> generate(ov::PartialShape& input_shape, ngraph::element::Type prc) override {
|
||||
auto constant = ngraph::opset9::Constant::create(prc, ov::Shape(input_shape.rank().get_length(), 0), {2.0f});
|
||||
constant->set_friendly_name("body1_constant");
|
||||
// constant->get_rt_info().emplace(ov::pass::DisableConstantFolding::get_type_info_static(), ov::pass::DisableConstantFolding{});
|
||||
// constant->get_rt_info().emplace("can_be_folded", false);
|
||||
auto constant = ngraph::opset9::Constant::create(prc, ov::Shape(input_shape.rank().get_length(), 1), {2.0f});
|
||||
constant->set_friendly_name("body6_constant");
|
||||
auto result = std::make_shared<ngraph::opset1::Result>(constant);
|
||||
auto o_layout = result->get_layout();
|
||||
result->set_friendly_name("body1_result");
|
||||
result->set_friendly_name("body6_result");
|
||||
auto body = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector {result},
|
||||
ngraph::ParameterVector{},
|
||||
|
@ -115,7 +115,7 @@ protected:
|
||||
const auto prc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(data_prc);
|
||||
const auto inputShape = data_shapes.first;
|
||||
const auto scalarShape = ngraph::Shape{};
|
||||
init_input_shapes({data_shapes});
|
||||
init_input_shapes({data_shapes, data_shapes});
|
||||
|
||||
ngraph::ParameterVector params{};
|
||||
auto cond_input_create = [¶ms] (ngraph::element::Type prc, const ov::PartialShape &shape, int value = 0, bool is_static = false)
|
||||
@ -128,35 +128,27 @@ protected:
|
||||
return input;
|
||||
};
|
||||
|
||||
auto start = cond_input_create(prc, inputShape);
|
||||
start->set_friendly_name("start");
|
||||
auto start_add = cond_input_create(prc, inputShape, start_value);
|
||||
start_add->set_friendly_name("start_add");
|
||||
auto start_mul = cond_input_create(prc, inputShape, 1);
|
||||
start_mul->set_friendly_name("start_mul");
|
||||
auto count = cond_input_create(ngraph::element::i64, scalarShape, max_iter_num, static_iter_num);
|
||||
count->set_friendly_name("count");
|
||||
auto skip = cond_input_create(ngraph::element::boolean, scalarShape, true, static_continue_cond);
|
||||
skip->set_friendly_name("skip");
|
||||
|
||||
//
|
||||
// count skip start count skip start
|
||||
// / /
|
||||
// ___*___*____ __________*___*____ | idx | data | out |
|
||||
// | idx in | | ex_val idx in | | 0 | 7 | 7 |
|
||||
// | | / | | | / | / | | 1 | 7 | 8 |
|
||||
// | add | | less add | | 2 | 8 | 10 |
|
||||
// | | true | | | | | | 3 | 10 | 13 |
|
||||
// | | | | | | | | ~~~~~ * * * ~~~~~
|
||||
// | out cnd | | cnd out |
|
||||
// |___*____*___| |____*_____*________|
|
||||
// Full loop Dynamic exit loop
|
||||
// n_iter = count n_iter = ex_val
|
||||
//
|
||||
auto b_indx = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::i64, ngraph::Shape{});
|
||||
b_indx->set_friendly_name("body_index");
|
||||
auto b_data = std::make_shared<ngraph::opset5::Parameter>(prc, inputShape);
|
||||
b_data->set_friendly_name("body_data");
|
||||
auto b_data_add = std::make_shared<ngraph::opset5::Parameter>(prc, inputShape);
|
||||
b_data_add->set_friendly_name("b_data_add");
|
||||
auto b_data_mul = std::make_shared<ngraph::opset5::Parameter>(prc, inputShape);
|
||||
b_data_mul->set_friendly_name("b_data_mul");
|
||||
auto b_indx_cast = std::make_shared<ngraph::opset5::Convert>(b_indx, prc);
|
||||
b_indx_cast->set_friendly_name("body_index_cast");
|
||||
auto b_add = std::make_shared<ngraph::opset5::Add>(b_data, b_indx_cast);
|
||||
b_add->set_friendly_name("body_addition");
|
||||
auto b_add = std::make_shared<ngraph::opset5::Add>(b_data_add, b_indx_cast);
|
||||
b_add->set_friendly_name("body_add");
|
||||
auto b_mul = std::make_shared<ngraph::opset5::Multiply>(b_data_mul, b_indx_cast);
|
||||
b_mul->set_friendly_name("body_mul");
|
||||
|
||||
std::shared_ptr<ngraph::Node> b_cond;
|
||||
if (dynamic_exit == -1) {
|
||||
@ -170,22 +162,32 @@ protected:
|
||||
}
|
||||
|
||||
auto body = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector {b_cond, b_add}, // TODO: check with reverse
|
||||
ngraph::ParameterVector {b_indx, b_data}); // TODO: check with reverse
|
||||
ngraph::OutputVector {b_cond, b_add, b_mul}, // TODO: check with reverse
|
||||
ngraph::ParameterVector {b_indx, b_data_add, b_data_mul}); // TODO: check with reverse
|
||||
body->set_friendly_name("body_network");
|
||||
|
||||
auto loop = std::make_shared<ngraph::opset5::Loop>(count, skip);
|
||||
loop->set_friendly_name("loop");
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports({0, 0});
|
||||
loop->set_merged_input(b_data, start, b_add);
|
||||
if (axis == -1)
|
||||
loop->set_merged_input(b_data_add, start_add, b_add);
|
||||
loop->set_merged_input(b_data_mul, start_mul, b_mul);
|
||||
if (axis == -1) {
|
||||
loop->get_iter_value(b_add, -1);
|
||||
else
|
||||
loop->get_iter_value(b_mul, -1);
|
||||
} else {
|
||||
loop->get_concatenated_slices(b_add, 0, 1, 1, -1, axis);
|
||||
loop->get_concatenated_slices(b_mul, 0, 1, 1, -1, axis);
|
||||
}
|
||||
|
||||
ngraph::ResultVector results;
|
||||
for (size_t i = 0; i < loop->get_output_size(); i++) {
|
||||
auto res = std::make_shared<ngraph::opset4::Result>(loop->output(i));
|
||||
res->set_friendly_name("loop_output_" + std::to_string(i));
|
||||
results.push_back(res);
|
||||
}
|
||||
function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector {loop},
|
||||
results,
|
||||
params);
|
||||
function->set_friendly_name("outer_body_network");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user