Dynamic shape : remaining changes to run bert (#13306)

* Initial dynamic shape smoke test for GPU

* Bert dynamic runs without crash

* Additional fix to resolve error in bert-large-uncased-whole-word-masking-squad-emb-0001

* Fix error in unfusion function: input nodes of the current (fused) node need to be updated with the latest dependency if they are fuesed to other nodes

* Several fixes
(1) Fix program to clear _kernels after all build_program steps finished
(2) Fix update_kernel not to init_kernel when impl_cache hit
(3) Fix update_kernel to clear kernels_cache::_kernels after adding the new impl to impl_cache
(4) No longer need to remove kernel from kernels_cache::_kernels after the corresponding impl is dropped from impl_cache

* Fix crash of bert_emd_4layer

* Applied review comment

* Applied review comment : fix add_required_reorder

* Fix broadcast to propagate dynamic shape properly & reverted change on constant

* Added a new unfusion unittest

* Fix broadcast single input case to use predefined shape properly

* Fixed count_non_zero output to result only one element of count

* Removed output_layout str for gather_nonzero
Removed unused ov_input_rank for count_nonzero

* Fixed create_host_blob to use usm if the target layout is not dynamic.
(Previous impl has an error when the network is dynamic but the output is static)
Moved dyn shape smoke test under dynamic directory

* Fix lint error
This commit is contained in:
Taylor Yeonbok Lee 2022-10-26 21:20:18 -07:00 committed by GitHub
parent a7a14a89c8
commit 728c9631b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 339 additions and 60 deletions

View File

@ -74,7 +74,7 @@ private:
std::vector<cldnn::event::ptr>& dependencies);
void prepare_output(const cldnn::primitive_id& outputName, InferenceEngine::Blob::Ptr& outputBlob);
InferenceEngine::Blob::Ptr create_host_blob(const InferenceEngine::TensorDesc& desc);
InferenceEngine::Blob::Ptr create_host_blob(const InferenceEngine::TensorDesc& desc, bool is_dynamic);
InferenceEngine::Blob::Ptr create_device_blob(const InferenceEngine::TensorDesc& desc);
void copy_output_data(cldnn::memory::ptr outputMemory, InferenceEngine::Blob::Ptr bptr);

View File

