Removed reshape and transpose constant folding passes (#1598)
* Removed template code from reshape implementation * Removed constant foldyng for transpose and dyn reshape
This commit is contained in:
parent
c518667e0a
commit
478d0368d0
@ -440,7 +440,6 @@ set (SRC
|
||||
pass/constant_folding_arithmetic_reduction.cpp
|
||||
pass/constant_folding_convert.cpp
|
||||
pass/constant_folding_dequantize.cpp
|
||||
pass/constant_folding_dyn_reshape.cpp
|
||||
pass/constant_folding_gather.cpp
|
||||
pass/constant_folding_scatter.cpp
|
||||
pass/constant_folding_logical_reduction.cpp
|
||||
@ -453,7 +452,6 @@ set (SRC
|
||||
pass/constant_folding_split.cpp
|
||||
pass/constant_folding_variadic_split.cpp
|
||||
pass/constant_folding_tile.cpp
|
||||
pass/constant_folding_transpose.cpp
|
||||
pass/constant_folding.cpp
|
||||
pass/constant_folding.hpp
|
||||
pass/convert_fp32_to_fp16.hpp
|
||||
@ -512,8 +510,12 @@ set (SRC
|
||||
runtime/host_tensor.hpp
|
||||
runtime/tensor.cpp
|
||||
runtime/tensor.hpp
|
||||
runtime/opt_kernel/reshape.cpp
|
||||
runtime/opt_kernel/reshape.hpp
|
||||
runtime/reference/eval_helpers.cpp
|
||||
runtime/reference/eval_helpers.hpp
|
||||
runtime/reference/reshape.cpp
|
||||
runtime/reference/reshape.hpp
|
||||
shape.cpp
|
||||
shape.hpp
|
||||
shape_util.cpp
|
||||
|
@ -28,52 +28,17 @@ using namespace ngraph;
|
||||
|
||||
namespace
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const AxisVector& order)
|
||||
{
|
||||
auto data_ptr = out->get_data_ptr<ET>();
|
||||
runtime::opt_kernel::reshape<typename element_type_traits<ET>::value_type>(
|
||||
arg0->get_data_ptr<ET>(), data_ptr, arg0->get_shape(), order, out->get_shape());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_reshape(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& out,
|
||||
const AxisVector& order)
|
||||
{
|
||||
bool rc = true;
|
||||
switch (arg0->get_element_type())
|
||||
{
|
||||
case element::Type_t::undefined: rc = false; break;
|
||||
case element::Type_t::dynamic: rc = false; break;
|
||||
case element::Type_t::u1:
|
||||
rc = false;
|
||||
break;
|
||||
TYPE_CASE(f16)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(f32)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(i8)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(i16)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(i32)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(i64)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(u8)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(u16)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(u32)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(u64)(arg0, out, order);
|
||||
break;
|
||||
TYPE_CASE(boolean)(arg0, out, order);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
runtime::opt_kernel::reshape(arg0->get_data_ptr<char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
arg0->get_shape(),
|
||||
order,
|
||||
out->get_shape(),
|
||||
arg0->get_element_type().size());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
@ -477,6 +442,6 @@ bool op::v1::Reshape::evaluate(const HostTensorVector& outputs, const HostTensor
|
||||
}
|
||||
outputs[0]->set_shape(output_shape);
|
||||
}
|
||||
const AxisVector order = get_default_order(outputs[0]->get_shape());
|
||||
const AxisVector order = get_default_order(inputs[0]->get_shape());
|
||||
return evaluate_reshape(inputs[0], outputs[0], order);
|
||||
}
|
||||
|
@ -18,7 +18,7 @@
|
||||
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/runtime/reference/reshape.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -92,8 +92,9 @@ namespace
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <element::Type_t INPUT_ET>
|
||||
bool evaluate(const HostTensorPtr& arg1, const HostTensorPtr& arg2, const HostTensorPtr& out)
|
||||
bool evaluate_transpose(const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
element::Type_t axis_type = arg2->get_element_type();
|
||||
|
||||
@ -132,38 +133,13 @@ namespace
|
||||
[&](const int64_t& v) { return in_shape[v]; });
|
||||
|
||||
out->set_shape(out_shape);
|
||||
return (INPUT_ET == arg1->get_element_type()) &&
|
||||
(runtime::reference::reshape(arg1->get_data_ptr<INPUT_ET>(),
|
||||
out->get_data_ptr<INPUT_ET>(),
|
||||
arg1->get_shape(),
|
||||
in_axis_order,
|
||||
out->get_shape()),
|
||||
true);
|
||||
}
|
||||
|
||||
bool evaluate_transpose(const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
bool rc = true;
|
||||
|
||||
switch (arg1->get_element_type())
|
||||
{
|
||||
TYPE_CASE(i32)(arg1, arg2, out);
|
||||
break;
|
||||
TYPE_CASE(i64)(arg1, arg2, out);
|
||||
break;
|
||||
TYPE_CASE(u32)(arg1, arg2, out);
|
||||
break;
|
||||
TYPE_CASE(u64)(arg1, arg2, out);
|
||||
break;
|
||||
TYPE_CASE(f16)(arg1, arg2, out);
|
||||
break;
|
||||
TYPE_CASE(f32)(arg1, arg2, out);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
runtime::opt_kernel::reshape(arg1->get_data_ptr<char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
arg1->get_shape(),
|
||||
in_axis_order,
|
||||
out->get_shape(),
|
||||
arg1->get_element_type().size());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
bool op::v1::Transpose::evaluate(const HostTensorVector& output_values,
|
||||
|
@ -33,37 +33,6 @@ namespace ngraph
|
||||
class NGRAPH_API ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
enum class CFTransformations
|
||||
{
|
||||
RESHAPE,
|
||||
BROADCAST,
|
||||
PAD,
|
||||
DEQUANTIZE,
|
||||
UNARY,
|
||||
BINARY,
|
||||
QUANTIZE,
|
||||
CONVERT,
|
||||
SHAPE_OF,
|
||||
REVERSE,
|
||||
ARITHMETIC_REDUCTION,
|
||||
LOGICAL_REDUCTION,
|
||||
CONCAT,
|
||||
GATHER,
|
||||
SCATTER,
|
||||
SLICE,
|
||||
DYN_RESHAPE,
|
||||
TRANSPOSE,
|
||||
RANGE,
|
||||
SELECT,
|
||||
SQUEEZE,
|
||||
UNSQUEEZE,
|
||||
SPLIT,
|
||||
VARIADIC_SPLIT,
|
||||
ONE_HOT,
|
||||
TILE,
|
||||
NON_ZERO
|
||||
};
|
||||
|
||||
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
|
||||
: GraphRewrite()
|
||||
{
|
||||
@ -81,8 +50,6 @@ public:
|
||||
construct_constant_gather_with_subgraph();
|
||||
construct_constant_scatter_elements_update();
|
||||
construct_constant_slice();
|
||||
construct_constant_dyn_reshape();
|
||||
construct_constant_transpose();
|
||||
construct_constant_select();
|
||||
construct_constant_one_hot();
|
||||
construct_constant_tile();
|
||||
@ -100,8 +67,6 @@ private:
|
||||
void construct_constant_gather_with_subgraph();
|
||||
void construct_constant_scatter_elements_update();
|
||||
void construct_constant_slice();
|
||||
void construct_constant_dyn_reshape();
|
||||
void construct_constant_transpose();
|
||||
void construct_constant_select();
|
||||
void construct_constant_split();
|
||||
void construct_constant_variadic_split();
|
||||
|
@ -1,130 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "constant_folding.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/runtime/reference/reshape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
template <typename T, typename R>
|
||||
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
|
||||
R dyn_reshape)
|
||||
{
|
||||
// v1::Reshape and v0::DynReshape do not allow data transposes.
|
||||
return make_shared<op::Constant>(dyn_reshape->get_element_type(),
|
||||
dyn_reshape->get_shape(),
|
||||
constant_data->get_data_ptr<T>());
|
||||
}
|
||||
|
||||
template <typename R>
|
||||
std::shared_ptr<Node> do_fold(R dyn_reshape_match, shared_ptr<op::Constant> constant_data_match)
|
||||
{
|
||||
std::shared_ptr<Node> replacement;
|
||||
auto type = dyn_reshape_match->get_element_type();
|
||||
switch (type)
|
||||
{
|
||||
case element::Type_t::undefined:
|
||||
NGRAPH_CHECK(false,
|
||||
"Encountered 'undefined' element type in constant_dyn_reshape_callback");
|
||||
break;
|
||||
case element::Type_t::dynamic:
|
||||
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_dyn_reshape_callback");
|
||||
break;
|
||||
case element::Type_t::u1:
|
||||
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback");
|
||||
break;
|
||||
case element::Type_t::boolean:
|
||||
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::bf16:
|
||||
replacement = fold_constant_dyn_reshape<bfloat16>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
replacement = fold_constant_dyn_reshape<float16>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
replacement = fold_constant_dyn_reshape<float>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::f64:
|
||||
replacement = fold_constant_dyn_reshape<double>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
replacement = fold_constant_dyn_reshape<int8_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
replacement = fold_constant_dyn_reshape<int16_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
replacement = fold_constant_dyn_reshape<int32_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
replacement = fold_constant_dyn_reshape<int64_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
replacement = fold_constant_dyn_reshape<uint8_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
replacement = fold_constant_dyn_reshape<uint16_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
replacement = fold_constant_dyn_reshape<uint32_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
replacement = fold_constant_dyn_reshape<uint64_t>(constant_data_match, dyn_reshape_match);
|
||||
break;
|
||||
}
|
||||
return replacement;
|
||||
}
|
||||
|
||||
void pass::ConstantFolding::construct_constant_dyn_reshape()
|
||||
{
|
||||
auto constant_data_label = make_shared<pattern::op::Label>(
|
||||
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
|
||||
auto constant_shape_label =
|
||||
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
|
||||
auto reshape_v1 =
|
||||
make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label, false);
|
||||
|
||||
// Note: No need to capture or consider constant_shape_label, because
|
||||
// shape propagation will have transferred the info to dyn_reshape's
|
||||
// output.
|
||||
auto constant_reshape_v1_callback = [constant_data_label](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for constant_reshape_v1_callback against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto constant_data_match =
|
||||
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
|
||||
auto match_root = m.get_match_root();
|
||||
NGRAPH_CHECK(revalidate_and_ensure_static(match_root));
|
||||
shared_ptr<Node> replacement;
|
||||
replacement =
|
||||
do_fold(static_pointer_cast<op::v1::Reshape>(match_root), constant_data_match);
|
||||
replace_node(m.get_match_root(), replacement);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto reshape_v1_matcher =
|
||||
make_shared<pattern::Matcher>(reshape_v1, "ConstantFolding.ConstantReshapev1");
|
||||
this->add_matcher(
|
||||
reshape_v1_matcher, constant_reshape_v1_callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
@ -1,143 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "constant_folding.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
template <class T>
|
||||
shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> constant_data,
|
||||
shared_ptr<op::Constant> constant_perm,
|
||||
shared_ptr<op::Transpose> transpose)
|
||||
{
|
||||
const Shape& out_shape = transpose->get_shape();
|
||||
auto input_order = constant_perm->get_axis_vector_val();
|
||||
|
||||
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
|
||||
|
||||
runtime::opt_kernel::reshape<T>(constant_data->get_data_ptr<T>(),
|
||||
buffer.get_ptr<T>(),
|
||||
constant_data->get_shape(),
|
||||
input_order,
|
||||
out_shape);
|
||||
|
||||
return make_shared<op::Constant>(transpose->get_element_type(), out_shape, buffer.get_ptr<T>());
|
||||
}
|
||||
|
||||
void pass::ConstantFolding::construct_constant_transpose()
|
||||
{
|
||||
auto constant_data_label = make_shared<pattern::op::Label>(
|
||||
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
|
||||
auto constant_perm_label =
|
||||
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
|
||||
auto transpose = make_shared<op::Transpose>(constant_data_label, constant_perm_label);
|
||||
|
||||
auto constant_transpose_callback = [constant_data_label,
|
||||
constant_perm_label](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for constant_transpose_callback against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto constant_data_match =
|
||||
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
|
||||
auto constant_perm_match =
|
||||
static_pointer_cast<op::Constant>(pattern_map[constant_perm_label]);
|
||||
auto transpose_match = static_pointer_cast<op::Transpose>(m.get_match_root());
|
||||
|
||||
NGRAPH_CHECK(revalidate_and_ensure_static(transpose_match));
|
||||
|
||||
std::shared_ptr<Node> replacement;
|
||||
auto type = transpose_match->get_element_type();
|
||||
switch (type)
|
||||
{
|
||||
case element::Type_t::undefined:
|
||||
NGRAPH_CHECK(false,
|
||||
"Encountered 'undefined' element type in constant_transpose_callback");
|
||||
break;
|
||||
case element::Type_t::dynamic:
|
||||
NGRAPH_CHECK(false,
|
||||
"Encountered 'dynamic' element type in constant_transpose_callback");
|
||||
break;
|
||||
case element::Type_t::u1:
|
||||
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_transpose_callback");
|
||||
break;
|
||||
case element::Type_t::boolean:
|
||||
replacement = fold_constant_transpose<char>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::bf16:
|
||||
replacement = fold_constant_transpose<bfloat16>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
replacement = fold_constant_transpose<float16>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
replacement = fold_constant_transpose<float>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::f64:
|
||||
replacement = fold_constant_transpose<double>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
replacement = fold_constant_transpose<int8_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
replacement = fold_constant_transpose<int16_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
replacement = fold_constant_transpose<int32_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
replacement = fold_constant_transpose<int64_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
replacement = fold_constant_transpose<uint8_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
replacement = fold_constant_transpose<uint16_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
replacement = fold_constant_transpose<uint32_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
replacement = fold_constant_transpose<uint64_t>(
|
||||
constant_data_match, constant_perm_match, transpose_match);
|
||||
break;
|
||||
}
|
||||
|
||||
replace_node(m.get_match_root(), replacement);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto transpose_matcher =
|
||||
make_shared<pattern::Matcher>(transpose, "ConstantFolding.ConstantTranspose");
|
||||
this->add_matcher(
|
||||
transpose_matcher, constant_transpose_callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
267
ngraph/src/ngraph/runtime/opt_kernel/reshape.cpp
Normal file
267
ngraph/src/ngraph/runtime/opt_kernel/reshape.cpp
Normal file
@ -0,0 +1,267 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace
|
||||
{
|
||||
void reshape_in0(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
memcpy(out, in, elem_size);
|
||||
}
|
||||
|
||||
void reshape_in1(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[1];
|
||||
size_t in_index[1];
|
||||
size_t* map_index[1];
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
memcpy(out, in + *map_index[0] * elem_size, elem_size);
|
||||
out += elem_size;
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_in2(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[2];
|
||||
size_t in_index[2];
|
||||
size_t* map_index[2];
|
||||
for (size_t i = 0; i < 2; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
// clang-format off
|
||||
memcpy(out,
|
||||
in + (*map_index[0] * in_shape[1] +
|
||||
*map_index[1]) * elem_size,
|
||||
elem_size);
|
||||
out += elem_size;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_in3(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[3];
|
||||
size_t in_index[3];
|
||||
size_t* map_index[3];
|
||||
for (size_t i = 0; i < 3; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
// clang-format off
|
||||
memcpy(out,
|
||||
in + (*map_index[0] * in_shape[1] * in_shape[2] +
|
||||
*map_index[1] * in_shape[2] +
|
||||
*map_index[2]) * elem_size,
|
||||
elem_size);
|
||||
out += elem_size;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_in4(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[4];
|
||||
size_t in_index[4];
|
||||
size_t* map_index[4];
|
||||
for (size_t i = 0; i < 4; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
// clang-format off
|
||||
memcpy(out,
|
||||
in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] +
|
||||
*map_index[2] * in_shape[3] +
|
||||
*map_index[3]) * elem_size,
|
||||
elem_size);
|
||||
out += elem_size;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_in5(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[5];
|
||||
size_t in_index[5];
|
||||
size_t* map_index[5];
|
||||
for (size_t i = 0; i < 5; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
|
||||
{
|
||||
// clang-format off
|
||||
memcpy(out,
|
||||
in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[3] * in_shape[4] +
|
||||
*map_index[4]) * elem_size,
|
||||
elem_size);
|
||||
out += elem_size;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_in6(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
size_t size[6];
|
||||
size_t in_index[6];
|
||||
size_t* map_index[6];
|
||||
for (size_t i = 0; i < 6; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
|
||||
{
|
||||
for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5])
|
||||
{
|
||||
// clang-format off
|
||||
memcpy(out,
|
||||
in + (*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[4] * in_shape[5] +
|
||||
*map_index[5]) * elem_size,
|
||||
elem_size);
|
||||
out += elem_size;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void runtime::opt_kernel::reshape(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
switch (in_shape.size())
|
||||
{
|
||||
case 0: reshape_in0(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 1: reshape_in1(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 2: reshape_in2(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 3: reshape_in3(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 4: reshape_in4(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 5: reshape_in5(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
case 6: reshape_in6(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
default: reference::reshape(in, out, in_shape, in_axis_order, out_shape, elem_size); break;
|
||||
}
|
||||
}
|
@ -26,232 +26,12 @@ namespace ngraph
|
||||
{
|
||||
namespace opt_kernel
|
||||
{
|
||||
template <typename T>
|
||||
void reshape_in0(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
*out = *in;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in1(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[1];
|
||||
size_t in_index[1];
|
||||
size_t* map_index[1];
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
*out++ = in[*map_index[0]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in2(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[2];
|
||||
size_t in_index[2];
|
||||
size_t* map_index[2];
|
||||
for (size_t i = 0; i < 2; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
// clang-format off
|
||||
*out++ = in[*map_index[0] * in_shape[1] +
|
||||
*map_index[1]];
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in3(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[3];
|
||||
size_t in_index[3];
|
||||
size_t* map_index[3];
|
||||
for (size_t i = 0; i < 3; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
// clang-format off
|
||||
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] +
|
||||
*map_index[1] * in_shape[2] +
|
||||
*map_index[2]];
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in4(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[4];
|
||||
size_t in_index[4];
|
||||
size_t* map_index[4];
|
||||
for (size_t i = 0; i < 4; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
// clang-format off
|
||||
*out++ =
|
||||
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] +
|
||||
*map_index[2] * in_shape[3] +
|
||||
*map_index[3]];
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in5(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[5];
|
||||
size_t in_index[5];
|
||||
size_t* map_index[5];
|
||||
for (size_t i = 0; i < 5; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
|
||||
{
|
||||
// clang-format off
|
||||
*out++ =
|
||||
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[2] * in_shape[3] * in_shape[4] +
|
||||
*map_index[3] * in_shape[4] +
|
||||
*map_index[4]];
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reshape_in6(const T* in,
|
||||
T* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
size_t size[6];
|
||||
size_t in_index[6];
|
||||
size_t* map_index[6];
|
||||
for (size_t i = 0; i < 6; i++)
|
||||
{
|
||||
size[i] = in_shape[in_axis_order[i]];
|
||||
map_index[in_axis_order[i]] = &in_index[i];
|
||||
}
|
||||
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
|
||||
{
|
||||
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
|
||||
{
|
||||
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
|
||||
{
|
||||
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
|
||||
{
|
||||
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
|
||||
{
|
||||
for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5])
|
||||
{
|
||||
// clang-format off
|
||||
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[2] * in_shape[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[3] * in_shape[4] * in_shape[5] +
|
||||
*map_index[4] * in_shape[5] +
|
||||
*map_index[5]];
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void reshape(const T* in,
|
||||
T* out,
|
||||
void reshape(const char* in,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
switch (in_shape.size())
|
||||
{
|
||||
case 0: reshape_in0<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 1: reshape_in1<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 2: reshape_in2<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 3: reshape_in3<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
|
||||
}
|
||||
}
|
||||
const Shape& out_shape,
|
||||
size_t elem_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -23,9 +23,9 @@
|
||||
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
#include "ngraph/runtime/reference/broadcast.hpp"
|
||||
#include "ngraph/runtime/reference/dot.hpp"
|
||||
#include "ngraph/runtime/reference/reshape.hpp"
|
||||
#include "ngraph/shape_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -116,8 +116,12 @@ namespace ngraph
|
||||
arg0_transpose_vec.reserve(shape_size(arg0_shape));
|
||||
auto axis_vector = get_transpose_order(arg0_shape);
|
||||
swap(wip_arg0_shape[arg0_rank - 1], wip_arg0_shape[arg0_rank - 2]);
|
||||
reference::reshape(
|
||||
arg0, arg0_transpose_vec.data(), arg0_shape, axis_vector, wip_arg0_shape);
|
||||
opt_kernel::reshape(reinterpret_cast<const char*>(arg0),
|
||||
reinterpret_cast<char*>(arg0_transpose_vec.data()),
|
||||
arg0_shape,
|
||||
axis_vector,
|
||||
wip_arg0_shape,
|
||||
sizeof(T));
|
||||
|
||||
arg0_update = arg0_transpose_vec.data();
|
||||
}
|
||||
@ -127,8 +131,12 @@ namespace ngraph
|
||||
arg1_transpose_vec.reserve(shape_size(arg1_shape));
|
||||
auto axis_vector = get_transpose_order(arg1_shape);
|
||||
swap(wip_arg1_shape[arg1_rank - 1], wip_arg1_shape[arg1_rank - 2]);
|
||||
reference::reshape(
|
||||
arg1, arg1_transpose_vec.data(), arg1_shape, axis_vector, wip_arg1_shape);
|
||||
opt_kernel::reshape(reinterpret_cast<const char*>(arg1),
|
||||
reinterpret_cast<char*>(arg1_transpose_vec.data()),
|
||||
arg1_shape,
|
||||
axis_vector,
|
||||
wip_arg1_shape,
|
||||
sizeof(T));
|
||||
|
||||
arg1_update = arg1_transpose_vec.data();
|
||||
}
|
||||
|
57
ngraph/src/ngraph/runtime/reference/reshape.cpp
Normal file
57
ngraph/src/ngraph/runtime/reference/reshape.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/runtime/reference/reshape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
void runtime::reference::reshape(const char* arg,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape,
|
||||
size_t elem_size)
|
||||
{
|
||||
// Unfortunately we don't yet have a constructor for CoordinateTransform that lets
|
||||
// us pass only source_space_shape
|
||||
// and source_axis_order so we have to construct the defaults here.
|
||||
Shape in_start_corner(in_shape.size(), 0); // (0,...0)
|
||||
Strides in_strides(in_shape.size(), 1); // (1,...,1)
|
||||
|
||||
CoordinateTransform input_transform(
|
||||
in_shape, in_start_corner, in_shape, in_strides, in_axis_order);
|
||||
CoordinateTransform output_transform(out_shape);
|
||||
|
||||
NGRAPH_CHECK(shape_size(input_transform.get_target_shape()) ==
|
||||
shape_size(output_transform.get_target_shape()));
|
||||
|
||||
CoordinateTransform::Iterator output_it = output_transform.begin();
|
||||
|
||||
for (const Coordinate& input_coord : input_transform)
|
||||
{
|
||||
const Coordinate& output_coord = *output_it;
|
||||
|
||||
memcpy(out + output_transform.index(output_coord) * elem_size,
|
||||
arg + input_transform.index(input_coord) * elem_size,
|
||||
elem_size);
|
||||
|
||||
++output_it;
|
||||
}
|
||||
}
|
@ -21,6 +21,7 @@
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -28,38 +29,12 @@ namespace ngraph
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void reshape(const T* arg,
|
||||
T* out,
|
||||
void reshape(const char* arg,
|
||||
char* out,
|
||||
const Shape& in_shape,
|
||||
const AxisVector& in_axis_order,
|
||||
const Shape& out_shape)
|
||||
{
|
||||
// Unfortunately we don't yet have a constructor for CoordinateTransform that lets
|
||||
// us pass only source_space_shape
|
||||
// and source_axis_order so we have to construct the defaults here.
|
||||
Shape in_start_corner(in_shape.size(), 0); // (0,...0)
|
||||
Strides in_strides(in_shape.size(), 1); // (1,...,1)
|
||||
|
||||
CoordinateTransform input_transform(
|
||||
in_shape, in_start_corner, in_shape, in_strides, in_axis_order);
|
||||
CoordinateTransform output_transform(out_shape);
|
||||
|
||||
NGRAPH_CHECK(shape_size(input_transform.get_target_shape()) ==
|
||||
shape_size(output_transform.get_target_shape()));
|
||||
|
||||
CoordinateTransform::Iterator output_it = output_transform.begin();
|
||||
|
||||
for (const Coordinate& input_coord : input_transform)
|
||||
{
|
||||
const Coordinate& output_coord = *output_it;
|
||||
|
||||
out[output_transform.index(output_coord)] =
|
||||
arg[input_transform.index(input_coord)];
|
||||
|
||||
++output_it;
|
||||
}
|
||||
}
|
||||
const Shape& out_shape,
|
||||
size_t elem_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/reshape.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
#include "ngraph/runtime/reference/reverse.hpp"
|
||||
#include "ngraph/runtime/reference/slice.hpp"
|
||||
#include "ngraph/slice_plan.hpp"
|
||||
@ -47,11 +47,12 @@ namespace ngraph
|
||||
|
||||
runtime::AlignedBuffer reshape_out_buffer(shape_size(sp.reshape_out_shape) *
|
||||
sizeof(T));
|
||||
reshape<T>(slice_out_buffer.get_ptr<T>(),
|
||||
reshape_out_buffer.get_ptr<T>(),
|
||||
sp.reshape_in_shape,
|
||||
get_default_order(sp.reshape_in_shape.size()),
|
||||
sp.reshape_out_shape);
|
||||
opt_kernel::reshape(slice_out_buffer.get_ptr<char>(),
|
||||
reshape_out_buffer.get_ptr<char>(),
|
||||
sp.reshape_in_shape,
|
||||
get_default_order(sp.reshape_in_shape.size()),
|
||||
sp.reshape_out_shape,
|
||||
sizeof(T));
|
||||
|
||||
reverse<T>(reshape_out_buffer.get_ptr<T>(),
|
||||
out,
|
||||
@ -61,4 +62,4 @@ namespace ngraph
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user