Removed redundant memcpy calls. Share weights from original blob (#4259)
* Removed redundant calls of memcpy. Share weights from original blob * Fixed nGraph tests
This commit is contained in:
parent
63f3a5d99c
commit
a327b72481
@ -349,13 +349,6 @@ public:
|
|||||||
|
|
||||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void*>& adapter) override {
|
|
||||||
if (std::string(node->get_type_name()) != "Constant") {
|
|
||||||
const auto data_beg = static_cast<char*>(adapter.get_ptr());
|
|
||||||
params[name] = std::string(data_beg, adapter.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<::ngraph::Node> node;
|
std::shared_ptr<::ngraph::Node> node;
|
||||||
std::map<std::string, std::string> params;
|
std::map<std::string, std::string> params;
|
||||||
@ -394,6 +387,11 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
|
|||||||
(void)a;
|
(void)a;
|
||||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<ngraph::op::v5::Loop::SpecialBodyPorts>>(& adapter)) {
|
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<ngraph::op::v5::Loop::SpecialBodyPorts>>(& adapter)) {
|
||||||
(void)a;
|
(void)a;
|
||||||
|
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(& adapter)) {
|
||||||
|
if (std::string(node->get_type_name()) != "Constant") {
|
||||||
|
const auto data_beg = static_cast<char*>(a->get()->get_ptr());
|
||||||
|
params[name] = std::string(data_beg, a->get()->size());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
|
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
|
||||||
"Attribute adapter can not be found for " << name << " parameter";
|
"Attribute adapter can not be found for " << name << " parameter";
|
||||||
|
@ -224,6 +224,11 @@ ngraph::op::v5::Loop::SpecialBodyPorts V10Parser::XmlDeserializer::parsePurposeA
|
|||||||
}
|
}
|
||||||
|
|
||||||
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||||
|
static const std::unordered_set<std::string> skip_names = {
|
||||||
|
"input_descriptions",
|
||||||
|
"output_descriptions",
|
||||||
|
"special_body_ports"
|
||||||
|
};
|
||||||
std::string val;
|
std::string val;
|
||||||
|
|
||||||
// for TensorIterator look for 'port_map' as 'data' does not exist
|
// for TensorIterator look for 'port_map' as 'data' does not exist
|
||||||
@ -239,7 +244,7 @@ void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::Val
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
if (skip_names.count(name) && !getStrAttribute(node.child("data"), name, val)) return;
|
||||||
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
|
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
|
||||||
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
|
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
|
||||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
|
||||||
@ -292,6 +297,44 @@ void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::Val
|
|||||||
ngraph::element::dynamic, variable_id});
|
ngraph::element::dynamic, variable_id});
|
||||||
}
|
}
|
||||||
a->set(variables[variable_id]);
|
a->set(variables[variable_id]);
|
||||||
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||||
|
std::string value;
|
||||||
|
pugi::xml_node dn = node.child("data");
|
||||||
|
auto type = XMLParseUtils::GetStrAttr(node, "type");
|
||||||
|
|
||||||
|
if (dn.empty())
|
||||||
|
THROW_IE_EXCEPTION << "No attrtibutes defined for " << type << " op!";
|
||||||
|
|
||||||
|
if (getStrAttribute(dn, name, value)) {
|
||||||
|
auto buffer = std::make_shared<ngraph::runtime::AlignedBuffer>(value.size());
|
||||||
|
auto data = static_cast<char*>(buffer->get_ptr());
|
||||||
|
value.copy(data, value.size());
|
||||||
|
a->set(buffer);
|
||||||
|
} else if (name == "value" && type == "Const") {
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
std::string el_type_str;
|
||||||
|
|
||||||
|
size_t offset = XMLParseUtils::GetUInt64Attr(dn, "offset");
|
||||||
|
size_t size = XMLParseUtils::GetUInt64Attr(dn, "size");
|
||||||
|
if (!getStrAttribute(dn, "element_type", el_type_str)) return;
|
||||||
|
if (!getParameters<int64_t>(dn, "shape", shape)) return;
|
||||||
|
|
||||||
|
ngraph::element::Type el_type = details::convertPrecision(el_type_str);
|
||||||
|
|
||||||
|
size_t length = weights->byteSize();
|
||||||
|
if (!length)
|
||||||
|
THROW_IE_EXCEPTION << "Empty weights data in bin file or bin file cannot be found!";
|
||||||
|
if (length < offset + size)
|
||||||
|
THROW_IE_EXCEPTION << "Incorrect weights in bin file!";
|
||||||
|
if (size < std::ceil(ngraph::shape_size(shape) * el_type.bitwidth() / 8.f))
|
||||||
|
THROW_IE_EXCEPTION << "Attribute and shape size are inconsistent for " << type << " op!";
|
||||||
|
|
||||||
|
char* data = weights->cbuffer().as<char*>() + offset;
|
||||||
|
|
||||||
|
using SharedBuffer = ngraph::runtime::SharedBuffer<const Blob::CPtr>;
|
||||||
|
auto buffer = std::make_shared<SharedBuffer>(data, size, weights);
|
||||||
|
a->set(buffer);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
||||||
<< " parameter";
|
<< " parameter";
|
||||||
@ -742,6 +785,10 @@ std::shared_ptr<ngraph::Node> V10Parser::XmlDeserializer::createNode(
|
|||||||
if (!ngraphNode) {
|
if (!ngraphNode) {
|
||||||
THROW_IE_EXCEPTION << "Opset " << params.version << " doesn't contain the operation with type: " << type;
|
THROW_IE_EXCEPTION << "Opset " << params.version << " doesn't contain the operation with type: " << type;
|
||||||
}
|
}
|
||||||
|
// Share Weights form constant blob
|
||||||
|
if (auto constant = std::dynamic_pointer_cast<ngraph::opset6::Constant>(ngraphNode)) {
|
||||||
|
constant->alloc_buffer_on_visit_attributes(false);
|
||||||
|
}
|
||||||
ngraphNode->set_arguments(inputs);
|
ngraphNode->set_arguments(inputs);
|
||||||
XmlDeserializer visitor(node, weights, opsets, variables);
|
XmlDeserializer visitor(node, weights, opsets, variables);
|
||||||
if (ngraphNode->visit_attributes(visitor)) {
|
if (ngraphNode->visit_attributes(visitor)) {
|
||||||
|
@ -216,43 +216,6 @@ private:
|
|||||||
stringToType<double>(val, value);
|
stringToType<double>(val, value);
|
||||||
adapter.set(value);
|
adapter.set(value);
|
||||||
}
|
}
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void*>& adapter) override {
|
|
||||||
std::string value;
|
|
||||||
pugi::xml_node dn = node.child("data");
|
|
||||||
auto type = XMLParseUtils::GetStrAttr(node, "type");
|
|
||||||
|
|
||||||
if (dn.empty())
|
|
||||||
THROW_IE_EXCEPTION << "No attrtibutes defined for " << type << " op!";
|
|
||||||
|
|
||||||
if (getStrAttribute(dn, name, value)) {
|
|
||||||
auto data = static_cast<char*>(adapter.get_ptr());
|
|
||||||
size_t length = std::min(value.size(), adapter.size());
|
|
||||||
value.copy(data, length);
|
|
||||||
} else if (name == "value" && type == "Const") {
|
|
||||||
std::vector<int64_t> shape;
|
|
||||||
std::string el_type_str;
|
|
||||||
|
|
||||||
size_t offset = XMLParseUtils::GetUInt64Attr(dn, "offset");
|
|
||||||
size_t size = XMLParseUtils::GetUInt64Attr(dn, "size");
|
|
||||||
if (!getStrAttribute(dn, "element_type", el_type_str)) return;
|
|
||||||
if (!getParameters<int64_t>(dn, "shape", shape)) return;
|
|
||||||
|
|
||||||
ngraph::element::Type el_type = details::convertPrecision(el_type_str);
|
|
||||||
|
|
||||||
size_t length = weights->byteSize();
|
|
||||||
if (!length)
|
|
||||||
THROW_IE_EXCEPTION << "Empty weights data in bin file or bin file cannot be found!";
|
|
||||||
if (length < offset + size)
|
|
||||||
THROW_IE_EXCEPTION << "Incorrect weights in bin file!";
|
|
||||||
if (size < std::ceil(ngraph::shape_size(shape) * el_type.bitwidth() / 8.f))
|
|
||||||
THROW_IE_EXCEPTION << "Attribute and shape size are inconsistent for " << type << " op!";
|
|
||||||
|
|
||||||
auto data = static_cast<char*>(adapter.get_ptr());
|
|
||||||
char* weights_data = weights->cbuffer().as<char*>() + offset;
|
|
||||||
|
|
||||||
std::memcpy(data, weights_data, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int64_t>& adapter) override {
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<int64_t>& adapter) override {
|
||||||
std::string val;
|
std::string val;
|
||||||
if (!getStrAttribute(node.child("data"), name, val))
|
if (!getStrAttribute(node.child("data"), name, val))
|
||||||
|
@ -247,22 +247,15 @@ public:
|
|||||||
}
|
}
|
||||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
|
||||||
m_xml_node.append_attribute(name.c_str()).set_value(a->get()->get_info().variable_id.c_str());
|
m_xml_node.append_attribute(name.c_str()).set_value(a->get()->get_info().variable_id.c_str());
|
||||||
}
|
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||||
}
|
|
||||||
|
|
||||||
void on_adapter(const std::string& name,
|
|
||||||
ngraph::ValueAccessor<void*>& adapter) override {
|
|
||||||
if (name == "value" && translate_type_name(m_node_type_name) == "Const") {
|
if (name == "value" && translate_type_name(m_node_type_name) == "Const") {
|
||||||
using AlignedBufferAdapter =
|
const int64_t size = a->get()->size();
|
||||||
ngraph::AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>;
|
|
||||||
if (auto a = ngraph::as_type<AlignedBufferAdapter>(&adapter)) {
|
|
||||||
const int64_t size = a->size();
|
|
||||||
const int64_t offset = m_bin_data.tellp();
|
const int64_t offset = m_bin_data.tellp();
|
||||||
|
|
||||||
m_xml_node.append_attribute("offset").set_value(offset);
|
m_xml_node.append_attribute("offset").set_value(offset);
|
||||||
m_xml_node.append_attribute("size").set_value(size);
|
m_xml_node.append_attribute("size").set_value(size);
|
||||||
|
|
||||||
auto data = static_cast<const char*>(a->get_ptr());
|
auto data = static_cast<const char*>(a->get()->get_ptr());
|
||||||
m_bin_data.write(data, size);
|
m_bin_data.write(data, size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -240,16 +240,16 @@ public:
|
|||||||
class ReadAndStoreAttributes : public ngraph::AttributeVisitor, protected storage::Storage {
|
class ReadAndStoreAttributes : public ngraph::AttributeVisitor, protected storage::Storage {
|
||||||
public:
|
public:
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||||
if (auto inputs =
|
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
|
||||||
insert(name, inputs->get());
|
insert(name, inputs->get());
|
||||||
} else if (
|
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||||
auto outputs =
|
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
|
||||||
insert(name, outputs->get());
|
insert(name, outputs->get());
|
||||||
} else if (
|
} else if (auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||||
auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
|
||||||
insert(name, ports->get());
|
insert(name, ports->get());
|
||||||
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||||
|
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
||||||
|
const auto end = beg + a->get()->size();
|
||||||
|
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
||||||
} else {
|
} else {
|
||||||
m_read_result += "store attr [ ERR ]: " + name +
|
m_read_result += "store attr [ ERR ]: " + name +
|
||||||
" [drop `void` comparison which is '" + adapter.get_type_info().name +
|
" [drop `void` comparison which is '" + adapter.get_type_info().name +
|
||||||
@ -257,12 +257,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void*>& adapter) override {
|
|
||||||
const auto beg = static_cast<unsigned char*>(adapter.get_ptr());
|
|
||||||
const auto end = beg + adapter.size();
|
|
||||||
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
|
||||||
}
|
|
||||||
|
|
||||||
#define ON_ADAPTER(TYPE) \
|
#define ON_ADAPTER(TYPE) \
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
|
||||||
insert(name, adapter.get()); \
|
insert(name, adapter.get()); \
|
||||||
@ -506,27 +500,13 @@ public:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
m_visited_attributes.insert(name);
|
m_visited_attributes.insert(name);
|
||||||
if (auto inputs =
|
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
|
||||||
verify(name, inputs->get());
|
verify(name, inputs->get());
|
||||||
} else if (
|
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||||
auto outputs =
|
|
||||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
|
||||||
verify(name, outputs->get());
|
verify(name, outputs->get());
|
||||||
} else if (
|
} else if (auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||||
auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
|
||||||
verify(name, ports->get());
|
verify(name, ports->get());
|
||||||
} else {
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||||
m_cmp_result += "compare attr [ ERR ]: " + name +
|
|
||||||
" [drop `void` comparison which is '" + adapter.get_type_info().name +
|
|
||||||
"']";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void*>& adapter) override {
|
|
||||||
if (should_return()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
m_visited_attributes.insert(name);
|
m_visited_attributes.insert(name);
|
||||||
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
||||||
if (!ref_value) {
|
if (!ref_value) {
|
||||||
@ -534,11 +514,16 @@ public:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adapter.size() != ref_value->size() ||
|
if (a->get()->size() != ref_value->size() ||
|
||||||
std::memcmp(ref_value->data(), adapter.get_ptr(), ref_value->size()) != 0) {
|
std::memcmp(ref_value->data(), a->get()->get_ptr(), ref_value->size()) != 0) {
|
||||||
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
m_cmp_result += "compare attr [ ERR ]: " + name +
|
||||||
|
" [drop `void` comparison which is '" + adapter.get_type_info().name +
|
||||||
|
"']";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define ON_ADAPTER(TYPE) \
|
#define ON_ADAPTER(TYPE) \
|
||||||
|
@ -470,6 +470,14 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
std::string convert_value_to_string(size_t index) const;
|
std::string convert_value_to_string(size_t index) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Allows to avoid buffer allocation on the visit_attributes call
|
||||||
|
*/
|
||||||
|
void alloc_buffer_on_visit_attributes(bool val)
|
||||||
|
{
|
||||||
|
m_alloc_buffer_on_visit_attributes = val;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <typename IN_T, typename OUT_T>
|
template <typename IN_T, typename OUT_T>
|
||||||
void cast_vector(std::vector<OUT_T>& output_vector) const
|
void cast_vector(std::vector<OUT_T>& output_vector) const
|
||||||
@ -591,6 +599,7 @@ namespace ngraph
|
|||||||
std::shared_ptr<runtime::AlignedBuffer> m_data;
|
std::shared_ptr<runtime::AlignedBuffer> m_data;
|
||||||
bool m_all_elements_bitwise_identical;
|
bool m_all_elements_bitwise_identical;
|
||||||
bool are_all_data_elements_bitwise_identical() const;
|
bool are_all_data_elements_bitwise_identical() const;
|
||||||
|
bool m_alloc_buffer_on_visit_attributes = true;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
using v0::Constant;
|
using v0::Constant;
|
||||||
|
@ -75,17 +75,13 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>
|
class NGRAPH_API AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>
|
||||||
: public ValueAccessor<void*>
|
: public DirectValueAccessor<std::shared_ptr<runtime::AlignedBuffer>>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
AttributeAdapter(std::shared_ptr<runtime::AlignedBuffer>& value);
|
AttributeAdapter(std::shared_ptr<runtime::AlignedBuffer>& value);
|
||||||
void* get_ptr() override;
|
|
||||||
size_t size() override;
|
|
||||||
|
|
||||||
static constexpr DiscreteTypeInfo type_info{
|
static constexpr DiscreteTypeInfo type_info{
|
||||||
"AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>", 0};
|
"AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>", 0};
|
||||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||||
protected:
|
|
||||||
std::shared_ptr<runtime::AlignedBuffer>& m_ref;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -628,9 +628,13 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
|
|||||||
bool op::v0::Constant::visit_attributes(AttributeVisitor& visitor)
|
bool op::v0::Constant::visit_attributes(AttributeVisitor& visitor)
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v0_Constant_visit_attributes);
|
NGRAPH_OP_SCOPE(v0_Constant_visit_attributes);
|
||||||
|
Shape prev_shape = m_shape;
|
||||||
|
element::Type prev_type = m_element_type;
|
||||||
visitor.on_attribute("element_type", m_element_type);
|
visitor.on_attribute("element_type", m_element_type);
|
||||||
visitor.on_attribute("shape", m_shape);
|
visitor.on_attribute("shape", m_shape);
|
||||||
if (m_data == nullptr)
|
|
||||||
|
bool need_to_reallocate = (m_shape != prev_shape || prev_type != m_element_type);
|
||||||
|
if (m_alloc_buffer_on_visit_attributes && need_to_reallocate)
|
||||||
{
|
{
|
||||||
// Filling in a fresh constant
|
// Filling in a fresh constant
|
||||||
allocate_buffer();
|
allocate_buffer();
|
||||||
|
@ -87,13 +87,7 @@ namespace ngraph
|
|||||||
|
|
||||||
AttributeAdapter<shared_ptr<runtime::AlignedBuffer>>::AttributeAdapter(
|
AttributeAdapter<shared_ptr<runtime::AlignedBuffer>>::AttributeAdapter(
|
||||||
shared_ptr<runtime::AlignedBuffer>& value)
|
shared_ptr<runtime::AlignedBuffer>& value)
|
||||||
: m_ref(value)
|
: DirectValueAccessor<shared_ptr<runtime::AlignedBuffer>>(value)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
void* AttributeAdapter<shared_ptr<runtime::AlignedBuffer>>::get_ptr()
|
|
||||||
{
|
|
||||||
return m_ref->get_ptr();
|
|
||||||
}
|
|
||||||
size_t AttributeAdapter<shared_ptr<runtime::AlignedBuffer>>::size() { return m_ref->size(); }
|
|
||||||
}
|
}
|
||||||
|
@ -161,9 +161,18 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
|
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
|
||||||
|
{
|
||||||
|
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<
|
||||||
|
std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter))
|
||||||
|
{
|
||||||
|
auto& data = m_values.get<HostTensorPtr>(name);
|
||||||
|
data->read(a->get()->get_ptr(), a->get()->size());
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
|
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// The remaining adapter methods fall back on the void adapter if not implemented
|
// The remaining adapter methods fall back on the void adapter if not implemented
|
||||||
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
|
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
|
||||||
{
|
{
|
||||||
@ -256,9 +265,20 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
|
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
|
||||||
|
{
|
||||||
|
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<
|
||||||
|
std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter))
|
||||||
|
{
|
||||||
|
HostTensorPtr data =
|
||||||
|
std::make_shared<HostTensor>(element::u8, Shape{a->get()->size()});
|
||||||
|
data->write(a->get()->get_ptr(), a->get()->size());
|
||||||
|
m_values.insert(name, data);
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
|
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// The remaining adapter methods fall back on the void adapter if not implemented
|
// The remaining adapter methods fall back on the void adapter if not implemented
|
||||||
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
|
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user