Files
openvino/ngraph/test/util/visitor.hpp

382 lines
16 KiB
C++

//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/factory.hpp"
#include "ngraph/runtime/host_tensor.hpp"
namespace ngraph
{
namespace test
{
class ValueHolder
{
template <typename T>
T& invalid()
{
NGRAPH_CHECK(false, "Invalid type access");
}
public:
virtual ~ValueHolder() {}
virtual operator bool&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator float&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator double&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::string&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int8_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int16_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int32_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int64_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint8_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint16_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint32_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint64_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<std::string>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<float>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<double>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int8_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int16_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int32_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int64_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<uint8_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<uint16_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint32_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint64_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator HostTensorPtr&() { NGRAPH_CHECK(false, "Invalid type access"); }
uint64_t get_index() { return m_index; }
protected:
uint64_t m_index{0};
};
template <typename T>
class ValueHolderImp : public ValueHolder
{
public:
ValueHolderImp(const T& value, uint64_t index)
: m_value(value)
{
m_index = index;
}
operator T&() override { return m_value; }
protected:
T m_value;
};
class ValueMap
{
using map_type = std::unordered_map<std::string, std::shared_ptr<ValueHolder>>;
public:
/// \brief Set to print serialization information
void set_print(bool value) { m_print = value; }
template <typename T>
void insert(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
}
template <typename T>
void insert_scalar(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print)
{
std::cerr << "SER: " << name << " = " << value << std::endl;
}
}
template <typename T>
void insert_vector(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print)
{
std::cerr << "SER: " << name << " = [";
std::string comma = "";
for (auto val : value)
{
std::cerr << comma << val;
comma = ", ";
}
std::cerr << "]" << std::endl;
}
}
template <typename T>
T& get(const std::string& name)
{
auto& value_holder = *m_values.at(name);
NGRAPH_CHECK(m_read_count++ == value_holder.get_index());
return static_cast<T&>(*m_values.at(name));
}
protected:
map_type m_values;
uint64_t m_write_count{0};
uint64_t m_read_count{0};
bool m_print{false};
};
class DeserializeAttributeVisitor : public AttributeVisitor
{
public:
DeserializeAttributeVisitor(ValueMap& value_map)
: m_values(value_map)
{
}
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
{
adapter.set(m_values.get<std::string>(name));
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override
{
adapter.set(m_values.get<bool>(name));
};
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override
{
adapter.set(m_values.get<int64_t>(name));
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override
{
adapter.set(m_values.get<double>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int8_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int8_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int16_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int16_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int32_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int32_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int64_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint8_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint8_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint16_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint16_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint32_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint64_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter) override
{
adapter.set(m_values.get<std::vector<std::string>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter) override
{
adapter.set(m_values.get<std::vector<float>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter) override
{
adapter.set(m_values.get<std::vector<double>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override
{
HostTensorPtr& data = m_values.get<HostTensorPtr>(name);
data->read(adapter.get_ptr(), adapter.size());
}
protected:
ValueMap& m_values;
};
class SerializeAttributeVisitor : public AttributeVisitor
{
public:
SerializeAttributeVisitor(ValueMap& value_map)
: m_values(value_map)
{
}
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int8_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int16_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int32_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint8_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint16_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override
{
HostTensorPtr data =
std::make_shared<HostTensor>(element::u8, Shape{adapter.size()});
data->write(adapter.get_ptr(), adapter.size());
m_values.insert(name, data);
}
protected:
ValueMap& m_values;
};
class NodeBuilder : public ValueMap, public DeserializeAttributeVisitor
{
public:
NodeBuilder()
: DeserializeAttributeVisitor(static_cast<ValueMap&>(*this))
, m_serializer(*this)
{
}
NodeBuilder(const std::shared_ptr<Node>& node)
: DeserializeAttributeVisitor(static_cast<ValueMap&>(*this))
, m_serializer(*this)
{
save_node(node);
}
void save_node(std::shared_ptr<Node> node)
{
m_node_type_info = node->get_type_info();
node->visit_attributes(m_serializer);
}
// Does not validate, since inputs aren't set
std::shared_ptr<Node> create()
{
std::shared_ptr<Node> node(FactoryRegistry<Node>::get().create(m_node_type_info));
node->visit_attributes(*this);
return node;
}
AttributeVisitor& get_node_saver() { return m_serializer; }
AttributeVisitor& get_node_loader() { return *this; }
protected:
Node::type_info_t m_node_type_info;
SerializeAttributeVisitor m_serializer;
};
}
}