@ -79,11 +79,16 @@ std::vector<layout> broadcast_inst::calc_output_layouts(broadcast_node const& /*
cldnn::mem_lock<uint8_t, mem_lock_type::read> target_shape_lock(target_shape_mem, impl_param.prog.get_stream());
const_data.emplace(1, make_host_tensor(target_shape_mem->get_layout(), target_shape_lock.data()));
ov::op::v3::shape_infer(&op, input_shapes, output_shapes, const_data);
} else {
} else if (impl_param.input_layouts.size() == 1) {
// predefined pattern shape
auto target_shape_tensor = make_host_tensor({pattern_shape, data_types::i64, format::bfyx},
static_cast<void*>(target_shape.data()));
const_data.emplace(1, target_shape_tensor);
ov::op::v3::shape_infer(&op, input_shapes, output_shapes, const_data);
} else {
// Pattern shape is set as second input. Even though the input is scalar, the shape should be propagaterd as dynamic
auto output_rank = input_shapes[0].size();
output_shapes[0] = ShapeType::dynamic(std::max(output_rank, static_cast<size_t>(1)));
}
format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());
@ -121,7 +126,8 @@ std::string broadcast_inst::to_string(broadcast_node const& node) {
broadcast_inst::typed_primitive_inst(network& network, broadcast_node const& node) : parent(network, node) {
auto input_layout = node.input().get_output_layout();
if (input_layout.is_dynamic())
return;
const auto& output_sizes = argument.broadcast_sizes;
std::vector<tensor::value_type> input_dims = input_layout.get_dims();

View File

@ -82,7 +82,10 @@ std::string concatenation_inst::to_string(concatenation_node const& node) {
for (size_t i = 0; i < node.inputs_count(); ++i) {
ss_inputs << node.input(i).id();
ss_inputs << ", count: " << node.input(i).get_output_layout().count();
if (node.input(i).get_output_layout().is_static())
ss_inputs << ", count: " << node.input(i).get_output_layout().count();
else
ss_inputs << ", count: " << "?";
i != (node.inputs_count() - 1) ? ss_inputs << ", " : ss_inputs << "";
}
@ -100,6 +103,8 @@ std::string concatenation_inst::to_string(concatenation_node const& node) {
concatenation_inst::typed_primitive_inst(network& network, concatenation_node const& node)
: parent(network, node) {
auto input_layout = node.input().get_output_layout();
if (input_layout.is_dynamic()) return;
auto output_layout = node.get_output_layout();
tensor::value_type concat_count = 0;

View File

@ -197,7 +197,7 @@ void add_required_reorders::run(program& p) {
if (!correct_layout_selected) {
for (auto new_layout_format : preferred_layout_formats) {
layout current_layout(original_layout.data_type, new_layout_format, original_layout.get_tensor());
layout current_layout(original_layout.get_partial_shape(), original_layout.data_type, new_layout_format);
usr->set_output_layout(current_layout, false);
if (usr->type()->does_possible_implementation_exist(*usr)) {
correct_layout_selected = true;
@ -210,21 +210,19 @@ void add_required_reorders::run(program& p) {
// goal of this section is to use int32 implementation
// if int64 is not available for usr primitive
if (original_layout.data_type == data_types::i64) {
layout original_layout_i32(data_types::i32,
original_layout.format,
original_layout.get_tensor());
layout original_layout_i32(original_layout.get_partial_shape(),
data_types::i32,
original_layout.format);
usr->set_output_layout(original_layout_i32, false);
if (usr->type()->does_possible_implementation_exist(*usr)) {
correct_layout_selected = true;
}
if (!correct_layout_selected) {
for (auto new_layout_format : preferred_layout_formats) {
layout current_layout_i32(original_layout_i32.data_type,
new_layout_format,
original_layout_i32.get_tensor());
layout current_layout_i32(original_layout_i32.get_partial_shape(),
original_layout_i32.data_type,
new_layout_format);
usr->set_output_layout(current_layout_i32, false);
if (usr->type()->does_possible_implementation_exist(*usr)) {
correct_layout_selected = true;
@ -232,7 +230,6 @@ void add_required_reorders::run(program& p) {
}
}
}
if (!correct_layout_selected) {
throw std::runtime_error("Internal Error: no implementation for " + usr->id() +
" kernel which satisfies output format dependecies.");

View File

@ -120,7 +120,7 @@ void concat_input_order::run(program& p) {
// 4. Not already aligned
// 5. Users can accept shuffled features
// 6. No fused primitives
if (!node->is_type<concatenation>() || node->is_output())
if (!node->is_type<concatenation>() || node->is_output() || node->is_dynamic())
continue;
auto& concat_node = node->as<concatenation>();

View File

@ -311,7 +311,7 @@ void prepare_buffer_fusing::run(program& p) {
If crop is before concat there can be padding mismtach, since concat changes padding.
*/
auto can_optimize = [](const program_node* node) {
if (node->is_output() || (!node->get_fused_activations_funcs().empty())) {
if (node->is_dynamic() || node->is_output() || (!node->get_fused_activations_funcs().empty())) {
return false;
}
return true;

View File

@ -30,8 +30,6 @@ struct count_nonzero_impl : typed_primitive_impl_ocl<count_nonzero> {
auto nonzero_optional_params =
get_default_optional_params<kernel_selector::count_nonzero_optional_params>(arg.get_program());
nonzero_params.ov_input_rank = impl_param.get_input_layout().get_shape().size();
auto& kernel_selector = kernel_selector::count_nonzero_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(nonzero_params, nonzero_optional_params);

View File

@ -35,6 +35,8 @@ class typed_primitive_inst<count_nonzero> : public typed_primitive_inst_base<cou
using parent = typed_primitive_inst_base<count_nonzero>;
public:
template <typename ShapeType>
static std::vector<layout> calc_output_layouts(count_nonzero_node const& /*node*/, kernel_impl_params const& impl_param);
static layout calc_output_layout(count_nonzero_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(count_nonzero_node const& node);
@ -62,6 +64,7 @@ public:
OPENVINO_ASSERT(dependencies.size() == 2, "[GPU] Primitive ", id(), " has invalid number of depndencies");
return get_dependency(index);
}
std::vector<size_t> get_shape_infer_dependencies() const override { return {1}; }
};
using gather_nonzero_node = typed_program_node<gather_nonzero>;
@ -71,6 +74,8 @@ class typed_primitive_inst<gather_nonzero> : public typed_primitive_inst_base<ga
using parent = typed_primitive_inst_base<gather_nonzero>;
public:
template <typename ShapeType>
static std::vector<layout> calc_output_layouts(gather_nonzero_node const& /*node*/, kernel_impl_params const& impl_param);
static layout calc_output_layout(gather_nonzero_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(gather_nonzero_node const& node);

View File

@ -366,7 +366,7 @@ protected:
private:
bool do_allocate_memory(typed_node const& typ_node) {
if (typ_node.is_dynamic())
if (typ_node.get_output_layout().is_dynamic())
return false;
if (typ_node.template have_user_with_type<concatenation>() && typ_node.get_users().size() == 1 &&

View File

@ -19,6 +19,7 @@
#include "intel_gpu/graph/network.hpp"
#include "assign_inst.h"
#include "read_value_inst.h"
#include "reshape_inst.h"
#include "to_string_utils.h"
#include "primitive_inst.h"
@ -994,6 +995,12 @@ void network::transfer_memory_to_device(std::shared_ptr<primitive_inst> instance
auto& inst_mem = instance->output_memory();
auto alloc_type = inst_mem.get_allocation_type();
auto users = node.get_users();
if (users.size() == 1
&& users.front()->is_type<reshape>()
&& users.front()->is_dynamic())
return;
// Do not transfer memory if a user requires lockable memory.
// If memory is used in both gpu and cpu implementations, primitive itself is responsible for correct allocation type
if (node.need_lockable_memory())

View File

@ -24,7 +24,14 @@ primitive_type_id count_nonzero::type_id() {
layout count_nonzero_inst::calc_output_layout(count_nonzero_node const& node, kernel_impl_params const& impl_param) {
assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
"Output data type forcing is not supported for count_nonzero_node!");
return layout{cldnn::data_types::i32, cldnn::format::bfyx, tensor{1, 1, 1, 4}};
return layout{cldnn::data_types::i32, cldnn::format::bfyx, tensor{1, 1, 1, 1}};
}
template<typename ShapeType>
std::vector<layout> count_nonzero_inst::calc_output_layouts(count_nonzero_node const& /*node*/, kernel_impl_params const& impl_param) {
assert(static_cast<bool>(impl_param.desc->output_data_type) == false &&
"Output data type forcing is not supported for count_nonzero_node!");
return {layout{ov::PartialShape{1}, cldnn::data_types::i32, cldnn::format::bfyx}};
}
std::string count_nonzero_inst::to_string(count_nonzero_node const& node) {
@ -71,6 +78,25 @@ layout gather_nonzero_inst::calc_output_layout(gather_nonzero_node const& node,
}
}
template<typename ShapeType>
std::vector<layout> gather_nonzero_inst::calc_output_layouts(gather_nonzero_node const& /*node*/, kernel_impl_params const& impl_param) {
auto desc = impl_param.typed_desc<gather_nonzero>();
assert(static_cast<bool>(desc->output_data_type) == false &&
"Output data type forcing is not supported for gather_nonzero_node!");
if (impl_param.memory_deps.count(1)) {
auto out_size = read_vector<int64_t>(impl_param.memory_deps.at(1), impl_param.prog.get_stream());
// output shape of nonzero is [input_rank, count_non_zero]
auto rank = static_cast<size_t>(impl_param.get_input_layout(0).get<ShapeType>().rank().get_length());
auto count = static_cast<size_t>(out_size[0]);
ov::Shape output_shape({rank, count});
ov::PartialShape output_pshape(output_shape);
auto out_layout = layout{output_pshape, cldnn::data_types::i32, cldnn::format::bfyx};
return {out_layout};
} else {
return {layout{ov::PartialShape({ov::Dimension::dynamic(), ov::Dimension::dynamic()}), cldnn::data_types::i32, cldnn::format::bfyx}};
}
}
std::string gather_nonzero_inst::to_string(gather_nonzero_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
@ -80,7 +106,6 @@ std::string gather_nonzero_inst::to_string(gather_nonzero_node const& node) {
json_composite gather_nonzero_info;
gather_nonzero_info.add("input id", input.id());
gather_nonzero_info.add("output layout", node.get_output_layout().to_string());
node_info->add("gather_nonzero info", gather_nonzero_info);
node_info->dump(primitive_description);

View File

@ -154,9 +154,8 @@ void primitive_inst::update_shape() {
}
}
// We assume that tensor ranks are static, thus shape_of doesn't need to update anything even if input shape is dynamic
if (_node.is_type<shape_of>() && !input_shape_changed)
return;
if (input_shape_changed)
set_shape_change();
// Even though the predecessors' shapes are not changed, the output shape might be udpated by the mem_dep
auto memory_deps = _node.get_const_memory_deps();
@ -167,6 +166,10 @@ void primitive_inst::update_shape() {
input_shape_changed = true;
}
// We assume that tensor ranks are static, thus shape_of doesn't need to update anything even if input shape is dynamic
if (_node.is_type<shape_of>() && !input_shape_changed)
return;
// Strided slice loads data from {1,2,3} dependencies in impl::create method.
// It means that this data must be put into impl_params map
// Thus we treat it as "dynamic" case
@ -182,9 +185,6 @@ void primitive_inst::update_shape() {
if (!strided_slice_wa && !input_shape_changed && !_node.generates_dynamic_output() && _impl_params->output_layout.is_static())
return;
if (input_shape_changed)
set_shape_change();
std::vector<event::ptr> dependencies_events;
auto queue_type = get_network().get_stream().get_queue_type();
bool has_runtime_deps = false;
@ -296,14 +296,11 @@ void primitive_inst::update_impl() {
} else {
auto lru = cache.get_lru_element();
_impl = _node.type()->choose_impl(_node, *_impl_params);
bool lru_popped = cache.add(layout_key, _impl->clone());
if (lru_popped) {
for (auto& id : lru->get_kernel_ids())
_network.get_program()->remove_kernel(id);
}
_network.get_program()->compile();
_impl->init_kernels(_network.get_program()->get_kernels_cache());
cache.add(layout_key, _impl->clone());
_network.get_program()->get_kernels_cache().reset();
}
_impl->init_kernels(_network.get_program()->get_kernels_cache());
reset_shape_change();
GPU_DEBUG_GET_INSTANCE(debug_config);
@ -796,14 +793,21 @@ cldnn::network::ptr primitive_inst::get_unfused_subgraph() {
// which doesn't exist anymore in the graph
// Thus we update dependency name used dependencies idx stored in fused descriptor.
if (std::find(dep_ids.begin(), dep_ids.end(), in) == dep_ids.end()) {
size_t dep_id = fd.dep_start_idx + i;
size_t dep_id = fd.dep_start_idx;
in = _node.get_dependency(dep_id).id();
}
}
t.add_primitive(prim);
dep_ids.push_back(prim->id);
}
// Samely, need to update dependency of the current fused nodes' input primitive ids with those in the current program
auto prim_of_fused_node = std::const_pointer_cast<primitive>(_impl_params->desc);
for (size_t i = 0; i < prim_of_fused_node->input.size(); ++i) {
auto& in = prim_of_fused_node->input[i];
if (std::find(dep_ids.begin(), dep_ids.end(), in) == dep_ids.end()) {
in = _node.get_dependency(i).id();
}
}
build_options bo;
bo.set_option(build_option::allow_static_input_reorder(true));
bo.set_option(build_option::allow_new_shape_infer(true));

View File

@ -702,6 +702,7 @@ void program::cleanup() {
}
}
}
_kernels_cache->reset();
}
void program::add_split_outputs() {

View File

@ -146,7 +146,8 @@ std::string reshape_inst::to_string(reshape_node const& node) {
return primitive_description.str();
}
reshape_inst::typed_primitive_inst(network& network, reshape_node const& node) : parent(network, node, false) {
reshape_inst::typed_primitive_inst(network& network, reshape_node const& node) :
parent(network, node, (!node.can_be_optimized() && node.get_output_layout().is_static()) ? true : false) {
auto input_layout = node.input().get_output_layout();
auto output_layout = node.get_output_layout();
CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(),
@ -155,7 +156,7 @@ reshape_inst::typed_primitive_inst(network& network, reshape_node const& node) :
"output layout data type",
output_layout.data_type,
"");
if (output_layout.is_static())
if (output_layout.is_static() && input_layout.is_static())
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Output layout count",
output_layout.count(),
@ -165,7 +166,7 @@ reshape_inst::typed_primitive_inst(network& network, reshape_node const& node) :
// if reshape operated in-place, postpone creation of the output until network run,
// then create new memory object as the reinterpreted output of the previous primitive
if (_node.get_output_layout().is_static()) {
if (input_layout.is_static() && output_layout.is_static()) {
if (!node.can_be_optimized())
_outputs = allocate_outputs();
else

View File

@ -14,6 +14,11 @@ KERNEL (count_nonzero_ref)(const __global INPUT0_TYPE* input,
const uint gdim1 = (uint)get_global_id(1);
const uint gdim2 = (uint)get_global_id(2);
if (gdim0 == 0 && gdim1 == 0 && gdim2 == 0) {
output[0] = 0;
}
barrier(CLK_GLOBAL_MEM_FENCE);
#if INPUT0_DIMS == 6
#define INPUT_ORDER b,f,w,z,y,x
const uint x = gdim0 % INPUT0_SIZE_X;
@ -37,13 +42,7 @@ KERNEL (count_nonzero_ref)(const __global INPUT0_TYPE* input,
count = sub_group_reduce_add(count);
if (get_sub_group_local_id() == 0)
atomic_add(&(output[1]), count);
if (gdim0 == 0 && gdim1 == 0 && gdim2 == 0) {
output[0] = OV_INPUT_RANK;
output[2] = 1;
output[3] = 1;
}
atomic_add(&(output[0]), count);
}
#undef INPUT0_GET_INDEX1

View File

@ -23,7 +23,7 @@ KERNEL (gather_nonzero_ref)(const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* out_mem = output;
#endif
int count_nzero = output_shape[1];
int count_nzero = output_shape[0];
#if OV_INPUT_RANK == 1 // b
#define ADD_IDXS \
int b = input_idx_v / INPUT0_BATCH_PITCH; \
@ -116,6 +116,7 @@ KERNEL (gather_nonzero_ref)(const __global INPUT0_TYPE* input,
}
}
}
// leftovers
for (;input_idx < TOTAL_DATA_SIZE; ++input_idx) {
int input_idx_v = input_idx;
@ -124,7 +125,6 @@ KERNEL (gather_nonzero_ref)(const __global INPUT0_TYPE* input,
ADD_IDXS;
}
}
#ifdef USE_LOCAL_MEM
// write back to global mem
int local_out_iter = 0;

View File

@ -41,7 +41,6 @@ KernelsData CountNonzeroKernelRef::GetKernelsData(const Params& params, const op
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
auto cldnn_jit = MakeBaseParamsJitConstants(newParams);
cldnn_jit.AddConstant(MakeJitConstant("OV_INPUT_RANK", newParams.ov_input_rank));
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
const auto& in = newParams.inputs[0];

View File

@ -550,7 +550,7 @@ void InferRequest::wait() {
auto layout = layout_by_rank(out_rank);
auto tensorDesc = InferenceEngine::TensorDesc(precision, dims, layout);
if (_outputs.find(no.first) == _outputs.end()) {
_outputs[no.first] = create_host_blob(tensorDesc);
_outputs[no.first] = create_host_blob(tensorDesc, false);
} else {
_outputs[no.first]->setShape(dims);
}
@ -592,11 +592,11 @@ void InferRequest::setup_stream_graph() {
m_graph = streamGraphs[streamID];
}
Blob::Ptr InferRequest::create_host_blob(const TensorDesc& desc) {
Blob::Ptr InferRequest::create_host_blob(const TensorDesc& desc, bool is_dynamic) {
OV_ITT_SCOPED_TASK(itt::domains::intel_gpu_plugin, "InferRequest::create_host_blob");
// Disable USM usage as USMHostAllocator may fail for attempt to allocate 0 bytes
// If we add WA for such case to avoid driver call, then deallocate method will return false and Blob::setShape call will throw an exception
bool use_usm = m_graph->GetEngine()->use_unified_shared_memory() && !m_graph->GetNetwork()->is_dynamic();
bool use_usm = m_graph->GetEngine()->use_unified_shared_memory() && !is_dynamic;
auto alloc = use_usm ? std::make_shared<USMHostAllocator>(m_graph->GetContext().get()) : CreateDefaultAllocator();
auto blob = make_blob_with_precision(desc, alloc);
blob->allocate();
@ -727,11 +727,11 @@ void InferRequest::allocate_inputs() {
if (desc.getPrecision() == Precision::I16 || desc.getPrecision() == Precision::U16) {
TensorDesc desc_fp32 = desc;
desc_fp32.setPrecision(Precision::FP32);
_inputs[name] = create_host_blob(desc);
_inputs[name] = create_host_blob(desc, input_layout.is_dynamic());
if (input_layout.is_static())
_deviceInputs[name] = create_device_blob(desc_fp32);
} else {
_inputs[name] = create_host_blob(desc);
_inputs[name] = create_host_blob(desc, input_layout.is_dynamic());
if (input_layout.is_static()) {
if (m_graph->GetEngine()->use_unified_shared_memory()) {
// For USM case we create host blob using custom USM host allocator
@ -777,11 +777,11 @@ void InferRequest::allocate_outputs() {
else
device_blob_desc.setPrecision(Precision::FP32);
_outputs[no.first] = create_host_blob(desc);
_outputs[no.first] = create_host_blob(desc, output_layout.is_dynamic());
if (output_layout.is_static())
_deviceOutputs[no.first] = create_device_blob(device_blob_desc);
} else {
_outputs[no.first] = create_host_blob(desc);
_outputs[no.first] = create_host_blob(desc, output_layout.is_dynamic());
if (output_layout.is_static()) {
if (m_graph->GetEngine()->use_unified_shared_memory()) {
// For USM case we create host blob using custom USM host allocator

View File

@ -317,7 +317,6 @@ std::shared_ptr<cldnn::program> Program::BuildProgram(const std::vector<std::sha
if (!m_config.graph_dumps_dir.empty()) {
options.set_option(cldnn::build_option::graph_dumps_dir(m_config.graph_dumps_dir));
}
for (const auto& op : ops) {
if (op->is_dynamic()) {
allow_new_shape_infer = true;

View File

@ -304,3 +304,72 @@ TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_1) {
ASSERT_EQ(lock[2], 285 + 30);
ASSERT_EQ(lock[3], 285 + 40);
}
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_2) {
auto& engine = get_test_engine();
auto weights0 = engine.allocate_memory({ ov::PartialShape{ 2, 10 }, data_types::i8, format::bfyx });
auto weights1 = engine.allocate_memory({ ov::PartialShape{ 1, 2 }, data_types::i8, format::bfyx });
auto in_layout = layout{ ov::PartialShape::dynamic(2), data_types::i8, format::bfyx };
auto in_eltw_layout = layout{ ov::PartialShape::dynamic(2), data_types::f32, format::bfyx };
set_values<uint8_t>(weights0, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
set_values<uint8_t>(weights1, {1, 1});
// The topology below is intended to check the following tricky things:
// 1. Cases where original eltw input is also optimized (act_e2 is fused into act_e1)
// 1. There is another layers in fusion pattern (activations before & after eltwise)
// 1. Also, the input (act_fc1) of the fused node of the eltw (i.e., fc2) is fused to other node (fc1)
topology topology;
topology.add(data("weights0", weights0));
topology.add(data("weights1", weights1));
topology.add(input_layout("input", in_layout));
topology.add(fully_connected("fc1", "input", { "weights0" }, "", data_types::i8));
topology.add(activation("act_fc1", "fc1", activation_func::relu));
topology.add(fully_connected("fc2", "act_fc1", { "weights1" }, "", data_types::i8));
topology.add(activation("act_fc2", "fc2", activation_func::relu));
topology.add(input_layout("extra_input", in_eltw_layout));
topology.add(activation("act_e1", "extra_input", activation_func::abs));
topology.add(activation("act_e2", "act_e1", activation_func::relu));
topology.add(eltwise("eltw", {"act_fc2", "act_e2"}, eltwise_mode::sum));
topology.add(activation("act_fc3", "eltw", activation_func::relu));
topology.add(reorder("reorder", "act_fc3", format::bfyx, data_types::f32));
build_options build_opts;
build_opts.set_option(build_option::optimize_data(true));
build_opts.set_option(build_option::allow_new_shape_infer(true));
auto prog = program::build_program(engine, topology, build_opts, false, true);
layout_optimizer lo(true);
program_wrapper::apply_opt_pass<prepare_primitive_fusing>(*prog, lo);
ASSERT_NE(prog, nullptr);
ASSERT_FALSE(has_node_with_type<eltwise>(*prog));
cldnn::network net(prog, 0);
auto input_memory = engine.allocate_memory(layout{ ov::PartialShape{1, 10}, data_types::i8, format::bfyx });
auto extra_input_memory = engine.allocate_memory(layout{ ov::PartialShape{4, 4}, data_types::f32, format::bfyx });
set_values<int8_t>(input_memory, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -10});
set_values<float>(extra_input_memory, {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
net.set_input_data("input", input_memory);
net.set_input_data("extra_input", extra_input_memory);
auto output = net.execute();
auto out_mem = output.at("reorder").get_memory();
ASSERT_NE(out_mem, nullptr);
ASSERT_EQ(out_mem->count(),16);
ASSERT_EQ(out_mem->size(), 16 * sizeof(float));
mem_lock<float> lock(out_mem, net.get_stream());
ASSERT_EQ(lock[0], 91);
ASSERT_EQ(lock[1], 92);
ASSERT_EQ(lock[2], 93);
ASSERT_EQ(lock[3], 94);
}

View File

@ -34,7 +34,7 @@ void test_count_non_zero(layout in_layout, std::vector<T> in_data) {
auto output = outputs.at("count_nonzero").get_memory();
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
EXPECT_EQ(count_non_zero, output_ptr[1]);
EXPECT_EQ(count_non_zero, output_ptr[0]);
}
TEST(test_count_non_zero, 4d_fp32_1_2_1_5) {
@ -45,7 +45,7 @@ TEST(test_count_non_zero, 4d_fp32_1_2_1_5) {
test_count_non_zero<float>(layout{ov::PartialShape{1, 2, 1, 5}, data_types::f32, format::bfyx}, in_data);
}
TEST(test_gather_non_zero, 5d_fp16_1_3_2_1_2) {
TEST(test_count_non_zero, 5d_fp16_1_3_2_1_2) {
std::vector<FLOAT16> in_data = {
0.1f, 0.2f, 0.3f, 0.0f, 12.1f, 11.1f,
0.0f, 0.0f, 0.1f, 0.9f, 0.10f, 0.001f
@ -62,11 +62,11 @@ void test_gather_non_zero(layout in_layout, std::vector<T> in_data) {
std::vector<int32_t> expected_results(count_non_zero * in_rank);
ngraph::runtime::reference::non_zero<T, int32_t>(in_data.data(), expected_results.data(), in_layout.get_shape());
auto output_shape_layout = layout{ov::PartialShape{4}, data_types::i32, format::bfyx};
auto output_shape_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
auto output_shape_mem = engine.allocate_memory(output_shape_layout);
set_values(input_mem, in_data);
std::vector<int32_t> output_shape_data = {(int32_t)in_rank, (int32_t)count_non_zero, 1, 1};
std::vector<int32_t> output_shape_data = {(int32_t)count_non_zero};
set_values(output_shape_mem, output_shape_data);

View File

@ -0,0 +1,164 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "shared_test_classes/single_layer/shape_of.hpp"
#include "shared_test_classes/single_layer/strided_slice.hpp"
#include <shared_test_classes/single_layer/eltwise.hpp>
#include <common_test_utils/ov_tensor_utils.hpp>
using namespace ngraph;
using namespace InferenceEngine;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
typedef std::tuple<
std::vector<InputShape>, // input shapes
ElementType, // Network precision
TargetDevice, // Device name
std::map<std::string, std::string> // Additional network configuration
> shapeOfReshapeReduceDynamicGPUTestParamsSet;
const std::vector<ElementType> netPrecisions = {
ElementType::f16,
ElementType::f32,
ElementType::i32,
ElementType::i64,
};
class ShapeOfReshapeReduceDynamicGPUTest : public testing::WithParamInterface<shapeOfReshapeReduceDynamicGPUTestParamsSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<shapeOfReshapeReduceDynamicGPUTestParamsSet>& obj) {
shapeOfReshapeReduceDynamicGPUTestParamsSet basicParamsSet = obj.param;
std::ostringstream result;
std::vector<InputShape> inputShapes;
ElementType netType;
TargetDevice targetDevice;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, netType, targetDevice, additionalConfig) = basicParamsSet;
result << "IS=";
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
for (const auto& actual_shape : shape.second) {
result << CommonTestUtils::partialShape2str({actual_shape}) << "_";
}
}
result << "NetType=" << netType << "_";
result << "targetDevice=" << targetDevice;
return result.str();
}
protected:
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(),
targetInputStaticShapes[i],
80,
0,
8);
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
void SetUp() override {
shapeOfReshapeReduceDynamicGPUTestParamsSet basicParamsSet = this->GetParam();
std::vector<InputShape> inputShapes;
ElementType netType;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, netType, targetDevice, additionalConfig) = basicParamsSet;
init_input_shapes(inputShapes);
const auto inShapeShapeOf = inputDynamicShapes[0];
const auto inShapeElt = inputDynamicShapes[1];
auto params = builder::makeDynamicParams(netType, {inShapeShapeOf, inShapeElt});
auto paramOuts = helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::opset3::Parameter>(params));
auto addOp = ngraph::builder::makeEltwise(paramOuts[1], paramOuts[1], ngraph::helpers::EltwiseTypes::ADD);
addOp->set_friendly_name("add");
auto shapeOfOp1 = std::make_shared<ngraph::opset3::ShapeOf>(paramOuts[0], ElementType::i64);
shapeOfOp1->set_friendly_name("shapeof1");
std::vector<int> reduce_axes = {0};
auto reduceAxesNode = std::dynamic_pointer_cast<ngraph::Node>(
std::make_shared<ngraph::opset3::Constant>(ngraph::element::Type_t::i64, ngraph::Shape({1}), reduce_axes));
auto reduceOp = ngraph::builder::makeReduce(shapeOfOp1, reduceAxesNode, true, ngraph::helpers::ReductionType::Prod);
reduceOp->set_friendly_name("reduce");
std::vector<int64_t> shapePatternFill = {-1};
auto reshapePatternComp = std::make_shared<ngraph::opset3::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{1}, shapePatternFill);
auto concatOp = ngraph::builder::makeConcat({reduceOp, reshapePatternComp}, 0);
concatOp->set_friendly_name("concat");
auto reshapeOp = std::make_shared<ngraph::opset1::Reshape>(addOp, concatOp, false);
auto shapeOf2 = std::make_shared<ngraph::opset3::ShapeOf>(reshapeOp, ElementType::i64);
shapeOf2->set_friendly_name("shapeof2");
ngraph::ResultVector results = {std::make_shared<ngraph::opset1::Result>(shapeOf2)};
function = std::make_shared<ngraph::Function>(results, params, "shapeof_out");
}
};
TEST_P(ShapeOfReshapeReduceDynamicGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
std::map<std::string, std::string> emptyAdditionalConfig;
const std::vector<std::vector<ov::test::InputShape>> dynInputShapes = {
// 1D
{
// Input for ShapeOf
{{ov::Dimension::dynamic()}, {{30}, {40}, {50}}},
// Input for Add
{{ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{3, 10}, {2, 20}, {25, 2}}}
},
// 2D
{
// Input for ShapeOf
{{ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{1, 10}, {2, 20}}},
// Input for Add
{{ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{1, 1, 10}, {2, 10, 2}}}
},
// 3D
{
// Input for ShapeOf
{{ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{1, 10, 4}, {1, 4, 12}}},
// Input for Add
{{ov::Dimension::dynamic()}, {{1, 10, 4}, {2, 2, 12}}}
},
// 4D
{
// Input for ShapeOf
{{ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{3, 1, 10, 4}, {2, 4, 23, 12}}},
// Input for Add
{{ov::Dimension::dynamic(), ov::Dimension::dynamic()}, {{30, 4}, {24, 92}}}
}
};
const auto testParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
::testing::ValuesIn(netPrecisions), // netprec
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_shapeof_reshape, ShapeOfReshapeReduceDynamicGPUTest,
testParams_smoke, ShapeOfReshapeReduceDynamicGPUTest::getTestCaseName);
} // namespace
} // namespace GPULayerTestsDefinitions