Host tensor 2 vector refactor (#2443)

This commit is contained in:
Piotr Szmelczynski 2020-09-30 16:20:41 +02:00 committed by GitHub
parent fd80873fca
commit 55451266c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 100 deletions

View File

@ -274,20 +274,10 @@ bool op::v3::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs,
{
OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v3::ScatterElementsUpdate::evaluate");
int64_t axis = 0;
switch (inputs[3]->get_element_type())
{
case element::Type_t::i8: axis = inputs[3]->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis = inputs[3]->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::i32: axis = inputs[3]->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis = inputs[3]->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::u8: axis = inputs[3]->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis = inputs[3]->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis = inputs[3]->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis = inputs[3]->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis element type is not integral data type");
}
NGRAPH_CHECK(inputs[3]->get_element_type().is_integral_number(),
"axis element type is not integral data type");
int64_t axis = host_tensor_2_vector<int64_t>(inputs[3])[0];
const auto& input_rank = get_input_partial_shape(0).rank();
int64_t normalized_axis = axis;

View File

@ -222,22 +222,11 @@ namespace
const int64_t num_splits,
const Node* split_node)
{
int64_t axis;
switch (axis_tensor->get_element_type())
{
case element::Type_t::i32: axis = read_vector<int32_t>(axis_tensor)[0]; break;
case element::Type_t::i64: axis = read_vector<int64_t>(axis_tensor)[0]; break;
case element::Type_t::u64:
axis = static_cast<int64_t>(read_vector<uint64_t>(axis_tensor)[0]);
break;
default:
NODE_VALIDATION_CHECK(split_node,
false,
"Not supported axis type: ",
axis_tensor->get_element_type(),
" during evaluate Split:v1");
break;
}
NGRAPH_CHECK(axis_tensor->get_element_type().is_integral_number(),
"axis element type is not integral data type");
int64_t axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
evaluate(data_tensor, outputs, axis, num_splits);
return true;

View File

@ -104,29 +104,10 @@ namespace
const HostTensorPtr& arg2,
const HostTensorPtr& out)
{
element::Type_t axis_type = arg2->get_element_type();
NGRAPH_CHECK(arg2->get_element_type().is_integral_number(),
"axis element type is not integral data type");
std::vector<int64_t> axis_order;
switch (axis_type)
{
case element::Type_t::i8: axis_order = get_vector<element::Type_t::i8>(arg2); break;
case element::Type_t::i16: axis_order = get_vector<element::Type_t::i16>(arg2); break;
case element::Type_t::i32: axis_order = get_vector<element::Type_t::i32>(arg2); break;
case element::Type_t::i64: axis_order = get_vector<element::Type_t::i64>(arg2); break;
case element::Type_t::u8: axis_order = get_vector<element::Type_t::u8>(arg2); break;
case element::Type_t::u16: axis_order = get_vector<element::Type_t::u16>(arg2); break;
case element::Type_t::u32: axis_order = get_vector<element::Type_t::u32>(arg2); break;
case element::Type_t::u64: axis_order = get_vector<element::Type_t::u64>(arg2); break;
default: throw ngraph_error("axis element type is not integral data type");
}
std::vector<int64_t> axis_order = host_tensor_2_vector<int64_t>(arg2);
Shape in_shape = arg1->get_shape();
AxisVector in_axis_order(shape_size(arg2->get_shape()));

View File

@ -171,57 +171,17 @@ namespace
const HostTensorVector& outputs,
const Node* split_node)
{
int64_t axis;
switch (axis_tensor->get_element_type())
{
case element::Type_t::i16: axis = read_vector<int16_t>(axis_tensor)[0]; break;
case element::Type_t::i32: axis = read_vector<int32_t>(axis_tensor)[0]; break;
case element::Type_t::i64: axis = read_vector<int64_t>(axis_tensor)[0]; break;
case element::Type_t::u64:
axis = static_cast<int64_t>(read_vector<uint64_t>(axis_tensor)[0]);
break;
default:
NODE_VALIDATION_CHECK(split_node,
false,
"Not supported axis type: ",
axis_tensor->get_element_type(),
" during evaluate Split:v1");
break;
}
NGRAPH_CHECK(axis_tensor->get_element_type().is_integral_number(),
"axis element type is not integral data type");
int64_t axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
std::vector<int64_t> split_lengths;
switch (split_lengths_tensor->get_element_type())
{
case element::Type_t::i32:
{
const auto split_lengths_i32 = read_vector<int32_t>(split_lengths_tensor);
split_lengths =
std::vector<int64_t>(std::begin(split_lengths_i32), std::end(split_lengths_i32));
break;
}
case element::Type_t::i64:
{
const auto split_lengths_i64 = read_vector<int64_t>(split_lengths_tensor);
split_lengths =
std::vector<int64_t>(std::begin(split_lengths_i64), std::end(split_lengths_i64));
break;
}
case element::Type_t::u64:
{
const auto split_lengths_u64 = read_vector<uint64_t>(split_lengths_tensor);
split_lengths =
std::vector<int64_t>(std::begin(split_lengths_u64), std::end(split_lengths_u64));
break;
}
default:
NODE_VALIDATION_CHECK(split_node,
false,
"Not supported split lengths type: ",
split_lengths_tensor->get_element_type(),
" during evaluate Split:v1");
break;
}
NGRAPH_CHECK(split_lengths_tensor->get_element_type().is_integral_number(),
"axis element type is not integral data type");
std::vector<int64_t> split_lengths = host_tensor_2_vector<int64_t>(split_lengths_tensor);
const auto data_shape = data_tensor->get_shape();
const auto neg_one = std::find(std::begin(split_lengths), std::end(split_lengths), -1);