diff --git a/ngraph/src/ngraph/CMakeLists.txt b/ngraph/src/ngraph/CMakeLists.txt index 881c4299b73..10b95747da1 100644 --- a/ngraph/src/ngraph/CMakeLists.txt +++ b/ngraph/src/ngraph/CMakeLists.txt @@ -440,7 +440,6 @@ set (SRC pass/constant_folding_arithmetic_reduction.cpp pass/constant_folding_convert.cpp pass/constant_folding_dequantize.cpp - pass/constant_folding_dyn_reshape.cpp pass/constant_folding_gather.cpp pass/constant_folding_scatter.cpp pass/constant_folding_logical_reduction.cpp @@ -453,7 +452,6 @@ set (SRC pass/constant_folding_split.cpp pass/constant_folding_variadic_split.cpp pass/constant_folding_tile.cpp - pass/constant_folding_transpose.cpp pass/constant_folding.cpp pass/constant_folding.hpp pass/convert_fp32_to_fp16.hpp @@ -512,8 +510,12 @@ set (SRC runtime/host_tensor.hpp runtime/tensor.cpp runtime/tensor.hpp + runtime/opt_kernel/reshape.cpp + runtime/opt_kernel/reshape.hpp runtime/reference/eval_helpers.cpp runtime/reference/eval_helpers.hpp + runtime/reference/reshape.cpp + runtime/reference/reshape.hpp shape.cpp shape.hpp shape_util.cpp diff --git a/ngraph/src/ngraph/op/reshape.cpp b/ngraph/src/ngraph/op/reshape.cpp index a73d640f232..3fc0eea437d 100644 --- a/ngraph/src/ngraph/op/reshape.cpp +++ b/ngraph/src/ngraph/op/reshape.cpp @@ -28,52 +28,17 @@ using namespace ngraph; namespace { - template - bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const AxisVector& order) - { - auto data_ptr = out->get_data_ptr(); - runtime::opt_kernel::reshape::value_type>( - arg0->get_data_ptr(), data_ptr, arg0->get_shape(), order, out->get_shape()); - return true; - } - bool evaluate_reshape(const HostTensorPtr& arg0, const HostTensorPtr& out, const AxisVector& order) { - bool rc = true; - switch (arg0->get_element_type()) - { - case element::Type_t::undefined: rc = false; break; - case element::Type_t::dynamic: rc = false; break; - case element::Type_t::u1: - rc = false; - break; - TYPE_CASE(f16)(arg0, out, order); - break; - TYPE_CASE(f32)(arg0, out, order); - break; - TYPE_CASE(i8)(arg0, out, order); - break; - TYPE_CASE(i16)(arg0, out, order); - break; - TYPE_CASE(i32)(arg0, out, order); - break; - TYPE_CASE(i64)(arg0, out, order); - break; - TYPE_CASE(u8)(arg0, out, order); - break; - TYPE_CASE(u16)(arg0, out, order); - break; - TYPE_CASE(u32)(arg0, out, order); - break; - TYPE_CASE(u64)(arg0, out, order); - break; - TYPE_CASE(boolean)(arg0, out, order); - break; - default: rc = false; break; - } - return rc; + runtime::opt_kernel::reshape(arg0->get_data_ptr(), + out->get_data_ptr(), + arg0->get_shape(), + order, + out->get_shape(), + arg0->get_element_type().size()); + return true; } template @@ -477,6 +442,6 @@ bool op::v1::Reshape::evaluate(const HostTensorVector& outputs, const HostTensor } outputs[0]->set_shape(output_shape); } - const AxisVector order = get_default_order(outputs[0]->get_shape()); + const AxisVector order = get_default_order(inputs[0]->get_shape()); return evaluate_reshape(inputs[0], outputs[0], order); } diff --git a/ngraph/src/ngraph/op/transpose.cpp b/ngraph/src/ngraph/op/transpose.cpp index 391c4b28a8c..9a84be319a2 100644 --- a/ngraph/src/ngraph/op/transpose.cpp +++ b/ngraph/src/ngraph/op/transpose.cpp @@ -18,7 +18,7 @@ #include "ngraph/op/constant.hpp" #include "ngraph/op/transpose.hpp" -#include "ngraph/runtime/reference/reshape.hpp" +#include "ngraph/runtime/opt_kernel/reshape.hpp" using namespace std; using namespace ngraph; @@ -92,8 +92,9 @@ namespace return rc; } - template - bool evaluate(const HostTensorPtr& arg1, const HostTensorPtr& arg2, const HostTensorPtr& out) + bool evaluate_transpose(const HostTensorPtr& arg1, + const HostTensorPtr& arg2, + const HostTensorPtr& out) { element::Type_t axis_type = arg2->get_element_type(); @@ -132,38 +133,13 @@ namespace [&](const int64_t& v) { return in_shape[v]; }); out->set_shape(out_shape); - return (INPUT_ET == arg1->get_element_type()) && - (runtime::reference::reshape(arg1->get_data_ptr(), - out->get_data_ptr(), - arg1->get_shape(), - in_axis_order, - out->get_shape()), - true); - } - - bool evaluate_transpose(const HostTensorPtr& arg1, - const HostTensorPtr& arg2, - const HostTensorPtr& out) - { - bool rc = true; - - switch (arg1->get_element_type()) - { - TYPE_CASE(i32)(arg1, arg2, out); - break; - TYPE_CASE(i64)(arg1, arg2, out); - break; - TYPE_CASE(u32)(arg1, arg2, out); - break; - TYPE_CASE(u64)(arg1, arg2, out); - break; - TYPE_CASE(f16)(arg1, arg2, out); - break; - TYPE_CASE(f32)(arg1, arg2, out); - break; - default: rc = false; break; - } - return rc; + runtime::opt_kernel::reshape(arg1->get_data_ptr(), + out->get_data_ptr(), + arg1->get_shape(), + in_axis_order, + out->get_shape(), + arg1->get_element_type().size()); + return true; } } bool op::v1::Transpose::evaluate(const HostTensorVector& output_values, diff --git a/ngraph/src/ngraph/pass/constant_folding.hpp b/ngraph/src/ngraph/pass/constant_folding.hpp index ed5bb629e2a..2de3be5a92e 100644 --- a/ngraph/src/ngraph/pass/constant_folding.hpp +++ b/ngraph/src/ngraph/pass/constant_folding.hpp @@ -33,37 +33,6 @@ namespace ngraph class NGRAPH_API ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite { public: - enum class CFTransformations - { - RESHAPE, - BROADCAST, - PAD, - DEQUANTIZE, - UNARY, - BINARY, - QUANTIZE, - CONVERT, - SHAPE_OF, - REVERSE, - ARITHMETIC_REDUCTION, - LOGICAL_REDUCTION, - CONCAT, - GATHER, - SCATTER, - SLICE, - DYN_RESHAPE, - TRANSPOSE, - RANGE, - SELECT, - SQUEEZE, - UNSQUEEZE, - SPLIT, - VARIADIC_SPLIT, - ONE_HOT, - TILE, - NON_ZERO - }; - ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) : GraphRewrite() { @@ -81,8 +50,6 @@ public: construct_constant_gather_with_subgraph(); construct_constant_scatter_elements_update(); construct_constant_slice(); - construct_constant_dyn_reshape(); - construct_constant_transpose(); construct_constant_select(); construct_constant_one_hot(); construct_constant_tile(); @@ -100,8 +67,6 @@ private: void construct_constant_gather_with_subgraph(); void construct_constant_scatter_elements_update(); void construct_constant_slice(); - void construct_constant_dyn_reshape(); - void construct_constant_transpose(); void construct_constant_select(); void construct_constant_split(); void construct_constant_variadic_split(); diff --git a/ngraph/src/ngraph/pass/constant_folding_dyn_reshape.cpp b/ngraph/src/ngraph/pass/constant_folding_dyn_reshape.cpp deleted file mode 100644 index 609cbea7994..00000000000 --- a/ngraph/src/ngraph/pass/constant_folding_dyn_reshape.cpp +++ /dev/null @@ -1,130 +0,0 @@ -//***************************************************************************** -// Copyright 2017-2020 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -//***************************************************************************** - -#include - -#include "constant_folding.hpp" -#include "ngraph/op/reshape.hpp" -#include "ngraph/runtime/reference/reshape.hpp" -#include "ngraph/type/element_type.hpp" - -using namespace std; -using namespace ngraph; - -template -shared_ptr fold_constant_dyn_reshape(shared_ptr constant_data, - R dyn_reshape) -{ - // v1::Reshape and v0::DynReshape do not allow data transposes. - return make_shared(dyn_reshape->get_element_type(), - dyn_reshape->get_shape(), - constant_data->get_data_ptr()); -} - -template -std::shared_ptr do_fold(R dyn_reshape_match, shared_ptr constant_data_match) -{ - std::shared_ptr replacement; - auto type = dyn_reshape_match->get_element_type(); - switch (type) - { - case element::Type_t::undefined: - NGRAPH_CHECK(false, - "Encountered 'undefined' element type in constant_dyn_reshape_callback"); - break; - case element::Type_t::dynamic: - NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_dyn_reshape_callback"); - break; - case element::Type_t::u1: - NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback"); - break; - case element::Type_t::boolean: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::bf16: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::f16: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::f32: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::f64: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::i8: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::i16: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::i32: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::i64: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::u8: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::u16: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::u32: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - case element::Type_t::u64: - replacement = fold_constant_dyn_reshape(constant_data_match, dyn_reshape_match); - break; - } - return replacement; -} - -void pass::ConstantFolding::construct_constant_dyn_reshape() -{ - auto constant_data_label = make_shared( - element::f32, Shape{2, 4}, pattern::has_class()); - auto constant_shape_label = - make_shared(element::i64, Shape{1}, pattern::has_class()); - auto reshape_v1 = - make_shared(constant_data_label, constant_shape_label, false); - - // Note: No need to capture or consider constant_shape_label, because - // shape propagation will have transferred the info to dyn_reshape's - // output. - auto constant_reshape_v1_callback = [constant_data_label](pattern::Matcher& m) { - NGRAPH_DEBUG << "In callback for constant_reshape_v1_callback against node = " - << m.get_match_root()->get_name(); - - auto pattern_map = m.get_pattern_map(); - - auto constant_data_match = - static_pointer_cast(pattern_map[constant_data_label]); - auto match_root = m.get_match_root(); - NGRAPH_CHECK(revalidate_and_ensure_static(match_root)); - shared_ptr replacement; - replacement = - do_fold(static_pointer_cast(match_root), constant_data_match); - replace_node(m.get_match_root(), replacement); - return true; - }; - - auto reshape_v1_matcher = - make_shared(reshape_v1, "ConstantFolding.ConstantReshapev1"); - this->add_matcher( - reshape_v1_matcher, constant_reshape_v1_callback, PassProperty::CHANGE_DYNAMIC_STATE); -} diff --git a/ngraph/src/ngraph/pass/constant_folding_transpose.cpp b/ngraph/src/ngraph/pass/constant_folding_transpose.cpp deleted file mode 100644 index b34436ecfea..00000000000 --- a/ngraph/src/ngraph/pass/constant_folding_transpose.cpp +++ /dev/null @@ -1,143 +0,0 @@ -//***************************************************************************** -// Copyright 2017-2020 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -//***************************************************************************** - -#include "constant_folding.hpp" -#include "ngraph/op/transpose.hpp" -#include "ngraph/runtime/opt_kernel/reshape.hpp" - -using namespace std; -using namespace ngraph; - -template -shared_ptr fold_constant_transpose(shared_ptr constant_data, - shared_ptr constant_perm, - shared_ptr transpose) -{ - const Shape& out_shape = transpose->get_shape(); - auto input_order = constant_perm->get_axis_vector_val(); - - runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T)); - - runtime::opt_kernel::reshape(constant_data->get_data_ptr(), - buffer.get_ptr(), - constant_data->get_shape(), - input_order, - out_shape); - - return make_shared(transpose->get_element_type(), out_shape, buffer.get_ptr()); -} - -void pass::ConstantFolding::construct_constant_transpose() -{ - auto constant_data_label = make_shared( - element::f32, Shape{2, 4}, pattern::has_class()); - auto constant_perm_label = - make_shared(element::i64, Shape{2}, pattern::has_class()); - auto transpose = make_shared(constant_data_label, constant_perm_label); - - auto constant_transpose_callback = [constant_data_label, - constant_perm_label](pattern::Matcher& m) { - NGRAPH_DEBUG << "In callback for constant_transpose_callback against node = " - << m.get_match_root()->get_name(); - - auto pattern_map = m.get_pattern_map(); - - auto constant_data_match = - static_pointer_cast(pattern_map[constant_data_label]); - auto constant_perm_match = - static_pointer_cast(pattern_map[constant_perm_label]); - auto transpose_match = static_pointer_cast(m.get_match_root()); - - NGRAPH_CHECK(revalidate_and_ensure_static(transpose_match)); - - std::shared_ptr replacement; - auto type = transpose_match->get_element_type(); - switch (type) - { - case element::Type_t::undefined: - NGRAPH_CHECK(false, - "Encountered 'undefined' element type in constant_transpose_callback"); - break; - case element::Type_t::dynamic: - NGRAPH_CHECK(false, - "Encountered 'dynamic' element type in constant_transpose_callback"); - break; - case element::Type_t::u1: - NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_transpose_callback"); - break; - case element::Type_t::boolean: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::bf16: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::f16: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::f32: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::f64: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::i8: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::i16: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::i32: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::i64: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::u8: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::u16: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::u32: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - case element::Type_t::u64: - replacement = fold_constant_transpose( - constant_data_match, constant_perm_match, transpose_match); - break; - } - - replace_node(m.get_match_root(), replacement); - return true; - }; - - auto transpose_matcher = - make_shared(transpose, "ConstantFolding.ConstantTranspose"); - this->add_matcher( - transpose_matcher, constant_transpose_callback, PassProperty::CHANGE_DYNAMIC_STATE); -} diff --git a/ngraph/src/ngraph/runtime/opt_kernel/reshape.cpp b/ngraph/src/ngraph/runtime/opt_kernel/reshape.cpp new file mode 100644 index 00000000000..ba784939b5c --- /dev/null +++ b/ngraph/src/ngraph/runtime/opt_kernel/reshape.cpp @@ -0,0 +1,267 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "ngraph/check.hpp" +#include "ngraph/runtime/opt_kernel/reshape.hpp" + +using namespace ngraph; + +namespace +{ + void reshape_in0(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + memcpy(out, in, elem_size); + } + + void reshape_in1(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[1]; + size_t in_index[1]; + size_t* map_index[1]; + for (size_t i = 0; i < 1; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + memcpy(out, in + *map_index[0] * elem_size, elem_size); + out += elem_size; + } + } + + void reshape_in2(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[2]; + size_t in_index[2]; + size_t* map_index[2]; + for (size_t i = 0; i < 2; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) + { + // clang-format off + memcpy(out, + in + (*map_index[0] * in_shape[1] + + *map_index[1]) * elem_size, + elem_size); + out += elem_size; + // clang-format on + } + } + } + + void reshape_in3(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[3]; + size_t in_index[3]; + size_t* map_index[3]; + for (size_t i = 0; i < 3; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) + { + for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) + { + // clang-format off + memcpy(out, + in + (*map_index[0] * in_shape[1] * in_shape[2] + + *map_index[1] * in_shape[2] + + *map_index[2]) * elem_size, + elem_size); + out += elem_size; + // clang-format on + } + } + } + } + + void reshape_in4(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[4]; + size_t in_index[4]; + size_t* map_index[4]; + for (size_t i = 0; i < 4; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) + { + for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) + { + for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) + { + // clang-format off + memcpy(out, + in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] + + *map_index[1] * in_shape[2] * in_shape[3] + + *map_index[2] * in_shape[3] + + *map_index[3]) * elem_size, + elem_size); + out += elem_size; + // clang-format on + } + } + } + } + } + + void reshape_in5(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[5]; + size_t in_index[5]; + size_t* map_index[5]; + for (size_t i = 0; i < 5; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) + { + for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) + { + for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) + { + for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4]) + { + // clang-format off + memcpy(out, + in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] + + *map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] + + *map_index[2] * in_shape[3] * in_shape[4] + + *map_index[3] * in_shape[4] + + *map_index[4]) * elem_size, + elem_size); + out += elem_size; + // clang-format on + } + } + } + } + } + } + + void reshape_in6(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) + { + size_t size[6]; + size_t in_index[6]; + size_t* map_index[6]; + for (size_t i = 0; i < 6; i++) + { + size[i] = in_shape[in_axis_order[i]]; + map_index[in_axis_order[i]] = &in_index[i]; + } + for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) + { + for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) + { + for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) + { + for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) + { + for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4]) + { + for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5]) + { + // clang-format off + memcpy(out, + in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] + + *map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] + + *map_index[2] * in_shape[3] * in_shape[4] * in_shape[5] + + *map_index[3] * in_shape[4] * in_shape[5] + + *map_index[4] * in_shape[5] + + *map_index[5]) * elem_size, + elem_size); + out += elem_size; + // clang-format on + } + } + } + } + } + } + } +} +void runtime::opt_kernel::reshape(const char* in, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) +{ + switch (in_shape.size()) + { + case 0: reshape_in0(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 1: reshape_in1(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 2: reshape_in2(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 3: reshape_in3(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 4: reshape_in4(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 5: reshape_in5(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + case 6: reshape_in6(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + default: reference::reshape(in, out, in_shape, in_axis_order, out_shape, elem_size); break; + } +} diff --git a/ngraph/src/ngraph/runtime/opt_kernel/reshape.hpp b/ngraph/src/ngraph/runtime/opt_kernel/reshape.hpp index f1dff29aad0..948ec9d278f 100644 --- a/ngraph/src/ngraph/runtime/opt_kernel/reshape.hpp +++ b/ngraph/src/ngraph/runtime/opt_kernel/reshape.hpp @@ -26,232 +26,12 @@ namespace ngraph { namespace opt_kernel { - template - void reshape_in0(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - *out = *in; - } - - template - void reshape_in1(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[1]; - size_t in_index[1]; - size_t* map_index[1]; - for (size_t i = 0; i < 1; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - *out++ = in[*map_index[0]]; - } - } - - template - void reshape_in2(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[2]; - size_t in_index[2]; - size_t* map_index[2]; - for (size_t i = 0; i < 2; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) - { - // clang-format off - *out++ = in[*map_index[0] * in_shape[1] + - *map_index[1]]; - // clang-format on - } - } - } - - template - void reshape_in3(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[3]; - size_t in_index[3]; - size_t* map_index[3]; - for (size_t i = 0; i < 3; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) - { - for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) - { - // clang-format off - *out++ = in[*map_index[0] * in_shape[1] * in_shape[2] + - *map_index[1] * in_shape[2] + - *map_index[2]]; - // clang-format on - } - } - } - } - - template - void reshape_in4(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[4]; - size_t in_index[4]; - size_t* map_index[4]; - for (size_t i = 0; i < 4; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) - { - for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) - { - for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) - { - // clang-format off - *out++ = - in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] + - *map_index[1] * in_shape[2] * in_shape[3] + - *map_index[2] * in_shape[3] + - *map_index[3]]; - // clang-format on - } - } - } - } - } - - template - void reshape_in5(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[5]; - size_t in_index[5]; - size_t* map_index[5]; - for (size_t i = 0; i < 5; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) - { - for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) - { - for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) - { - for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4]) - { - // clang-format off - *out++ = - in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] + - *map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] + - *map_index[2] * in_shape[3] * in_shape[4] + - *map_index[3] * in_shape[4] + - *map_index[4]]; - // clang-format on - } - } - } - } - } - } - - template - void reshape_in6(const T* in, - T* out, - const Shape& in_shape, - const AxisVector& in_axis_order, - const Shape& out_shape) - { - size_t size[6]; - size_t in_index[6]; - size_t* map_index[6]; - for (size_t i = 0; i < 6; i++) - { - size[i] = in_shape[in_axis_order[i]]; - map_index[in_axis_order[i]] = &in_index[i]; - } - for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0]) - { - for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1]) - { - for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2]) - { - for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3]) - { - for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4]) - { - for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5]) - { - // clang-format off - *out++ = in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] + - *map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] + - *map_index[2] * in_shape[3] * in_shape[4] * in_shape[5] + - *map_index[3] * in_shape[4] * in_shape[5] + - *map_index[4] * in_shape[5] + - *map_index[5]]; - // clang-format on - } - } - } - } - } - } - } - template - void reshape(const T* in, - T* out, + void reshape(const char* in, + char* out, const Shape& in_shape, const AxisVector& in_axis_order, - const Shape& out_shape) - { - switch (in_shape.size()) - { - case 0: reshape_in0(in, out, in_shape, in_axis_order, out_shape); break; - case 1: reshape_in1(in, out, in_shape, in_axis_order, out_shape); break; - case 2: reshape_in2(in, out, in_shape, in_axis_order, out_shape); break; - case 3: reshape_in3(in, out, in_shape, in_axis_order, out_shape); break; - case 4: reshape_in4(in, out, in_shape, in_axis_order, out_shape); break; - case 5: reshape_in5(in, out, in_shape, in_axis_order, out_shape); break; - case 6: reshape_in6(in, out, in_shape, in_axis_order, out_shape); break; - default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break; - } - } + const Shape& out_shape, + size_t elem_size); } } } diff --git a/ngraph/src/ngraph/runtime/reference/matmul.hpp b/ngraph/src/ngraph/runtime/reference/matmul.hpp index 7f8c5dfc5f3..17de94b12c5 100644 --- a/ngraph/src/ngraph/runtime/reference/matmul.hpp +++ b/ngraph/src/ngraph/runtime/reference/matmul.hpp @@ -23,9 +23,9 @@ #include "ngraph/axis_vector.hpp" #include "ngraph/builder/autobroadcast.hpp" +#include "ngraph/runtime/opt_kernel/reshape.hpp" #include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/dot.hpp" -#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/shape_util.hpp" using namespace std; @@ -116,8 +116,12 @@ namespace ngraph arg0_transpose_vec.reserve(shape_size(arg0_shape)); auto axis_vector = get_transpose_order(arg0_shape); swap(wip_arg0_shape[arg0_rank - 1], wip_arg0_shape[arg0_rank - 2]); - reference::reshape( - arg0, arg0_transpose_vec.data(), arg0_shape, axis_vector, wip_arg0_shape); + opt_kernel::reshape(reinterpret_cast(arg0), + reinterpret_cast(arg0_transpose_vec.data()), + arg0_shape, + axis_vector, + wip_arg0_shape, + sizeof(T)); arg0_update = arg0_transpose_vec.data(); } @@ -127,8 +131,12 @@ namespace ngraph arg1_transpose_vec.reserve(shape_size(arg1_shape)); auto axis_vector = get_transpose_order(arg1_shape); swap(wip_arg1_shape[arg1_rank - 1], wip_arg1_shape[arg1_rank - 2]); - reference::reshape( - arg1, arg1_transpose_vec.data(), arg1_shape, axis_vector, wip_arg1_shape); + opt_kernel::reshape(reinterpret_cast(arg1), + reinterpret_cast(arg1_transpose_vec.data()), + arg1_shape, + axis_vector, + wip_arg1_shape, + sizeof(T)); arg1_update = arg1_transpose_vec.data(); } diff --git a/ngraph/src/ngraph/runtime/reference/reshape.cpp b/ngraph/src/ngraph/runtime/reference/reshape.cpp new file mode 100644 index 00000000000..5872ee3e77e --- /dev/null +++ b/ngraph/src/ngraph/runtime/reference/reshape.cpp @@ -0,0 +1,57 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "ngraph/check.hpp" +#include "ngraph/runtime/reference/reshape.hpp" + +using namespace ngraph; + +void runtime::reference::reshape(const char* arg, + char* out, + const Shape& in_shape, + const AxisVector& in_axis_order, + const Shape& out_shape, + size_t elem_size) +{ + // Unfortunately we don't yet have a constructor for CoordinateTransform that lets + // us pass only source_space_shape + // and source_axis_order so we have to construct the defaults here. + Shape in_start_corner(in_shape.size(), 0); // (0,...0) + Strides in_strides(in_shape.size(), 1); // (1,...,1) + + CoordinateTransform input_transform( + in_shape, in_start_corner, in_shape, in_strides, in_axis_order); + CoordinateTransform output_transform(out_shape); + + NGRAPH_CHECK(shape_size(input_transform.get_target_shape()) == + shape_size(output_transform.get_target_shape())); + + CoordinateTransform::Iterator output_it = output_transform.begin(); + + for (const Coordinate& input_coord : input_transform) + { + const Coordinate& output_coord = *output_it; + + memcpy(out + output_transform.index(output_coord) * elem_size, + arg + input_transform.index(input_coord) * elem_size, + elem_size); + + ++output_it; + } +} diff --git a/ngraph/src/ngraph/runtime/reference/reshape.hpp b/ngraph/src/ngraph/runtime/reference/reshape.hpp index 7d7b3e88dd3..e4711d634a6 100644 --- a/ngraph/src/ngraph/runtime/reference/reshape.hpp +++ b/ngraph/src/ngraph/runtime/reference/reshape.hpp @@ -21,6 +21,7 @@ #include "ngraph/axis_vector.hpp" #include "ngraph/check.hpp" #include "ngraph/coordinate_transform.hpp" +#include "ngraph/type/element_type.hpp" namespace ngraph { @@ -28,38 +29,12 @@ namespace ngraph { namespace reference { - template - void reshape(const T* arg, - T* out, + void reshape(const char* arg, + char* out, const Shape& in_shape, const AxisVector& in_axis_order, - const Shape& out_shape) - { - // Unfortunately we don't yet have a constructor for CoordinateTransform that lets - // us pass only source_space_shape - // and source_axis_order so we have to construct the defaults here. - Shape in_start_corner(in_shape.size(), 0); // (0,...0) - Strides in_strides(in_shape.size(), 1); // (1,...,1) - - CoordinateTransform input_transform( - in_shape, in_start_corner, in_shape, in_strides, in_axis_order); - CoordinateTransform output_transform(out_shape); - - NGRAPH_CHECK(shape_size(input_transform.get_target_shape()) == - shape_size(output_transform.get_target_shape())); - - CoordinateTransform::Iterator output_it = output_transform.begin(); - - for (const Coordinate& input_coord : input_transform) - { - const Coordinate& output_coord = *output_it; - - out[output_transform.index(output_coord)] = - arg[input_transform.index(input_coord)]; - - ++output_it; - } - } + const Shape& out_shape, + size_t elem_size); } } } diff --git a/ngraph/src/ngraph/runtime/reference/strided_slice.hpp b/ngraph/src/ngraph/runtime/reference/strided_slice.hpp index 599d840367d..76745a50d16 100644 --- a/ngraph/src/ngraph/runtime/reference/strided_slice.hpp +++ b/ngraph/src/ngraph/runtime/reference/strided_slice.hpp @@ -21,7 +21,7 @@ #include "ngraph/check.hpp" #include "ngraph/coordinate_transform.hpp" #include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/runtime/reference/reshape.hpp" +#include "ngraph/runtime/opt_kernel/reshape.hpp" #include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/slice.hpp" #include "ngraph/slice_plan.hpp" @@ -47,11 +47,12 @@ namespace ngraph runtime::AlignedBuffer reshape_out_buffer(shape_size(sp.reshape_out_shape) * sizeof(T)); - reshape(slice_out_buffer.get_ptr(), - reshape_out_buffer.get_ptr(), - sp.reshape_in_shape, - get_default_order(sp.reshape_in_shape.size()), - sp.reshape_out_shape); + opt_kernel::reshape(slice_out_buffer.get_ptr(), + reshape_out_buffer.get_ptr(), + sp.reshape_in_shape, + get_default_order(sp.reshape_in_shape.size()), + sp.reshape_out_shape, + sizeof(T)); reverse(reshape_out_buffer.get_ptr(), out, @@ -61,4 +62,4 @@ namespace ngraph } } } -} \ No newline at end of file +}