[IE CLDNN] Free up first copy of weights/biases that were transferred to USM device memory (#561)

This commit is contained in:
Mikhail Letavin 2020-06-01 12:01:28 +03:00 committed by GitHub
parent 004f414b89
commit 65f62945dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 104 additions and 21 deletions

View File

@ -167,6 +167,8 @@ struct memory {
/// C API memory handle
memory_impl* get() const { return _impl; }
void reset();
private:
friend struct engine;
memory_impl* _impl;

View File

@ -140,6 +140,8 @@ public:
allocation_type type);
void clear_pool();
void clear_pool_for_network(uint32_t network_id);
void release_memory(memory_impl* memory,
const primitive_id& id);
void color_graph(const program_impl&);
void dump_memory_pool(const program_impl&, std::string&, std::string&);

View File

@ -72,26 +72,34 @@ memory memory::share_surface(const engine& engine, const layout& layout, shared_
#endif
size_t memory::count() const {
return get_layout().count();
if (_impl) return get_layout().count();
else return 0;
}
size_t memory::size() const {
return _impl->size();
if (_impl) return _impl->size();
else return 0;
}
const layout& memory::get_layout() const {
return _impl->get_layout();
if (_impl) return _impl->get_layout();
else throw std::runtime_error("empty memory object");
}
int memory::get_net_id() const {
return _impl->get_net_id();
if (_impl) return _impl->get_net_id();
else throw std::runtime_error("empty memory object");
}
bool memory::is_allocated_by(const engine& engine) const {
return _impl->is_allocated_by(*engine.get());
if (_impl) return _impl->is_allocated_by(*engine.get());
else return false;
}
bool memory::is_the_same_buffer(const memory& other) const {
if (_impl == nullptr)
return false;
if (_impl == other.get())
return true;
@ -107,7 +115,8 @@ bool memory::is_the_same_buffer(const memory& other) const {
}
shared_mem_params memory::get_internal_params() const {
return _impl->get_internal_params();
if (_impl) return _impl->get_internal_params();
else throw std::runtime_error("empty memory object");
}
memory memory::attach_impl(const cldnn::layout& layout, void* ptr, uint32_t net_id) {
@ -115,18 +124,24 @@ memory memory::attach_impl(const cldnn::layout& layout, void* ptr, uint32_t net_
}
void* memory::lock_impl() const {
return _impl->lock();
if (_impl) return _impl->lock();
else return nullptr;
}
void memory::unlock() const {
_impl->unlock();
if (_impl) _impl->unlock();
}
void memory::retain() {
_impl->add_ref();
if (_impl) _impl->add_ref();
}
void memory::release() {
_impl->release();
if (_impl) _impl->release();
}
void memory::reset() {
release();
_impl = nullptr;
}
} // namespace cldnn

View File

@ -141,6 +141,75 @@ bool memory_pool::has_conflict(const memory_set& a,
return !intersection.empty();
}
void memory_pool::release_memory(memory_impl* mem,
const primitive_id& id) {
// check nonpadded pool first
auto _layout = mem->get_layout();
auto type = mem->get_allocation_type();
auto network_id = mem->get_net_id();
{
auto range = _non_padded_pool.equal_range(_layout.bytes_count());
auto it = range.first;
while (it != range.second && it != _non_padded_pool.end()) {
if (it->second._network_id == network_id &&
it->second._type == type &&
it->second._memory.get() == mem) {
auto user_it = it->second._users.find({ id, network_id });
// normally there should be only one entry
if (user_it != it->second._users.end()) {
user_it = it->second._users.erase(user_it);
}
if (it->second._users.empty()) {
// if this was the only user of the memory, then free it up
it = _non_padded_pool.erase(it);
}
//entry found and processed - so return
return;
} else {
++it;
}
}
}
{
auto itr = _padded_pool.find(_layout);
if (itr != _padded_pool.end()) {
auto& list = itr->second;
auto list_itr = list.begin();
while (list_itr != list.end()) {
if (list_itr->_memory.get() == mem &&
list_itr->_network_id == network_id &&
list_itr->_type == type) {
auto user_it = list_itr->_users.find({ id, network_id });
// normally there should be only one entry
if (user_it != list_itr->_users.end()) {
user_it = list_itr->_users.erase(user_it);
}
if (list_itr->_users.empty()) {
// if this was the only user of the memory, then free it up
list.erase(list_itr);
}
//entry found and processed - so return
break;
} else {
list_itr++;
}
}
if (list.empty()) {
_padded_pool.erase(itr);
}
}
}
}
memory_impl::ptr memory_pool::get_from_non_padded_pool(const layout& layout,
const primitive_id& id,
uint32_t network_id,

View File

@ -736,11 +736,13 @@ void network_impl::transfer_memory_to_device(std::shared_ptr<primitive_inst> ins
if (alloc_type == allocation_type::usm_host || alloc_type == allocation_type::usm_shared) {
// Allocate and transfer memory
auto& mem_pool = inst_mem.get_engine()->get_memory_pool();
auto device_mem = inst_mem.get_engine()->allocate_memory(
inst_mem.get_layout(),
allocation_type::usm_device,
inst_mem.get_net_id());
dynamic_cast<gpu::gpu_usm&>(*device_mem).copy_from_other(dynamic_cast<gpu::gpu_usm&>(inst_mem));
mem_pool.release_memory(&inst_mem, node.id());
instance->set_output_memory(*device_mem);
}
}

View File

@ -367,7 +367,7 @@ void program_impl::build_program(bool is_internal) {
if (!is_internal)
prim_info = get_current_stage_info();
transfer_memory_to_device();
if (!is_internal) transfer_memory_to_device();
cleanup();
}
@ -523,6 +523,7 @@ void program_impl::transfer_memory_to_device() {
mem.get_net_id());
dynamic_cast<gpu::gpu_usm&>(*device_mem).copy_from_other(dynamic_cast<gpu::gpu_usm&>(mem));
data_node.attach_memory(*device_mem);
const_cast<memory&>(data_node.get_primitive()->mem).reset();
}
}
}

View File

@ -408,22 +408,14 @@ TEST(memory_pool, shared_mem_pool_diff_batches) {
auto outputs = network_first.execute();
auto dev_info = engine.get_info();
if (dev_info.supports_usm) {
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)4312);
} else {
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)3928);
}
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)3928);
topo.change_input_layout("input", input_1.get_layout());//change input layout to batch=1
network network_second(engine, topo, bo);
network_second.set_input_data("input", input_1);
auto outputs_second = network_second.execute();
if (dev_info.supports_usm) {
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)4312);
} else {
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)3928);
}
EXPECT_EQ(engine.get_max_used_device_memory_size(), (uint64_t)3928);
}
TEST(memory_pool, shared_dep_two_output) {