Host tensor 2 vector refactor (#2443)
This commit is contained in:
parent
fd80873fca
commit
55451266c6
@ -274,20 +274,10 @@ bool op::v3::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs,
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v3::ScatterElementsUpdate::evaluate");
|
||||
|
||||
int64_t axis = 0;
|
||||
switch (inputs[3]->get_element_type())
|
||||
{
|
||||
case element::Type_t::i8: axis = inputs[3]->get_data_ptr<element::Type_t::i8>()[0]; break;
|
||||
case element::Type_t::i16: axis = inputs[3]->get_data_ptr<element::Type_t::i16>()[0]; break;
|
||||
case element::Type_t::i32: axis = inputs[3]->get_data_ptr<element::Type_t::i32>()[0]; break;
|
||||
case element::Type_t::i64: axis = inputs[3]->get_data_ptr<element::Type_t::i64>()[0]; break;
|
||||
case element::Type_t::u8: axis = inputs[3]->get_data_ptr<element::Type_t::u8>()[0]; break;
|
||||
case element::Type_t::u16: axis = inputs[3]->get_data_ptr<element::Type_t::u16>()[0]; break;
|
||||
case element::Type_t::u32: axis = inputs[3]->get_data_ptr<element::Type_t::u32>()[0]; break;
|
||||
case element::Type_t::u64: axis = inputs[3]->get_data_ptr<element::Type_t::u64>()[0]; break;
|
||||
default: throw ngraph_error("axis element type is not integral data type");
|
||||
}
|
||||
NGRAPH_CHECK(inputs[3]->get_element_type().is_integral_number(),
|
||||
"axis element type is not integral data type");
|
||||
|
||||
int64_t axis = host_tensor_2_vector<int64_t>(inputs[3])[0];
|
||||
const auto& input_rank = get_input_partial_shape(0).rank();
|
||||
int64_t normalized_axis = axis;
|
||||
|
||||
|
@ -222,22 +222,11 @@ namespace
|
||||
const int64_t num_splits,
|
||||
const Node* split_node)
|
||||
{
|
||||
int64_t axis;
|
||||
switch (axis_tensor->get_element_type())
|
||||
{
|
||||
case element::Type_t::i32: axis = read_vector<int32_t>(axis_tensor)[0]; break;
|
||||
case element::Type_t::i64: axis = read_vector<int64_t>(axis_tensor)[0]; break;
|
||||
case element::Type_t::u64:
|
||||
axis = static_cast<int64_t>(read_vector<uint64_t>(axis_tensor)[0]);
|
||||
break;
|
||||
default:
|
||||
NODE_VALIDATION_CHECK(split_node,
|
||||
false,
|
||||
"Not supported axis type: ",
|
||||
axis_tensor->get_element_type(),
|
||||
" during evaluate Split:v1");
|
||||
break;
|
||||
}
|
||||
NGRAPH_CHECK(axis_tensor->get_element_type().is_integral_number(),
|
||||
"axis element type is not integral data type");
|
||||
|
||||
int64_t axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
|
||||
|
||||
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
|
||||
evaluate(data_tensor, outputs, axis, num_splits);
|
||||
return true;
|
||||
|
@ -104,29 +104,10 @@ namespace
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
element::Type_t axis_type = arg2->get_element_type();
|
||||
NGRAPH_CHECK(arg2->get_element_type().is_integral_number(),
|
||||
"axis element type is not integral data type");
|
||||
|
||||
std::vector<int64_t> axis_order;
|
||||
switch (axis_type)
|
||||
{
|
||||
case element::Type_t::i8: axis_order = get_vector<element::Type_t::i8>(arg2); break;
|
||||
|
||||
case element::Type_t::i16: axis_order = get_vector<element::Type_t::i16>(arg2); break;
|
||||
|
||||
case element::Type_t::i32: axis_order = get_vector<element::Type_t::i32>(arg2); break;
|
||||
|
||||
case element::Type_t::i64: axis_order = get_vector<element::Type_t::i64>(arg2); break;
|
||||
|
||||
case element::Type_t::u8: axis_order = get_vector<element::Type_t::u8>(arg2); break;
|
||||
|
||||
case element::Type_t::u16: axis_order = get_vector<element::Type_t::u16>(arg2); break;
|
||||
|
||||
case element::Type_t::u32: axis_order = get_vector<element::Type_t::u32>(arg2); break;
|
||||
|
||||
case element::Type_t::u64: axis_order = get_vector<element::Type_t::u64>(arg2); break;
|
||||
|
||||
default: throw ngraph_error("axis element type is not integral data type");
|
||||
}
|
||||
std::vector<int64_t> axis_order = host_tensor_2_vector<int64_t>(arg2);
|
||||
|
||||
Shape in_shape = arg1->get_shape();
|
||||
AxisVector in_axis_order(shape_size(arg2->get_shape()));
|
||||
|
@ -171,57 +171,17 @@ namespace
|
||||
const HostTensorVector& outputs,
|
||||
const Node* split_node)
|
||||
{
|
||||
int64_t axis;
|
||||
switch (axis_tensor->get_element_type())
|
||||
{
|
||||
case element::Type_t::i16: axis = read_vector<int16_t>(axis_tensor)[0]; break;
|
||||
case element::Type_t::i32: axis = read_vector<int32_t>(axis_tensor)[0]; break;
|
||||
case element::Type_t::i64: axis = read_vector<int64_t>(axis_tensor)[0]; break;
|
||||
case element::Type_t::u64:
|
||||
axis = static_cast<int64_t>(read_vector<uint64_t>(axis_tensor)[0]);
|
||||
break;
|
||||
default:
|
||||
NODE_VALIDATION_CHECK(split_node,
|
||||
false,
|
||||
"Not supported axis type: ",
|
||||
axis_tensor->get_element_type(),
|
||||
" during evaluate Split:v1");
|
||||
break;
|
||||
}
|
||||
NGRAPH_CHECK(axis_tensor->get_element_type().is_integral_number(),
|
||||
"axis element type is not integral data type");
|
||||
|
||||
int64_t axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
|
||||
|
||||
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
|
||||
|
||||
std::vector<int64_t> split_lengths;
|
||||
switch (split_lengths_tensor->get_element_type())
|
||||
{
|
||||
case element::Type_t::i32:
|
||||
{
|
||||
const auto split_lengths_i32 = read_vector<int32_t>(split_lengths_tensor);
|
||||
split_lengths =
|
||||
std::vector<int64_t>(std::begin(split_lengths_i32), std::end(split_lengths_i32));
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i64:
|
||||
{
|
||||
const auto split_lengths_i64 = read_vector<int64_t>(split_lengths_tensor);
|
||||
split_lengths =
|
||||
std::vector<int64_t>(std::begin(split_lengths_i64), std::end(split_lengths_i64));
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u64:
|
||||
{
|
||||
const auto split_lengths_u64 = read_vector<uint64_t>(split_lengths_tensor);
|
||||
split_lengths =
|
||||
std::vector<int64_t>(std::begin(split_lengths_u64), std::end(split_lengths_u64));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
NODE_VALIDATION_CHECK(split_node,
|
||||
false,
|
||||
"Not supported split lengths type: ",
|
||||
split_lengths_tensor->get_element_type(),
|
||||
" during evaluate Split:v1");
|
||||
break;
|
||||
}
|
||||
NGRAPH_CHECK(split_lengths_tensor->get_element_type().is_integral_number(),
|
||||
"axis element type is not integral data type");
|
||||
|
||||
std::vector<int64_t> split_lengths = host_tensor_2_vector<int64_t>(split_lengths_tensor);
|
||||
|
||||
const auto data_shape = data_tensor->get_shape();
|
||||
const auto neg_one = std::find(std::begin(split_lengths), std::end(split_lengths), -1);
|
||||
|
Loading…
Reference in New Issue
Block a user