[GPU] Fix get_default_params & choose_impl not to dependent on program_node (#12239)

* Getting rid of dependency from get_default_param for typed_program_node

* Fix bug

* Enable two pathes to call choose_impl / does_possible_impl_exists / does_an_impl_exists to be able to use given layout

* Replaced impl factory API to get kernel_impl_param's pointer

* Update for recently added primitives

* Add and apply optional_layout

* fix kernel_param_impl to be handled as unique_ptr

* Applied review comments

* Fix rebase conflict

* Fix CI error
This commit is contained in:
Taylor Yeonbok Lee
2022-07-27 19:36:53 +09:00
committed by GitHub
parent 101e1ea5ad
commit 361ca2078d
118 changed files with 844 additions and 595 deletions

View File

@@ -68,6 +68,7 @@ private:
storage_type storage;
};
/// Converts C++ type to @ref data_types .
template <typename T>
struct type_to_data_type;
@@ -429,6 +430,38 @@ private:
tensor size;
};
class optional_layout {
public:
optional_layout() {}
optional_layout(const layout& lay) {
this->opt_layout_ptr = make_unique<layout>(lay);
}
optional_layout(const optional_layout& new_opt_lay) {
if (new_opt_lay) {
layout copied_lay = *new_opt_lay;
this->opt_layout_ptr = make_unique<layout>(copied_lay);
}
}
operator bool() const {
return this->opt_layout_ptr != nullptr;
}
layout operator*() const {
if (opt_layout_ptr == nullptr)
throw std::runtime_error("Attempt to access uninitialized optional layout!");
return *this->opt_layout_ptr;
}
std::unique_ptr<layout>& get_layout_ptr() {
return opt_layout_ptr;
}
private:
std::unique_ptr<layout> opt_layout_ptr = nullptr;
};
/// @}
/// @}
} // namespace cldnn

View File

@@ -25,8 +25,7 @@ void add_onednn_optimization_attributes::run(program& p) {
// Reshape fused ops tensors for OneDNN FC if needed
if (fc_prim->input_size == 3) {
for (auto& fused_prim : node->get_fused_primitives()) {
auto fused_node = fused_prim.node;
if (fused_node->is_type<eltwise>()) {
if (fused_prim.is_type<eltwise>()) {
auto& dependency = node->get_dependency(fused_prim.dep_start_idx);
auto original_layout = dependency.get_output_layout();
onednn::combine_bf_with_first_spatial_dim(original_layout);

View File

@@ -41,7 +41,7 @@ void basic_memory_dependencies::run(program& p) {
&& (node->is_type<convolution>() || node->is_type<deconvolution>())) {
size_t eltw_dep = 0;
for (auto& fused_op : node->get_fused_primitives()) {
if (fused_op.node->is_type<eltwise>() && fused_op.deps.size() == 1) {
if (fused_op.is_type<eltwise>() && fused_op.deps.size() == 1) {
// If it is first sum, reuse the buffer
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(*node, fused_op);
if (fusing_type != add_fusing_type::sum || eltw_dep != 0)

View File

@@ -674,10 +674,10 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
auto& fused_descs = input_data.get_fused_primitives();
auto origin_input_iter = std::find_if(fused_descs.begin(), fused_descs.end(),
[&](cldnn::fused_primitive_desc& desc) {
return (desc.node->id() == prim_id.first);
return (desc.desc->id == prim_id.first);
});
if (origin_input_iter != fused_descs.end()) {
auto users = get_users_from_fusing_history(origin_input_iter->node->id());
auto users = get_users_from_fusing_history(origin_input_iter->desc->id);
if (users.size() != 1) {
return false;
}
@@ -1167,10 +1167,10 @@ void prepare_primitive_fusing::optimize_fused_ops(program& p) {
auto remove_deps_of_node = [&](cldnn::fused_primitive_desc& desc) {
for (auto& prim : fused_prims) {
if (desc.node->id() == prim.node->id()) {
if (desc.desc->id == prim.desc->id) {
continue;
}
auto rm_iter = prim.fused_deps.find(desc.node->id());
auto rm_iter = prim.fused_deps.find(desc.desc->id);
if (rm_iter != prim.fused_deps.end()) {
prim.fused_deps.erase(rm_iter);
prim.fused_deps.insert(desc.fused_deps.begin(), desc.fused_deps.end());
@@ -1187,16 +1187,13 @@ void prepare_primitive_fusing::optimize_fused_ops(program& p) {
auto& fp = *curr_itr;
auto& fp_next = *fp_itr;
if (fp.is_type<activation>() && fp_next.is_type<quantize>()) {
const auto& act_prim = fp.typed_desc<activation>();;
const auto& quant_param = fp_next.get_typed_fuse_params<kernel_selector::quantize_fuse_params>();
if (fp.node->is_type<activation>() && fp_next.node->is_type<quantize>()) {
auto& activation_node = fp.node->as<activation>();
auto& quantize_node = fp_next.node->as<quantize>();
bool can_skip = activation_node.get_primitive()->activation_function == activation_func::relu &&
activation_node.get_primitive()->additional_params.a == 0.0f &&
fp.deps.empty() &&
data_type_traits::is_i8_u8(quantize_node.get_output_layout().data_type) &&
quantize_node.get_scale_shift_opt() &&
!quantize_node.get_need_pre_shift();
bool can_skip = fp.deps.empty() && data_type_traits::is_i8_u8(fp_next.output_layout.data_type);
can_skip &= ((act_prim->activation_function == activation_func::relu) && (act_prim->additional_params.a == 0.0f));
can_skip &= (quant_param->scale_shift_opt && !quant_param->has_pre_shift);
if (can_skip) {
remove_deps_of_node(fp);

View File

@@ -485,14 +485,14 @@ void remove_redundant_reorders::run(program& p) {
input.set_output_padding(node->get_output_layout().data_padding);
// Add fused_primitive_desc of reorder to convolution which propagate original output layout to jitter
fused_primitive_desc local_desc;
local_desc.node = p.get_node_ptr(node->id());
fused_primitive_desc local_desc(node->get_primitive());
local_desc.input_layout = input.get_dependency(0).get_output_layout(); // original convolution's output layout
node->set_input_layout(local_desc.input_layout);
local_desc.f_param = node->get_fuse_params();
local_desc.dep_start_idx = input.get_fused_primitives().size();
local_desc.output_layout = output_layout;
local_desc.input_layout = input.get_dependency(0).get_output_layout(); // original convolution's output layout
local_desc.activation = activation_func::none;
input.add_fused_primitive(local_desc);
node->set_input_layout(local_desc.input_layout);
// remove reorder node
LOG_NODE_REMOVAL(node->id());

View File

@@ -695,7 +695,7 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
// changes the input format of eltwise sum post-op to use binary add.
if (conv_node.get_preferred_impl_type() == impl_types::onednn) {
onednn_add_fusing_helpers::for_eltwise(conv_node, eltwise_mode::sum,
[&](const program_node& p_node, const eltwise_node& e_node, const fused_primitive_desc& desc) {
[&](const program_node& p_node, const fused_primitive_desc& desc) {
auto fusing_type = onednn_add_fusing_helpers::get_add_fusing_type(p_node, desc);
if (fusing_type == add_fusing_type::binary_per_tensor) {
auto& dep_node = p_node.get_dependency(desc.dep_start_idx);

View File

@@ -41,7 +41,7 @@ struct condition_impl : typed_primitive_impl<condition> {
return ev;
}
static primitive_impl* create(const condition_node& arg) { return new condition_impl(arg); }
static primitive_impl* create(const condition_node& arg, std::shared_ptr<kernel_impl_params>) { return new condition_impl(arg); }
void init_kernels() override {}

View File

@@ -164,7 +164,7 @@ struct loop_impl : typed_primitive_impl<loop> {
return ev;
}
static primitive_impl* create(const loop_node& arg) { return new loop_impl(arg); }
static primitive_impl* create(const loop_node& arg, std::shared_ptr<kernel_impl_params>) { return new loop_impl(arg); }
};
namespace detail {

View File

@@ -32,13 +32,13 @@ public:
bool validate(const primitive_inst&) const override { return true; }
static primitive_impl* create_data(const data_node& data) { return new wait_for_events_impl(data); }
static primitive_impl* create_data(const data_node& data, std::shared_ptr<kernel_impl_params>) { return new wait_for_events_impl(data); }
static primitive_impl* create_input_layout(const input_layout_node& input) {
static primitive_impl* create_input_layout(const input_layout_node& input, std::shared_ptr<kernel_impl_params>) {
return new wait_for_events_impl(input);
}
static primitive_impl* create_prior_box(const prior_box_node& prior_box) {
static primitive_impl* create_prior_box(const prior_box_node& prior_box, std::shared_ptr<kernel_impl_params>) {
// This primitive is being executed on CPU during network compilation.
return new wait_for_events_impl(prior_box);
}

View File

@@ -38,7 +38,9 @@ struct assign_impl : public typed_primitive_impl<assign> {
void init_kernels() override {}
public:
static primitive_impl* create(assign_node const& arg) { return new assign_impl{}; }
static primitive_impl* create(const assign_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
return new assign_impl{};
}
};

View File

@@ -833,7 +833,7 @@ struct detection_output_impl : typed_primitive_impl<detection_output> {
void init_kernels() override {}
static primitive_impl* create(const detection_output_node& arg) { return new detection_output_impl(arg); }
static primitive_impl* create(const detection_output_node& arg, std::shared_ptr<kernel_impl_params>) { return new detection_output_impl(arg); }
};
namespace detail {

View File

@@ -401,7 +401,7 @@ struct non_max_suppression_impl : typed_primitive_impl<non_max_suppression> {
return ev;
}
static primitive_impl* create(const non_max_suppression_node&) {
static primitive_impl* create(const non_max_suppression_node&, std::shared_ptr<kernel_impl_params>) {
return new non_max_suppression_impl();
}
void init_kernels() override {}

View File

@@ -427,8 +427,8 @@ struct proposal_impl : typed_primitive_impl<proposal> {
void init_kernels() override {}
static primitive_impl* create(const proposal_node& arg) {
const layout& l = arg.image_info().get_output_layout();
static primitive_impl* create(const proposal_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const layout& l = impl_param->input_layouts[2];
const size_t count = l.feature() == 1 ? static_cast<size_t>(l.batch()) : static_cast<size_t>(l.feature());
// Supported image_info sizes and components meaning:

View File

@@ -39,7 +39,9 @@ struct read_value_impl : public typed_primitive_impl<read_value> {
void init_kernels() override {}
public:
static primitive_impl* create(read_value_node const& arg) { return new read_value_impl{}; }
static primitive_impl* create(const read_value_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
return new read_value_impl{};
}
};
namespace detail {

View File

@@ -11,6 +11,8 @@
#include <string>
#include <sstream>
#include "to_string_utils.h"
#include "kernel_selector_helper.h"
#include "activation_inst.h"
namespace cldnn {
@@ -145,42 +147,39 @@ class implementation_map {
public:
using key_builder = implementation_key<primitive_kind>;
using key_type = typename key_builder::type;
using factory_type = std::function<primitive_impl*(const typed_program_node<primitive_kind>&)>;
using factory_type = std::function<primitive_impl*(const typed_program_node<primitive_kind>&, std::shared_ptr<kernel_impl_params>)>;
using map_type = singleton_map<impl_types, std::pair<std::set<key_type>, factory_type>>;
static factory_type get(const typed_program_node<primitive_kind>& primitive) {
impl_types target_impl_type = primitive.get_preferred_impl_type();
// lookup in database; throw if not found
auto key = key_builder()(primitive);
static factory_type get(std::shared_ptr<kernel_impl_params> impl_param, impl_types preferred_impl_type) {
auto key = key_builder()(impl_param->input_layouts[0]);
for (auto& kv : map_type::instance()) {
impl_types impl_type = kv.first;
if ((target_impl_type & impl_type) != impl_type)
if ((preferred_impl_type & impl_type) != impl_type)
continue;
std::set<key_type>& keys_set = kv.second.first;
auto& factory = kv.second.second;
if (keys_set.empty() || keys_set.find(key) != keys_set.end()) {
if (keys_set.empty() || keys_set.find(key) != keys_set.end()) {
return factory;
}
}
std::stringstream target_impl_type_ss;
target_impl_type_ss << target_impl_type;
target_impl_type_ss << preferred_impl_type;
throw std::runtime_error(std::string("implementation_map for ") + typeid(primitive_kind).name() +
" could not find any implementation to match key: " +
get_key_name(key) + ", impl_type: " + target_impl_type_ss.str() + ", node_id: " + primitive.id());
get_key_name(key) + ", impl_type: " + target_impl_type_ss.str() + ", node_id: " + impl_param->desc->id);
}
// check if for a given engine and type there exist an implementation
static bool check(const typed_program_node<primitive_kind>& primitive) {
static bool check(const typed_program_node<primitive_kind>& primitive, std::shared_ptr<kernel_impl_params> impl_params) {
impl_types target_impl_type = primitive.get_preferred_impl_type();
auto key = key_builder()(primitive);
auto key = key_builder()(impl_params->input_layouts[0]);
return check_key(target_impl_type, key);
}
// check if there exists a kernel implementation of a primitive with output set it primitive's output layout
static bool check_io_eq(const typed_program_node<primitive_kind>& primitive) {
static bool check_io_eq(const typed_program_node<primitive_kind>& primitive, std::shared_ptr<kernel_impl_params> impl_params) {
impl_types target_impl_type = primitive.get_preferred_impl_type();
auto key = key_builder()(primitive.get_output_layout());
auto key = key_builder()(impl_params->output_layout);
return check_key(target_impl_type, key);
}

View File

@@ -30,17 +30,17 @@ struct activation_impl : typed_primitive_impl_ocl<activation> {
return args;
}
static primitive_impl* create(const activation_node& arg) {
auto activation_params = get_default_params<kernel_selector::activation_params>(arg);
static primitive_impl* create(const activation_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto activation_params = get_default_params<kernel_selector::activation_params>(*impl_param);
auto activation_optional_params =
get_default_optional_params<kernel_selector::activation_optional_params>(arg.get_program());
convert_new_activation_func(arg.get_primitive(), activation_params.activations);
convert_new_activation_func(prim, activation_params.activations);
if (arg.is_parameterized()) {
const auto& slope_layout = arg.slope_input().get_output_layout();
const auto& output_layout = arg.get_output_layout();
const auto& slope_layout = impl_param->input_layouts[1];
const auto& output_layout = impl_param->output_layout;
const auto params_num =
kernel_selector::GetActivationAdditionalParamsNumber(activation_params.activations[0].function);

View File

@@ -35,8 +35,8 @@ protected:
}
public:
static primitive_impl* create(const adaptive_pooling_node& arg) {
auto params = get_default_params<kernel_selector::adaptive_pooling_params>(arg);
static primitive_impl* create(const adaptive_pooling_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::adaptive_pooling_params>(*impl_param);
auto optional_params = get_default_optional_params<kernel_selector::adaptive_pooling_optional_params>(arg.get_program());
const auto& primitive = arg.get_primitive();

View File

@@ -34,9 +34,8 @@ protected:
}
public:
static primitive_impl* create(const arg_max_min_node& arg) {
static primitive_impl* create(const arg_max_min_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto& axis = primitive->axis;
const auto& top_k = primitive->top_k;
const auto& out_type = primitive->output_type;
@@ -45,7 +44,7 @@ public:
const auto& values_first = primitive->values_first;
const auto& outputs_num = primitive->input.size() == 3 ? 2 : 1; // second output passed as input for TOP_K layer
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(arg);
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(*impl_param);
auto argm_optional_params =
get_default_optional_params<kernel_selector::arg_max_min_optional_params>(arg.get_program());
@@ -84,7 +83,7 @@ public:
argm_params.argMaxMinSortType = kernel_selector::argm_sort::INDEX;
if (outputs_num == 2) {
argm_params.inputs.push_back(convert_data_tensor(arg.get_dependency(2).get_output_layout()));
argm_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
}
argm_params.values_first = values_first;

View File

@@ -28,13 +28,13 @@ protected:
}
public:
static primitive_impl* create(const average_unpooling_node& arg) {
auto average_unpooling_params = get_default_params<kernel_selector::average_unpooling_params>(arg);
static primitive_impl* create(const average_unpooling_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto primitive = arg.get_primitive();
auto average_unpooling_params = get_default_params<kernel_selector::average_unpooling_params>(*impl_param);
auto average_unpooling_optional_params =
get_default_optional_params<kernel_selector::average_unpooling_optional_params>(arg.get_program());
auto& params = average_unpooling_params;
auto primitive = arg.get_primitive();
auto stride = primitive->stride;
params.unpoolSize = {

View File

@@ -25,13 +25,12 @@ struct batch_to_space_impl : typed_primitive_impl_ocl<batch_to_space> {
}
public:
static primitive_impl* create(const batch_to_space_node& arg) {
auto batch_to_space_params = get_default_params<kernel_selector::batch_to_space_params>(arg);
static primitive_impl* create(const batch_to_space_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto primitive = arg.get_primitive();
auto batch_to_space_params = get_default_params<kernel_selector::batch_to_space_params>(*impl_param);
auto batch_to_space_optional_params =
get_default_optional_params<kernel_selector::batch_to_space_optional_params>(arg.get_program());
auto primitive = arg.get_primitive();
batch_to_space_params.block_shape = convert_dim_vector(primitive->block_shape);
batch_to_space_params.crops_begin = convert_dim_vector(primitive->crops_begin);
batch_to_space_params.crops_end = convert_dim_vector(primitive->crops_end);

View File

@@ -60,9 +60,9 @@ protected:
int32_t get_split() const override { return _outer.get_split(); }
public:
static primitive_impl* create(const binary_convolution_node& arg) {
static primitive_impl* create(const binary_convolution_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto& weights_layout = arg.weights(0).get_output_layout().convert_to_weights_layout(false);
const auto& weights_layout = (*impl_param->weights_layout).convert_to_weights_layout(false);
const auto& weights_size = weights_layout.get_tensor();
const auto& split = primitive->split();
@@ -74,12 +74,10 @@ public:
const auto depthwise_separable_opt = arg.get_depthwise_sep_opt();
const auto actual_split = depthwise_separable_opt ? (decltype(split))1 : split;
assert(arg.get_output_layout().feature() / primitive->split() == weights_layout.batch());
assert(impl_param->output_layout.feature() / primitive->split() == weights_layout.batch());
auto conv_params =
get_weights_bias_default_params<kernel_selector::binary_convolution_params>(arg, actual_split);
auto conv_optional_params =
get_default_weights_bias_optional_params<kernel_selector::binary_convolution_optional_params>(
auto conv_params = get_weights_bias_default_params<kernel_selector::binary_convolution_params>(*impl_param, actual_split);
auto conv_optional_params = get_default_weights_bias_optional_params<kernel_selector::binary_convolution_optional_params>(
arg.get_program());
conv_params.pad_value = primitive->pad_value;

View File

@@ -22,13 +22,13 @@ struct border_impl : typed_primitive_impl_ocl<border> {
return make_unique<border_impl>(*this);
}
static primitive_impl* create(const border_node& arg) {
auto b_params = get_default_params<kernel_selector::border_params>(arg, 1);
static primitive_impl* create(const border_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto desc = arg.get_primitive();
auto b_params = get_default_params<kernel_selector::border_params>(*impl_param, 1);
auto b_optional_params =
get_default_optional_params<kernel_selector::border_optional_params>(arg.get_program());
auto desc = arg.get_primitive();
b_params.lt_sizes = convert_dim_vector(desc->left_top_sizes);
b_params.rb_sizes = convert_dim_vector(desc->right_bottom_sizes);
b_params.border_value = desc->border_value;

View File

@@ -22,15 +22,16 @@ struct broadcast_impl : typed_primitive_impl_ocl<broadcast> {
return make_unique<broadcast_impl>(*this);
}
static primitive_impl* create(const broadcast_node& arg) {
auto bc_params = get_default_params<kernel_selector::broadcast_params>(arg, 1);
static primitive_impl* create(const broadcast_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto bc_params = get_default_params<kernel_selector::broadcast_params>(*impl_param, 1);
auto bc_optional_params =
get_default_optional_params<kernel_selector::broadcast_optional_params>(arg.get_program());
const auto format = arg.get_output_layout().format;
const auto format = impl_param->output_layout.format;
size_t max_axes_num = format.dimension();
const auto& broadcast_axes = arg.get_primitive()->broadcast_axes;
const auto& broadcast_axes = primitive->broadcast_axes;
uint16_t index = (uint16_t)0;
uint16_t input_index = (uint16_t)broadcast_axes.size();

View File

@@ -24,8 +24,8 @@ struct bucketize_impl : typed_primitive_impl_ocl<bucketize> {
return make_unique<bucketize_impl>(*this);
}
static primitive_impl* create(const bucketize_node& arg) {
auto params = get_default_params<kernel_selector::bucketize_params>(arg);
static primitive_impl* create(const bucketize_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::bucketize_params>(*impl_param);
auto optional_params =
get_default_optional_params<kernel_selector::bucketize_optional_params>(arg.get_program());

View File

@@ -69,23 +69,22 @@ protected:
}
public:
static primitive_impl* create(const concatenation_node& arg) {
static primitive_impl* create(const concatenation_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
if (arg.can_be_optimized()) {
return new concatenation_impl(arg, {});
}
auto concat_params = get_default_params<kernel_selector::concatenation_params>(arg);
auto concat_optional_params =
get_default_optional_params<kernel_selector::concatenation_optional_params>(arg.get_program());
auto axis = arg.get_primitive()->axis;
const auto& primitive = arg.get_primitive();
auto concat_params = get_default_params<kernel_selector::concatenation_params>(*impl_param);
auto concat_optional_params = get_default_optional_params<kernel_selector::concatenation_optional_params>(arg.get_program());
auto axis = primitive->axis;
concat_params.inputs.resize(arg.inputs_count());
for (size_t i = 0; i < arg.inputs_count(); ++i) {
const layout& input_layout = arg.input(i).get_output_layout();
const layout& input_layout = impl_param->input_layouts[i];
concat_params.inputs[i] = convert_data_tensor(input_layout);
}
concat_params.axis = convert_axis(axis, arg.get_output_layout().get_rank());
concat_params.axis = convert_axis(axis, impl_param->output_layout.get_rank());
concat_optional_params.kernelPerInput = true;
auto& kernel_selector = kernel_selector::concatenation_kernel_selector::Instance();

View File

@@ -31,17 +31,17 @@ protected:
}
public:
static primitive_impl* create(const convert_color_node& arg) {
auto convert_color_params = get_default_params<kernel_selector::convert_color_params>(arg);
static primitive_impl* create(const convert_color_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto primitive = arg.get_primitive();
auto convert_color_params = get_default_params<kernel_selector::convert_color_params>(*impl_param);
auto convert_color_optional_params =
get_default_optional_params<kernel_selector::convert_color_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); ++i) {
convert_color_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
for (size_t i = 1; i < impl_param->input_layouts.size(); ++i) {
convert_color_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
auto primitive = arg.get_primitive();
convert_color_params.input_color_format = static_cast<kernel_selector::color_format>(primitive->input_color_format);
convert_color_params.output_color_format = static_cast<kernel_selector::color_format>(primitive->output_color_format);
convert_color_params.mem_type = static_cast<kernel_selector::memory_type>(primitive->mem_type);

View File

@@ -60,9 +60,8 @@ protected:
bool get_depthwise_sep_opt() const override { return _outer.get_depthwise_sep_opt(); }
public:
static primitive_impl* create(const convolution_node& arg) {
static primitive_impl* create(const convolution_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto& weights_layout = arg.weights(0).get_output_layout().convert_to_weights_layout(primitive->grouped_weights_shape);
const auto &split = primitive->split();
auto stride = primitive->stride;
@@ -73,15 +72,15 @@ public:
const auto transposed = arg.get_transposed();
auto conv_params = get_weight_bias_zero_point_default_params<kernel_selector::convolution_params>(
arg, split, 1, primitive->grouped_weights_shape);
*impl_param, split, 1, primitive->grouped_weights_shape);
auto conv_optional_params =
get_default_weights_bias_optional_params<kernel_selector::convolution_optional_params>(arg.get_program());
if (primitive->deformable_mode) {
conv_params.inputs.push_back(convert_data_tensor(arg.trans().get_output_layout()));
conv_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
conv_params.deformable_mode = true;
if (primitive->input.size() == 3) {
conv_params.inputs.push_back(convert_data_tensor(arg.mask().get_output_layout()));
conv_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
conv_params.deformable_mask_enabled = true;
}
conv_params.bilinear_interpolation_pad = arg.bilinear_interpolation_pad();
@@ -93,6 +92,8 @@ public:
conv_params.split = split;
conv_params.groups = groups;
const auto& weights_layout = impl_param->input_layouts[1 + 0 + arg.get_deform_conv_dep_offset()]
.convert_to_weights_layout(primitive->grouped_weights_shape);
uint32_t kx = weights_layout.spatial(0);
uint32_t ky = weights_layout.spatial(1);
uint32_t kz = weights_layout.spatial(2);
@@ -113,9 +114,9 @@ public:
uint32_t dilation_x = dilation.size() >= 1 ? dilation[dilation.size() - 1] : 1;
conv_params.dilation = {dilation_x, dilation_y, dilation_z};
if ((arg.get_dependency(0).get_output_layout().data_type == data_types::u8 ||
arg.get_dependency(0).get_output_layout().data_type == data_types::i8) &&
arg.get_dependency(1).get_output_layout().data_type == data_types::i8) {
if ((impl_param->input_layouts[0].data_type == data_types::u8 ||
impl_param->input_layouts[0].data_type == data_types::i8) &&
impl_param->input_layouts[1].data_type == data_types::i8) {
if (!primitive->weights_zero_points.empty() && !primitive->activations_zero_points.empty()) {
conv_params.quantization = kernel_selector::QuantizationType::ASYMMETRIC_DATA_AND_WEIGHTS;
} else if (!primitive->weights_zero_points.empty()) {
@@ -129,7 +130,7 @@ public:
conv_params.quantization = kernel_selector::QuantizationType::NONE;
}
auto format = arg.get_output_layout().format;
auto format = impl_param->output_layout.format;
if (format == format::b_fs_zyx_fsv16 ||
format == format::bs_fs_zyx_bsv16_fsv16 ||
format == format::bs_fs_yx_bsv16_fsv16 ||

View File

@@ -27,16 +27,17 @@ protected:
}
public:
static primitive_impl* create(const crop_node& arg) {
auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg, 1);
auto ew_optional_params =
get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
static primitive_impl* create(const crop_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto ew_params = get_default_params<kernel_selector::eltwise_params>(*impl_param, 1);
auto ew_optional_params = get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
ew_params.operations.push_back(
{{kernel_selector::eltwise_params::InputType::Buffer(0)}, kernel_selector::eltwise_mode::ASSIGN});
const auto& input_layout = arg.input().get_output_layout();
ew_params.inputs[0] = convert_data_tensor(input_layout, 1, arg.get_primitive()->offsets);
const auto& input_layout = impl_param->input_layouts[0];
ew_params.inputs[0] = convert_data_tensor(input_layout, 1, primitive->offsets);
auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(ew_params, ew_optional_params);

View File

@@ -26,20 +26,20 @@ struct ctc_greedy_decoder_impl : typed_primitive_impl_ocl<ctc_greedy_decoder> {
}
public:
static primitive_impl* create(const ctc_greedy_decoder_node& arg) {
auto ctc_gd_params = get_default_params<kernel_selector::ctc_greedy_decoder_params>(arg);
auto ctc_gd_optional_params = get_default_optional_params<kernel_selector::ctc_greedy_decoder_optional_params>(arg.get_program());
static primitive_impl* create(const ctc_greedy_decoder_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto prim = arg.get_primitive();
ctc_gd_params.inputs.push_back(
convert_data_tensor(arg.seq_indicators().get_output_layout()));
auto ctc_gd_params = get_default_params<kernel_selector::ctc_greedy_decoder_params>(*impl_param);
auto ctc_gd_optional_params = get_default_optional_params<kernel_selector::ctc_greedy_decoder_optional_params>(arg.get_program());
ctc_gd_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
ctc_gd_params.merge_repeated = prim->ctc_merge_repeated;
ctc_gd_params.blank_index = prim->blank_index;
ctc_gd_params.outputs_num = arg.has_second_output() ? 2 : 1;
if (ctc_gd_params.outputs_num == 2) {
ctc_gd_params.inputs.push_back(
convert_data_tensor(arg.second_output().get_output_layout()));
const auto& second_output_layout = impl_param->input_layouts[1];
ctc_gd_params.inputs.push_back(convert_data_tensor(second_output_layout));
}
auto& kernel_selector = kernel_selector::ctc_greedy_decoder_kernel_selector::Instance();

View File

@@ -45,12 +45,14 @@ struct cum_sum_impl : typed_primitive_impl_ocl<cum_sum> {
}
public:
static primitive_impl* create(const cum_sum_node& arg) {
auto cum_sum_params = get_default_params<kernel_selector::cum_sum_params>(arg);
static primitive_impl* create(const cum_sum_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto cum_sum_params = get_default_params<kernel_selector::cum_sum_params>(*impl_param);
auto cum_sum_optional_params =
get_default_optional_params<kernel_selector::cum_sum_optional_params>(arg.get_program());
cum_sum_params.axis = convert_axis(arg.get_primitive()->axis);
cum_sum_params.axis = convert_axis(prim->axis);
cum_sum_params.exclusive = arg.get_primitive()->exclusive;
cum_sum_params.reverse = arg.get_primitive()->reverse;

View File

@@ -179,7 +179,7 @@ static void add_layout_to_jit(kernel_selector::jit_constants& mem_consts, const
mem_consts.AddConstant(kernel_selector::MakeJitConstant(name + "_OFFSET", std::to_string(offset)));
}
static std::string get_jit_constant(const custom_gpu_primitive_node& outer) {
static std::string get_jit_constant(const custom_gpu_primitive_node& outer, const kernel_impl_params& impl_param) {
kernel_selector::jit_constants mem_consts{
kernel_selector::MakeJitConstant("NUM_INPUTS", std::to_string(outer.get_dependencies().size()))};
const auto primitive = outer.get_primitive().get();
@@ -189,11 +189,11 @@ static std::string get_jit_constant(const custom_gpu_primitive_node& outer) {
kernel_selector::MakeJitConstant("LOCAL_WORKSIZE", primitive->lws),
});
for (size_t i = 0; i < outer.get_dependencies().size(); i++) {
add_layout_to_jit(mem_consts, "INPUT" + std::to_string(i), outer.input(i).get_output_layout());
for (size_t i = 0; i < impl_param.input_layouts.size(); i++) {
add_layout_to_jit(mem_consts, "INPUT" + std::to_string(i), impl_param.input_layouts[i]);
}
add_layout_to_jit(mem_consts, "OUTPUT0", outer.get_output_layout());
add_layout_to_jit(mem_consts, "OUTPUT0", impl_param.output_layout);
std::ostringstream oss;
oss << "// Custom Layer Built-ins\n\n";
@@ -204,14 +204,14 @@ static std::string get_jit_constant(const custom_gpu_primitive_node& outer) {
return oss.str();
}
static primitive_impl* create(const custom_gpu_primitive_node& arg) {
static primitive_impl* create(const custom_gpu_primitive_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto primitive = arg.get_primitive().get();
auto cl_kernel = std::make_shared<kernel_selector::cl_kernel_data>();
cl_kernel->code.kernelString = std::make_shared<kernel_selector::kernel_string>();
cl_kernel->code.kernelString->entry_point = primitive->kernel_entry_point;
cl_kernel->code.kernelString->options = primitive->build_options;
cl_kernel->code.kernelString->jit = get_jit_constant(arg);
cl_kernel->code.kernelString->jit = get_jit_constant(arg, *impl_param);
for (const auto& s : primitive->kernels_code) {
cl_kernel->code.kernelString->str += s + "\n";
}

View File

@@ -51,24 +51,18 @@ protected:
uint32_t get_groups() const override { return _outer.get_groups(); }
public:
static primitive_impl* create(const deconvolution_node& arg) {
static primitive_impl* create(const deconvolution_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto& weights_layout = arg.weights(0).get_output_layout().convert_to_weights_layout(primitive->grouped_weights_shape);
const auto& split = primitive->split();
const auto& stride = primitive->stride;
#if 0 // TODO: support dilation
const auto& dilation = primitive->dilation;
#else
const ov::Strides dilation(arg.get_output_layout().get_spatial_rank(), 1);
#endif
const ov::Strides dilation(impl_param->output_layout.get_spatial_rank(), 1);
const auto actual_split = split;
const auto& pad = primitive->pad;
const auto& groups = primitive->groups;
auto deconv_params = get_weights_bias_default_params<kernel_selector::deconvolution_params>(
arg,
*impl_param,
(groups > 1) ? 1 : actual_split,
1,
primitive->grouped_weights_shape);
@@ -78,6 +72,8 @@ public:
deconv_params.split = split;
deconv_params.groups = groups;
const auto weights_idx = 1 + 0;
const auto& weights_layout = impl_param->input_layouts[weights_idx].convert_to_weights_layout(primitive->grouped_weights_shape);
uint32_t kx = weights_layout.spatial(0);
uint32_t ky = weights_layout.spatial(1);
uint32_t kz = weights_layout.spatial(2);

View File

@@ -37,11 +37,8 @@ protected:
uint32_t get_groups() const override { return _outer.get_groups(); }
public:
static primitive_impl* create(const deformable_conv_node& arg) {
static primitive_impl* create(const deformable_conv_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto& weights_layout = arg.weights(0).get_output_layout().convert_to_weights_layout(false);
const auto& weights_size = weights_layout.get_tensor();
const auto& split = primitive->split();
const auto& groups = primitive->groups;
@@ -49,12 +46,16 @@ public:
const auto actual_split = depthwise_separable_opt ? (decltype(split))1 : split;
auto conv_params = get_weights_bias_default_params<kernel_selector::convolution_params>(
arg,
*impl_param,
(groups > 1 && !depthwise_separable_opt) ? groups : actual_split,
groups);
auto conv_optional_params =
get_default_weights_bias_optional_params<kernel_selector::convolution_optional_params>(arg.get_program());
const auto weight_idx = 1 + 0;
const auto& weights_layout = impl_param->input_layouts[weight_idx].convert_to_weights_layout(false);
const auto& weights_size = weights_layout.get_tensor();
conv_params.depthwise_separable_opt = depthwise_separable_opt;
conv_params.split = split;
conv_params.groups = groups;
@@ -91,9 +92,12 @@ protected:
uint32_t get_groups() const override { return 1; }
public:
static primitive_impl* create(const deformable_interp_node& arg) {
static primitive_impl* create(const deformable_interp_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto input_idx = 0;
const auto trans_idx = 1;
const auto mask_idx = 2;
const auto& primitive = arg.get_primitive();
const auto& input_layout = arg.input().get_output_layout();
const auto& input_layout = impl_param->input_layouts[input_idx];
const auto& kernel_size = primitive->kernel_size;
auto stride = primitive->stride;
@@ -102,7 +106,7 @@ public:
const auto& groups = primitive->groups;
const auto& deformable_groups = primitive->deformable_groups;
auto conv_params = get_default_params<kernel_selector::convolution_params>(arg, groups);
auto conv_params = get_default_params<kernel_selector::convolution_params>(*impl_param, groups);
auto conv_optional_params =
get_default_optional_params<kernel_selector::convolution_optional_params>(arg.get_program());
@@ -110,9 +114,9 @@ public:
auto weights_layout = layout(input_layout.data_type, input_layout.format, kernel_size);
conv_params.weights = convert_weights_tensor(weights_layout);
conv_params.inputs.push_back(convert_data_tensor(arg.trans().get_output_layout()));
conv_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[trans_idx]));
if (primitive->input.size() == 3) {
conv_params.inputs.push_back(convert_data_tensor(arg.mask().get_output_layout()));
conv_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[mask_idx]));
conv_params.deformable_mask_enabled = true;
}
conv_params.bilinear_interpolation_pad = primitive->bilinear_interpolation_pad;

View File

@@ -24,13 +24,15 @@ struct depth_to_space_impl : typed_primitive_impl_ocl<depth_to_space> {
}
public:
static primitive_impl* create(const depth_to_space_node& arg) {
auto depth_to_space_params = get_default_params<kernel_selector::depth_to_space_params>(arg);
static primitive_impl* create(const depth_to_space_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto depth_to_space_params = get_default_params<kernel_selector::depth_to_space_params>(*impl_param);
auto depth_to_space_optional_params =
get_default_optional_params<kernel_selector::depth_to_space_optional_params>(arg.get_program());
depth_to_space_params.block_size = arg.get_primitive()->block_size;
depth_to_space_params.mode = arg.get_primitive()->mode == depth_to_space_mode::blocks_first ? kernel_selector::depth_to_space_mode::BLOCKS_FIRST
depth_to_space_params.block_size = prim->block_size;
depth_to_space_params.mode = prim->mode == depth_to_space_mode::blocks_first ? kernel_selector::depth_to_space_mode::BLOCKS_FIRST
: kernel_selector::depth_to_space_mode::DEPTH_FIRST;
auto& kernel_selector = kernel_selector::depth_to_space_kernel_selector::Instance();

View File

@@ -51,13 +51,15 @@ private:
}
public:
static primitive_impl* create(const detection_output_node& arg) {
auto detect_out_params = get_default_params<kernel_selector::detection_output_params>(arg);
static primitive_impl* create(const detection_output_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto detect_out_params = get_default_params<kernel_selector::detection_output_params>(*impl_param);
auto detect_out_optional_params =
get_default_optional_params<kernel_selector::detection_output_optional_params>(arg.get_program());
detect_out_params.inputs.push_back(convert_data_tensor(arg.confidence().get_output_layout()));
detect_out_params.inputs.push_back(convert_data_tensor(arg.prior_box().get_output_layout()));
const auto confidence_idx = 1;
const auto prior_box_idx = 2;
detect_out_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[confidence_idx]));
detect_out_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[prior_box_idx]));
set_detection_output_specific_params(detect_out_params.detectOutParams, arg);
auto& kernel_selector = kernel_selector::detection_output_kernel_selector::Instance();

View File

@@ -22,8 +22,8 @@ struct dft_impl : typed_primitive_impl_ocl<dft> {
return make_unique<dft_impl>(*this);
}
static primitive_impl* create(const dft_node& arg) {
auto params = get_default_params<kernel_selector::dft_params>(arg);
static primitive_impl* create(const dft_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::dft_params>(*impl_param);
auto primitive = arg.get_primitive();
params.axes = primitive->axes;
if (primitive->kind == dft_kind::inverse) {

View File

@@ -29,16 +29,17 @@ protected:
}
public:
static primitive_impl* create(const eltwise_node& arg) {
auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg);
static primitive_impl* create(const eltwise_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto ew_params = get_default_params<kernel_selector::eltwise_params>(*impl_param);
auto ew_optional_params =
get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); i++) {
ew_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
ew_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
const auto& primitive = arg.get_primitive();
ew_params.operations.push_back({{kernel_selector::eltwise_params::InputType::Buffer(0),
kernel_selector::eltwise_params::InputType::Buffer(1)},
@@ -56,8 +57,8 @@ public:
for (size_t i = 0; i < ew_params.inputs.size(); i++) {
if (!ew_params.inputs[i].SameDims(ew_params.outputs[0])) {
std::vector<int32_t> input_size = arg.input(i).get_output_layout().get_tensor().raw.vector();
std::vector<int32_t> output_size = arg.get_output_layout().get_tensor().raw.vector();
std::vector<int32_t> input_size = impl_param->input_layouts[i].get_tensor().raw.vector();
std::vector<int32_t> output_size = impl_param->output_layout.get_tensor().raw.vector();
bool broadcast = false;
for (size_t d = 0; d < output_size.size(); d++) {
if (output_size[d] != 1 && input_size[d] == 1)
@@ -98,8 +99,8 @@ public:
// TODO [LOW PRECISION]: check if this parameter's really needed. Maybe data types are enough
bool quantization = true;
for (size_t i = 0; i < arg.inputs_count(); i++) {
if (arg.input(i).get_output_layout().data_type != data_types::u8 &&
arg.input(i).get_output_layout().data_type != data_types::i8) {
if (impl_param->input_layouts[i].data_type != data_types::u8 &&
impl_param->input_layouts[i].data_type != data_types::i8) {
quantization = false;
}
}

View File

@@ -24,12 +24,13 @@ struct embedding_bag_impl : typed_primitive_impl_ocl<embedding_bag> {
}
public:
static primitive_impl* create(const embedding_bag_node& arg) {
auto embedding_bag_params = get_default_params<kernel_selector::embedding_bag_params>(arg);
static primitive_impl* create(const embedding_bag_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto embedding_bag_params = get_default_params<kernel_selector::embedding_bag_params>(*impl_param);
auto embedding_bag_optional_params =
get_default_optional_params<kernel_selector::embedding_bag_optional_params>(arg.get_program());
switch (arg.get_primitive()->type) {
switch (primitive->type) {
case embedding_bag::packed_sum:
embedding_bag_params.type = kernel_selector::EmbeddingBagType::PACKED_SUM;
break;
@@ -45,7 +46,7 @@ public:
}
for (size_t i = 1; i < arg.inputs_count(); i++) {
embedding_bag_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
embedding_bag_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
embedding_bag_params.default_index = arg.get_primitive()->default_index;

View File

@@ -31,8 +31,8 @@ protected:
}
public:
static primitive_impl* create(const experimental_detectron_detection_output_node& arg) {
auto params = get_default_params<kernel_selector::experimental_detectron_detection_output_params>(arg);
static primitive_impl* create(const experimental_detectron_detection_output_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::experimental_detectron_detection_output_params>(*impl_param);
auto optional_params =
get_default_optional_params<kernel_selector::experimental_detectron_detection_output_optional_params>(
arg.get_program());

View File

@@ -37,8 +37,8 @@ protected:
}
public:
static primitive_impl* create(const experimental_detectron_generate_proposals_single_image_node& arg) {
auto params = get_default_params<kernel_selector::experimental_detectron_generate_proposals_single_image_params>(arg);
static primitive_impl* create(const experimental_detectron_generate_proposals_single_image_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::experimental_detectron_generate_proposals_single_image_params>(*impl_param);
auto optional_params = get_default_optional_params<
kernel_selector::experimental_detectron_generate_proposals_single_image_optional_params>(arg.get_program());

View File

@@ -26,8 +26,8 @@ struct experimental_detectron_prior_grid_generator_impl
return make_unique<experimental_detectron_prior_grid_generator_impl>(*this);
}
static primitive_impl* create(const experimental_detectron_prior_grid_generator_node& arg) {
auto params = get_default_params<kernel_selector::experimental_detectron_prior_grid_generator_params>(arg);
static primitive_impl* create(const experimental_detectron_prior_grid_generator_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::experimental_detectron_prior_grid_generator_params>(*impl_param);
auto primPtr = arg.get_primitive();
auto& prim = *primPtr;

View File

@@ -39,8 +39,9 @@ protected:
}
public:
static primitive_impl* create(const experimental_detectron_roi_feature_extractor_node& arg) {
const auto output_layout = arg.get_output_layout();
static primitive_impl* create(const experimental_detectron_roi_feature_extractor_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
const auto output_layout = impl_param->output_layout;
const auto padding_filling_value = output_layout.data_padding.filling_value();
CLDNN_ERROR_NOT_EQUAL(arg.id(),
"experimental_detectron_roi_feature_extractor padding filling value",
@@ -48,14 +49,12 @@ public:
"padding mode",
0.0f,
"Unknown padding mode in experimental_detectron_roi_feature_extractor.");
auto params = get_default_params<kernel_selector::experimental_detectron_roi_feature_extractor_params>(arg);
auto params = get_default_params<kernel_selector::experimental_detectron_roi_feature_extractor_params>(*impl_param);
auto optional_params = get_default_optional_params<kernel_selector::experimental_detectron_roi_feature_extractor_optional_params>(arg.get_program());
const auto& primitive = arg.get_primitive();
size_t number_of_inputs = primitive->input_size() - 1;
for (std::size_t i = 1; i < number_of_inputs; i++) {
params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
params.output_dim = primitive->output_dim;

View File

@@ -21,13 +21,12 @@ struct experimental_detectron_topk_rois_impl : typed_primitive_impl_ocl<experime
return make_unique<experimental_detectron_topk_rois_impl>(*this);
}
static primitive_impl *create(const experimental_detectron_topk_rois_node &arg) {
auto params = get_default_params<kernel_selector::experimental_detectron_topk_roi_params>(
arg);
static primitive_impl *create(const experimental_detectron_topk_rois_node &arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto params = get_default_params<kernel_selector::experimental_detectron_topk_roi_params>(*impl_param);
const auto& experimental_detectron_topk_rois_kernel_selector =
kernel_selector::experimental_detectron_topk_rois_kernel_selector::Instance();
const auto& primitive = arg.get_primitive();
params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
params.max_rois = primitive->max_rois;
auto best_kernels = experimental_detectron_topk_rois_kernel_selector.GetBestKernels(params,
kernel_selector::experimental_detectron_topk_roi_optional_params());

View File

@@ -23,15 +23,16 @@ struct extract_image_patches_impl : typed_primitive_impl_ocl<extract_image_patch
}
public:
static primitive_impl* create(const extract_image_patches_node& arg) {
auto params = get_default_params<kernel_selector::extract_image_patches_params>(arg);
static primitive_impl* create(const extract_image_patches_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto params = get_default_params<kernel_selector::extract_image_patches_params>(*impl_param);
auto optional_params =
get_default_optional_params<kernel_selector::extract_image_patches_optional_params>(arg.get_program());
params.sizes = arg.get_primitive()->sizes;
params.strides = arg.get_primitive()->strides;
params.rates = arg.get_primitive()->rates;
params.auto_pad = arg.get_primitive()->auto_pad;
params.sizes = prim->sizes;
params.strides = prim->strides;
params.rates = prim->rates;
params.auto_pad = prim->auto_pad;
auto& kernel_selector = kernel_selector::extract_image_patches_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(params, optional_params);

View File

@@ -40,21 +40,20 @@ protected:
}
public:
static primitive_impl* create(const fully_connected_node& arg) {
auto fc_params = get_weights_bias_default_params<kernel_selector::fully_connected_params>(arg);
static primitive_impl* create(const fully_connected_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto primitive = arg.get_primitive();
auto fc_params = get_weights_bias_default_params<kernel_selector::fully_connected_params>(*impl_param);
auto fc_optional_params =
get_default_weights_bias_optional_params<kernel_selector::fully_connected_optional_params>(
arg.get_program());
fc_optional_params.allowInputReordering = true;
const auto primitive = arg.get_primitive();
if (primitive->input_size != 3)
fc_params.outputs = { fc_params.outputs[0].FlattenFeatureAndSpatials() };
bool is_quantized = true;
for (auto& input : arg.get_dependencies())
is_quantized &= data_type_traits::is_quantized(input->get_output_layout().data_type);
for (auto& input : impl_param->input_layouts)
is_quantized &= data_type_traits::is_quantized(input.data_type);
if (is_quantized) {
fc_params.quantization = kernel_selector::QuantizationType::SYMMETRIC;

View File

@@ -68,17 +68,18 @@ struct gather_impl : typed_primitive_impl_ocl<gather> {
}
public:
static primitive_impl* create(const gather_node& arg) {
auto gather_params = get_default_params<kernel_selector::gather_params>(arg);
static primitive_impl* create(const gather_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto gather_params = get_default_params<kernel_selector::gather_params>(*impl_param);
auto gather_optional_params =
get_default_optional_params<kernel_selector::gather_optional_params>(arg.get_program());
auto input_layout = arg.get_dependency(0).get_output_layout();
gather_params.axis = convert_axis(arg.get_primitive()->axis, input_layout.get_rank());
gather_params.batch_dim = size_t(arg.get_primitive()->batch_dim);
gather_params.support_neg_ind = arg.get_primitive()->support_neg_ind;
auto input_layout = impl_param->input_layouts[0];
gather_params.axis = convert_axis(prim->axis, input_layout.get_rank());
gather_params.batch_dim = size_t(prim->batch_dim);
gather_params.support_neg_ind = prim->support_neg_ind;
gather_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
gather_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
auto& kernel_selector = kernel_selector::gather_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(gather_params, gather_optional_params);

View File

@@ -42,14 +42,15 @@ struct gather_elements_impl : typed_primitive_impl_ocl<gather_elements> {
}
public:
static primitive_impl* create(const gather_elements_node& arg) {
auto gather_elements_params = get_default_params<kernel_selector::gather_elements_params>(arg);
static primitive_impl* create(const gather_elements_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto gather_elements_params = get_default_params<kernel_selector::gather_elements_params>(*impl_param);
auto gather_elements_optional_params =
get_default_optional_params<kernel_selector::gather_elements_optional_params>(arg.get_program());
gather_elements_params.axis = convert_axis(arg.get_primitive()->axis);
gather_elements_params.axis = convert_axis(prim->axis);
gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
gather_elements_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params);

View File

@@ -22,16 +22,17 @@ struct gather_nd_impl : typed_primitive_impl_ocl<gather_nd> {
return make_unique<gather_nd_impl>(*this);
}
static primitive_impl* create(const gather_nd_node& arg) {
auto gather_nd_params = get_default_params<kernel_selector::gather_nd_params>(arg);
static primitive_impl* create(const gather_nd_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto gather_nd_params = get_default_params<kernel_selector::gather_nd_params>(*impl_param);
auto gather_nd_optional_params =
get_default_optional_params<kernel_selector::gather_nd_optional_params>(arg.get_program());
gather_nd_params.indices_rank = arg.get_primitive()->indices_rank;
gather_nd_params.batch_dims = arg.get_primitive()->batch_dims;
gather_nd_params.batch_merged_output = arg.get_primitive()->batch_merged_output;
gather_nd_params.indices_rank = prim->indices_rank;
gather_nd_params.batch_dims = prim->batch_dims;
gather_nd_params.batch_merged_output = prim->batch_merged_output;
gather_nd_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
gather_nd_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
auto& kernel_selector = kernel_selector::gather_nd_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(gather_nd_params, gather_nd_optional_params);

View File

@@ -22,14 +22,14 @@ struct gather_tree_impl : typed_primitive_impl_ocl<gather_tree> {
return make_unique<gather_tree_impl>(*this);
}
static primitive_impl* create(const gather_tree_node& arg) {
auto b_params = get_default_params<kernel_selector::gather_tree_params>(arg, 1);
static primitive_impl* create(const gather_tree_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto desc = arg.get_primitive();
auto b_params = get_default_params<kernel_selector::gather_tree_params>(*impl_param, 1);
auto b_optional_params = get_default_optional_params<kernel_selector::gather_tree_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.get_dependencies().size(); i++) {
b_params.inputs.push_back(convert_data_tensor(arg.get_dependency(i).get_output_layout(), 1));
b_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i], 1));
}
auto desc = arg.get_primitive();
auto& kernel_selector = kernel_selector::gather_tree_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(b_params, b_optional_params);

View File

@@ -23,24 +23,24 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
}
public:
static primitive_impl* create(const gemm_node& arg) {
auto gemm_params = get_default_params<kernel_selector::gemm_params>(arg, 1);
static primitive_impl* create(const gemm_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto desc = arg.get_primitive();
auto gemm_params = get_default_params<kernel_selector::gemm_params>(*impl_param, 1);
auto gemm_optional_params =
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); i++) {
gemm_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
gemm_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
auto desc = arg.get_primitive();
gemm_params.alpha = desc->alpha;
gemm_params.beta = desc->beta;
gemm_params.transpose_input0 = desc->transpose_input0;
gemm_params.transpose_input1 = desc->transpose_input1;
bool is_quantized = true;
for (auto& input : arg.get_dependencies())
is_quantized &= data_type_traits::is_quantized(input->get_output_layout().data_type);
for (auto& input : impl_param->input_layouts)
is_quantized &= data_type_traits::is_quantized(input.data_type);
if (is_quantized) {
gemm_params.quantization = kernel_selector::QuantizationType::SYMMETRIC;

View File

@@ -105,7 +105,7 @@ struct generic_layer_cpu : typed_primitive_impl<generic_layer> {
void init_kernels() override {}
};
static primitive_impl* create(const generic_layer_node& arg) {
static primitive_impl* create(const generic_layer_node& arg, std::shared_ptr<kernel_impl_params>) {
if (arg.get_primitive()->generic_params.engine == kernel_selector::generic_kernel_params::Engine::GPU) {
return new generic_layer_impl(arg);
} else {

View File

@@ -26,11 +26,12 @@ struct grn_impl : typed_primitive_impl_ocl<grn> {
}
public:
static primitive_impl* create(const grn_node& arg) {
auto grn_params = get_default_params<kernel_selector::grn_params>(arg);
static primitive_impl* create(const grn_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto grn_params = get_default_params<kernel_selector::grn_params>(*impl_param);
auto grn_optional_params = get_default_optional_params<kernel_selector::grn_optional_params>(arg.get_program());
grn_params.bias = arg.get_primitive()->bias;
grn_params.bias = prim->bias;
auto& kernel_selector = kernel_selector::grn_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(grn_params, grn_optional_params);

View File

@@ -21,11 +21,10 @@ struct lrn_impl : typed_primitive_impl_ocl<lrn> {
return make_unique<lrn_impl>(*this);
}
static primitive_impl* create(const lrn_node& arg) {
auto lrn_params = get_default_params<kernel_selector::lrn_params>(arg);
auto lrn_optional_params = get_default_optional_params<kernel_selector::lrn_optional_params>(arg.get_program());
static primitive_impl* create(const lrn_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto lrn_params = get_default_params<kernel_selector::lrn_params>(*impl_param);
auto lrn_optional_params = get_default_optional_params<kernel_selector::lrn_optional_params>(arg.get_program());
lrn_params.alpha = primitive->alpha;
lrn_params.beta = primitive->beta;

View File

@@ -34,19 +34,23 @@ protected:
}
public:
static primitive_impl* create(const lstm_dynamic_input_node& arg) {
auto dlstm_input_params = get_default_params<kernel_selector::lstm_dynamic_input_params>(arg);
static primitive_impl* create(const lstm_dynamic_input_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto dlstm_input_params = get_default_params<kernel_selector::lstm_dynamic_input_params>(*impl_param);
const auto& weights_layout = arg.weights().get_output_layout();
const auto dyn_len_idx = 1;
const auto weights_idx = 2;
const auto bias_idx = 3;
const auto& weights_layout = impl_param->input_layouts[weights_idx];
dlstm_input_params.weights = convert_weights_tensor(weights_layout);
if (arg.bias_term()) {
const auto& bias_layout = arg.bias().get_output_layout();
const auto& bias_layout = impl_param->input_layouts[bias_idx];
dlstm_input_params.bias.push_back(convert_data_tensor(bias_layout));
}
// dyn length
const auto& dyn_length_tensor = arg.dyn_length().get_output_layout();
const auto& dyn_length_tensor = impl_param->input_layouts[dyn_len_idx];
dlstm_input_params.inputs.push_back(convert_data_tensor(dyn_length_tensor));
dlstm_input_params.direction = arg.direction();

View File

@@ -39,36 +39,36 @@ protected:
}
public:
static primitive_impl* create(const lstm_dynamic_timeloop_node& arg) {
auto dlstm_timeloop_params = get_default_params<kernel_selector::lstm_dynamic_timeloop_params>(arg);
static primitive_impl* create(const lstm_dynamic_timeloop_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto dlstm_timeloop_params = get_default_params<kernel_selector::lstm_dynamic_timeloop_params>(*impl_param);
// dyn length
const auto& dyn_length_tensor = arg.dyn_length().get_output_layout();
const auto& dyn_length_tensor = impl_param->input_layouts[arg.get_dependency_idx("dyn_length")];
dlstm_timeloop_params.inputs.push_back(convert_data_tensor(dyn_length_tensor));
// recurrent
const auto& recurrent_layout = arg.recurrent().get_output_layout();
const auto& recurrent_layout = impl_param->input_layouts[arg.get_dependency_idx("recurrent")];
dlstm_timeloop_params.recurrent = convert_data_tensor(recurrent_layout);
dlstm_timeloop_params.direction = arg.direction();
if (arg.initial_cell_term()) {
const auto& cell_layout = arg.initial_cell().get_output_layout();
const auto& cell_layout = impl_param->input_layouts[arg.get_dependency_idx("initial_cell")];
dlstm_timeloop_params.set_cell(convert_data_tensor(cell_layout));
}
if (arg.last_hidden_output_term()) {
const auto& last_hidden_output_layout = arg.last_hidden_state().get_output_layout();
const auto& last_hidden_output_layout = impl_param->input_layouts[arg.get_dependency_idx("last_hidden_output")];
dlstm_timeloop_params.set_last_hidden_output(convert_data_tensor(last_hidden_output_layout));
}
if (arg.initial_hidden_term()) {
const auto& hidden_layout = arg.initial_hidden().get_output_layout();
const auto& hidden_layout = impl_param->input_layouts[arg.get_dependency_idx("initial_hidden")];
dlstm_timeloop_params.set_hidden(convert_data_tensor(hidden_layout));
}
if (arg.last_cell_output_term()) {
const auto& last_cell_state_layout = arg.last_cell_state().get_output_layout();
const auto& last_cell_state_layout = impl_param->input_layouts[arg.get_dependency_idx("last_cell_output")];
dlstm_timeloop_params.set_last_cell_output(convert_data_tensor(last_cell_state_layout));
}

View File

@@ -34,13 +34,15 @@ protected:
}
public:
static primitive_impl* create(const lstm_elt_node& arg) {
auto lstm_elt_params = get_default_params<kernel_selector::lstm_elt_params>(arg);
static primitive_impl* create(const lstm_elt_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto lstm_elt_params = get_default_params<kernel_selector::lstm_elt_params>(*impl_param);
auto lstm_elt_optional_params =
get_default_optional_params<kernel_selector::lstm_elt_optional_params>(arg.get_program());
if (arg.cell_term()) {
const auto& cell_layout = arg.cell().get_output_layout();
const auto& cell_idx = 1;
const auto& cell_layout = impl_param->input_layouts[cell_idx];
lstm_elt_params.SetCell(convert_data_tensor(cell_layout));
// TODO: make a generic function to get the direction
if (cell_layout.spatial(1) > 1) {
@@ -48,7 +50,6 @@ public:
}
}
const auto& prim = arg.get_primitive();
if (!prim->activations.empty()) {
auto a_sz = prim->activations.size();
auto param_sz = prim->activation_params.size();

View File

@@ -37,21 +37,26 @@ protected:
}
public:
static primitive_impl* create(const lstm_gemm_node& arg) {
const auto& weights_layout = arg.weights().get_output_layout();
static primitive_impl* create(const lstm_gemm_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto input_idx = 0;
const auto weight_idx = 1;
const auto recurrent_idx = 2;
const auto bias_idx = 3;
const auto hidden_idx = arg.bias_term() ? 4 : 3;
auto lstm_gemm_params = get_default_params<kernel_selector::lstm_gemm_params>(arg);
const auto& weights_layout = impl_param->input_layouts[weight_idx];
auto lstm_gemm_params = get_default_params<kernel_selector::lstm_gemm_params>(*impl_param);
lstm_gemm_params.weights = convert_data_tensor(weights_layout);
if (arg.bias_term()) {
const auto& bias_layout = arg.bias().get_output_layout();
const auto& bias_layout = impl_param->input_layouts[bias_idx];
lstm_gemm_params.SetBias(convert_data_tensor(bias_layout));
}
if (arg.hidden_term()) {
const auto& recurrent_layout = arg.recurrent().get_output_layout();
const auto& recurrent_layout = impl_param->input_layouts[recurrent_idx];
lstm_gemm_params.recurrent = convert_data_tensor(recurrent_layout);
const auto& hidden_layout = arg.hidden().get_output_layout();
const auto& hidden_layout = impl_param->input_layouts[hidden_idx];
lstm_gemm_params.SetHidden(convert_data_tensor(hidden_layout));
// TODO: make a generic function to get the direction
if (hidden_layout.spatial(1) > 1) {
@@ -61,7 +66,7 @@ public:
lstm_gemm_params.direction = arg.direction();
// Update the direction of the input for the gemm kernel
const auto& input_layout = arg.input().get_output_layout();
const auto& input_layout = impl_param->input_layouts[input_idx];
size_t input_directions = input_layout.spatial(1);
if (input_directions > 1) { // For bidirection input, input direction can be 1 or 0

View File

@@ -40,12 +40,13 @@ public:
return parent::execute_impl(tmp_events, instance);
}
static primitive_impl* create(const max_unpooling_node& arg) {
auto max_unpooling_params = get_default_params<kernel_selector::max_unpooling_params>(arg);
static primitive_impl* create(const max_unpooling_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto max_unpooling_params = get_default_params<kernel_selector::max_unpooling_params>(*impl_param);
auto max_unpooling_optional_params =
get_default_optional_params<kernel_selector::max_unpooling_optional_params>(arg.get_program());
max_unpooling_params.inputs.push_back(convert_data_tensor(arg.argmax().get_output_layout()));
const auto max_idx = 1;
max_unpooling_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[max_idx]));
auto& kernel_selector = kernel_selector::max_unpooling_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(max_unpooling_params, max_unpooling_optional_params);

View File

@@ -18,7 +18,7 @@ struct mutable_data_impl : public typed_primitive_impl_ocl<mutable_data> {
}
public:
static primitive_impl* create(mutable_data_node const& arg) { return new mutable_data_impl(arg, {}); }
static primitive_impl* create(mutable_data_node const& arg, std::shared_ptr<kernel_impl_params>) { return new mutable_data_impl(arg, {}); }
};
namespace detail {

View File

@@ -26,16 +26,17 @@ struct mvn_impl : typed_primitive_impl_ocl<mvn> {
}
public:
static primitive_impl* create(const mvn_node& arg) {
auto mvn_params = get_default_params<kernel_selector::mvn_params>(arg);
static primitive_impl* create(const mvn_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto mvn_params = get_default_params<kernel_selector::mvn_params>(*impl_param);
auto mvn_optional_params = get_default_optional_params<kernel_selector::mvn_optional_params>(arg.get_program());
mvn_params.mvnMode = arg.get_primitive()->across_channels ? kernel_selector::mvn_mode::ACROSS_CHANNELS
mvn_params.mvnMode = prim->across_channels ? kernel_selector::mvn_mode::ACROSS_CHANNELS
: kernel_selector::mvn_mode::WITHIN_CHANNELS;
mvn_params.mvnNormalizeVariance = arg.get_primitive()->normalize_variance;
mvn_params.epsilon = arg.get_primitive()->epsilon;
mvn_params.mvnNormalizeVariance = prim->normalize_variance;
mvn_params.epsilon = prim->epsilon;
mvn_params.mvnEpsMode = arg.get_primitive()->eps_inside_sqrt ? kernel_selector::mvn_eps_mode::INSIDE_SQRT
mvn_params.mvnEpsMode = prim->eps_inside_sqrt ? kernel_selector::mvn_eps_mode::INSIDE_SQRT
: kernel_selector::mvn_eps_mode::OUTSIDE_SQRT;
auto& kernel_selector = kernel_selector::mvn_kernel_selector::Instance();

View File

@@ -53,13 +53,14 @@ protected:
}
public:
static primitive_impl* create(const non_max_suppression_node& arg) {
auto params = get_default_params<kernel_selector::non_max_suppression_params>(arg);
static primitive_impl* create(const non_max_suppression_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto params = get_default_params<kernel_selector::non_max_suppression_params>(*impl_param);
auto optional_params =
get_default_optional_params<kernel_selector::non_max_suppression_optional_params>(arg.get_program());
const auto& primitive = arg.get_primitive();
params.inputs.push_back(convert_data_tensor(arg.input_scores().get_output_layout()));
const auto input_scores_idx = 1;
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[input_scores_idx]));
if (arg.has_num_select_per_class()) {
cldnn::program_node& node = arg.num_select_per_class_node();
@@ -68,7 +69,7 @@ public:
params.num_select_per_class = get_value<int>(node);
} else {
params.num_select_per_class_type = kernel_selector::NmsArgType::Input;
params.inputs.push_back(convert_data_tensor(node.get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->output_layout));
}
}
@@ -79,7 +80,7 @@ public:
params.iou_threshold = get_value<float>(node);
} else {
params.iou_threshold_type = kernel_selector::NmsArgType::Input;
params.inputs.push_back(convert_data_tensor(node.get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->output_layout));
}
}
@@ -90,7 +91,7 @@ public:
params.score_threshold = get_value<float>(node);
} else {
params.score_threshold_type = kernel_selector::NmsArgType::Input;
params.inputs.push_back(convert_data_tensor(node.get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->output_layout));
}
}
@@ -101,21 +102,28 @@ public:
params.soft_nms_sigma = get_value<float>(node);
} else {
params.soft_nms_sigma_type = kernel_selector::NmsArgType::Input;
params.inputs.push_back(convert_data_tensor(node.get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->output_layout));
}
}
auto get_additional_output_node_idx = [&] (bool is_third) {
size_t offset = 2;
offset += arg.has_num_select_per_class();
offset += arg.has_iou_threshold();
offset += arg.has_score_threshold();
offset += arg.has_soft_nms_sigma();
if (is_third)
offset += arg.has_second_output();
return offset;
};
if (arg.has_second_output()) {
layout second_output_layout = arg.second_output_node().get_output_layout();
second_output_layout.format = arg.input_scores().get_output_layout().format;
params.inputs.push_back(convert_data_tensor(second_output_layout));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[get_additional_output_node_idx(false)]));
params.has_second_output = true;
}
if (arg.has_third_output()) {
layout third_output_layout = arg.third_output_node().get_output_layout();
third_output_layout.format = arg.input_scores().get_output_layout().format;
params.inputs.push_back(convert_data_tensor(third_output_layout));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[get_additional_output_node_idx(true)]));
params.has_third_output = true;
}

View File

@@ -33,16 +33,17 @@ protected:
}
public:
static primitive_impl* create(const normalize_node& arg) {
auto norm_params = get_default_params<kernel_selector::normalize_params>(arg);
static primitive_impl* create(const normalize_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto norm_params = get_default_params<kernel_selector::normalize_params>(*impl_param);
auto norm_optional_params =
get_default_optional_params<kernel_selector::normalize_optional_params>(arg.get_program());
const auto& scale_layout = arg.scale().get_output_layout();
const auto& scale_layout = impl_param->input_layouts[1];
norm_params.normMode = arg.get_primitive()->across_spatial ? kernel_selector::normalize_mode::ACROSS_SPATIAL
norm_params.normMode = prim->across_spatial ? kernel_selector::normalize_mode::ACROSS_SPATIAL
: kernel_selector::normalize_mode::WITHIN_SPATIAL;
norm_params.epsilon = arg.get_primitive()->epsilon;
norm_params.epsilon = prim->epsilon;
norm_params.scaleTable = convert_data_tensor(scale_layout).FlattenFeatureAndSpatials();
auto& kernel_selector = kernel_selector::normalize_kernel_selector::Instance();

View File

@@ -23,16 +23,17 @@ struct one_hot_impl : typed_primitive_impl_ocl<one_hot> {
return make_unique<one_hot_impl>(*this);
}
static primitive_impl* create(const one_hot_node& arg) {
auto oh_params = get_default_params<kernel_selector::one_hot_params>(arg, 1);
static primitive_impl* create(const one_hot_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto oh_params = get_default_params<kernel_selector::one_hot_params>(*impl_param, 1);
auto oh_optional_params =
get_default_optional_params<kernel_selector::one_hot_optional_params>(arg.get_program());
oh_params.one_hot_axis = arg.get_primitive()->one_hot_axis;
oh_params.on_value = arg.get_primitive()->on_value;
oh_params.off_value = arg.get_primitive()->off_value;
oh_params.one_hot_axis = prim->one_hot_axis;
oh_params.on_value = prim->on_value;
oh_params.off_value = prim->off_value;
auto output_sizes = arg.get_output_layout().get_dims();
auto output_sizes = impl_param->output_layout.get_dims();
oh_params.one_hot_limit = output_sizes[oh_params.one_hot_axis];

View File

@@ -50,13 +50,14 @@ struct permute_impl : typed_primitive_impl_ocl<permute> {
return make_unique<permute_impl>(*this);
}
static primitive_impl* create(const permute_node& arg) {
auto permute_params = get_default_params<kernel_selector::permute_params>(arg);
static primitive_impl* create(const permute_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto permute_params = get_default_params<kernel_selector::permute_params>(*impl_param);
auto permute_optional_params =
get_default_optional_params<kernel_selector::permute_optional_params>(arg.get_program());
auto in_rank = arg.get_dependency(0).get_output_layout().get_rank();
auto permute_order = convert_permute_order(arg.get_primitive()->permute_order, in_rank);
auto in_rank = impl_param->input_layouts[0].get_rank();
auto permute_order = convert_permute_order(prim->permute_order, in_rank);
permute_params.order = permute_order;
auto& kernel_selector = kernel_selector::permute_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(permute_params, permute_optional_params);

View File

@@ -77,15 +77,13 @@ protected:
}
public:
static primitive_impl* create(const pooling_node& arg) {
static primitive_impl* create(const pooling_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
validate_args(arg);
auto pool_params = get_default_params<kernel_selector::pooling_params>(arg);
const auto primitive = arg.get_primitive();
auto pool_params = get_default_params<kernel_selector::pooling_params>(*impl_param);
auto pool_optional_params =
get_default_optional_params<kernel_selector::pooling_optional_params>(arg.get_program());
const auto primitive = arg.get_primitive();
pool_params.maxPoolOpset8Features = primitive->maxPoolOpset8Features;
if (pool_params.maxPoolOpset8Features) {
switch (primitive->index_element_type) {
@@ -107,8 +105,8 @@ public:
const auto& pad = primitive->pad;
const auto& dilation = primitive->dilation;
auto kernel = primitive->size;
const auto& input_layout = arg.input().get_output_layout();
const auto& output_layout = arg.get_output_layout();
const auto& input_layout = impl_param->input_layouts[0];
const auto& output_layout = impl_param->output_layout;
auto spatial_rank = output_layout.get_spatial_rank();
auto& pp = pool_params;

View File

@@ -23,23 +23,27 @@ struct pyramid_roi_align_impl : typed_primitive_impl_ocl<pyramid_roi_align> {
return make_unique<pyramid_roi_align_impl>(*this);
}
static primitive_impl* create(const pyramid_roi_align_node& arg) {
static primitive_impl* create(const pyramid_roi_align_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto prim = arg.get_primitive();
auto params = get_default_params<kernel_selector::PyramidROIAlign_params>(arg, 1);
auto params = get_default_params<kernel_selector::PyramidROIAlign_params>(*impl_param, 1);
auto optional_params =
get_default_optional_params<kernel_selector::PyramidROIAlign_optional_params>(arg.get_program());
params.inputs.push_back(convert_data_tensor(arg.P2().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.P3().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.P4().get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.P5().get_output_layout()));
const auto P2_idx = 1;
const auto P3_idx = 2;
const auto P4_idx = 3;
const auto P5_idx = 4;
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[P2_idx]));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[P3_idx]));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[P4_idx]));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[P5_idx]));
params.sampling_ratio_x = prim->sampling_ratio;
params.sampling_ratio_y = prim->sampling_ratio;
auto first_layer_scale = prim->pyramid_scales[0];
auto image_size_x = arg.P2().get_output_layout().spatial(0) * first_layer_scale;
auto image_size_y = arg.P2().get_output_layout().spatial(1) * first_layer_scale;
auto image_size_x = impl_param->input_layouts[P2_idx].spatial(0) * first_layer_scale;
auto image_size_y = impl_param->input_layouts[P2_idx].spatial(1) * first_layer_scale;
params.image_size_x = image_size_x;
params.image_size_y = image_size_y;

View File

@@ -43,8 +43,8 @@ protected:
}
public:
static primitive_impl* create(const quantize_node& arg) {
auto quantize_params = get_default_params<kernel_selector::quantize_params>(arg);
static primitive_impl* create(const quantize_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto quantize_params = get_default_params<kernel_selector::quantize_params>(*impl_param);
auto quantize_optional_params =
get_default_optional_params<kernel_selector::quantize_optional_params>(arg.get_program());
@@ -75,9 +75,9 @@ public:
quantize_params.out_shift = arg.get_output_shift_val();
for (size_t i = 1; i < arg.inputs_count(); i++) {
quantize_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
quantize_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
const auto& output_layout = arg.get_output_layout();
const auto& output_layout = impl_param->output_layout;
quantize_params.outputs = { convert_data_tensor(output_layout) };
auto& kernel_selector = kernel_selector::quantize_kernel_selector::Instance();

View File

@@ -21,16 +21,15 @@ struct random_uniform_impl : typed_primitive_impl_ocl<random_uniform> {
return make_unique<random_uniform_impl>(*this);
}
static primitive_impl *create(const random_uniform_node &arg) {
auto params = get_default_params<kernel_selector::random_uniform_params>(
arg);
static primitive_impl *create(const random_uniform_node &arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto &primitive = arg.get_primitive();
auto params = get_default_params<kernel_selector::random_uniform_params>(*impl_param);
auto &random_uniform_kernel_selector =
kernel_selector::random_uniform_kernel_selector::Instance();
const auto &primitive = arg.get_primitive();
params.global_seed = primitive->global_seed;
params.op_seed = primitive->op_seed;
params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
auto best_kernels = random_uniform_kernel_selector.GetBestKernels(params,
kernel_selector::random_uniform_optional_params());
CLDNN_ERROR_BOOL(arg.id(),

View File

@@ -20,10 +20,10 @@ struct range_impl : typed_primitive_impl_ocl<range> {
return make_unique<range_impl>(*this);
}
static primitive_impl* create(const range_node& arg) {
auto params = get_default_params<kernel_selector::range_params>(arg);
static primitive_impl* create(const range_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::range_params>(*impl_param);
for (int i : {1, 2})
params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
auto optional_params =
get_default_optional_params<kernel_selector::range_optional_params>(arg.get_program());

View File

@@ -58,13 +58,14 @@ struct reduce_impl : typed_primitive_impl_ocl<reduce> {
}
public:
static primitive_impl* create(const reduce_node& arg) {
auto reduce_params = get_default_params<kernel_selector::reduce_params>(arg);
static primitive_impl* create(const reduce_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto reduce_params = get_default_params<kernel_selector::reduce_params>(*impl_param);
auto reduce_optional_params = get_default_optional_params<kernel_selector::reduce_optional_params>(arg.get_program());
reduce_params.reduceAxes = arg.get_primitive()->axes;
reduce_params.keepDims = arg.get_primitive()->keep_dims;
reduce_params.reduceMode = cldnn_2_reduce_mode(arg.get_primitive()->mode);
reduce_params.reduceAxes = prim->axes;
reduce_params.keepDims = prim->keep_dims;
reduce_params.reduceMode = cldnn_2_reduce_mode(prim->mode);
auto& kernel_selector = kernel_selector::reduce_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(reduce_params, reduce_optional_params);

View File

@@ -21,8 +21,8 @@ struct region_yolo_impl : typed_primitive_impl_ocl<region_yolo> {
return make_unique<region_yolo_impl>(*this);
}
static primitive_impl* create(const region_yolo_node& arg) {
auto ry_params = get_default_params<kernel_selector::region_yolo_params>(arg);
static primitive_impl* create(const region_yolo_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto ry_params = get_default_params<kernel_selector::region_yolo_params>(*impl_param);
auto ry_optional_params =
get_default_optional_params<kernel_selector::region_yolo_optional_params>(arg.get_program());

View File

@@ -41,40 +41,40 @@ protected:
}
public:
static primitive_impl* create(const reorder_node& arg) {
auto&& input_layout = arg.input().get_output_layout();
auto&& output_layout = arg.get_output_layout();
auto reorder_params = get_default_params<kernel_selector::reorder_params>(arg);
static primitive_impl* create(const reorder_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto&& output_layout = impl_param->output_layout;
auto reorder_params = get_default_params<kernel_selector::reorder_params>(*impl_param);
auto reorder_optional_params =
get_default_optional_params<kernel_selector::reorder_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); i++) {
reorder_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
reorder_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
if (arg.get_output_layout().data_padding) {
if (impl_param->output_layout.data_padding) {
reorder_params.has_padded_output = true;
}
if (arg.has_mean()) {
if (input_layout.format == cldnn::format::nv12) {
if (impl_param->input_layouts[0].format == cldnn::format::nv12) {
const auto& mean_layout = arg.mean_nv12().get_output_layout();
reorder_params.mean = convert_data_tensor(mean_layout);
reorder_params.mode = kernel_selector::mean_subtruct_mode::IN_BUFFER;
} else {
const auto& mean_layout = arg.mean().get_output_layout();
const auto mean_idx = 1;
const auto& mean_layout = impl_param->input_layouts[mean_idx];
reorder_params.mean = convert_data_tensor(mean_layout);
reorder_params.mode = kernel_selector::mean_subtruct_mode::IN_BUFFER;
}
} else if (arg.get_primitive()->subtract_per_feature.empty() == false) {
} else if (prim->subtract_per_feature.empty() == false) {
reorder_params.mode = kernel_selector::mean_subtruct_mode::INSIDE_PARAMS;
reorder_params.meanValues = arg.get_primitive()->subtract_per_feature;
reorder_params.meanValues = prim->subtract_per_feature;
} else {
reorder_params.mode = kernel_selector::mean_subtruct_mode::NONE;
}
if (reorder_params.mode != kernel_selector::mean_subtruct_mode::NONE) {
switch (arg.get_primitive()->mean_mode) {
switch (prim->mean_mode) {
case reorder_mean_mode::none:
reorder_params.mean_op = kernel_selector::mean_op::NONE;
break;
@@ -98,7 +98,7 @@ public:
reorder_params.winograd_nr_tiles_x = ceil_div(output_layout.spatial(0), 4);
}
reorder_params.winograd = input_layout.format.is_winograd() || output_layout.format.is_winograd();
reorder_params.winograd = impl_param->input_layouts[0].format.is_winograd() || output_layout.format.is_winograd();
auto& kernel_selector = kernel_selector::reorder_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(reorder_params, reorder_optional_params);

View File

@@ -21,13 +21,12 @@ struct reorg_yolo_impl : typed_primitive_impl_ocl<reorg_yolo> {
return make_unique<reorg_yolo_impl>(*this);
}
static primitive_impl* create(const reorg_yolo_node& arg) {
auto ry_params = get_default_params<kernel_selector::reorg_yolo_params>(arg);
static primitive_impl* create(const reorg_yolo_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto ry_params = get_default_params<kernel_selector::reorg_yolo_params>(*impl_param);
auto ry_optional_params =
get_default_optional_params<kernel_selector::reorg_yolo_optional_params>(arg.get_program());
const auto& primitive = arg.get_primitive();
ry_params.stride = primitive->stride;
auto& kernel_selector = kernel_selector::reorg_yolo_kernel_selector::Instance();

View File

@@ -104,13 +104,13 @@ struct resample_impl : typed_primitive_impl_ocl<resample> {
return make_unique<resample_impl>(*this);
}
static primitive_impl* create(const resample_node& arg) {
auto us_params = get_default_params<kernel_selector::resample_params>(arg);
static primitive_impl* create(const resample_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& primitive = arg.get_primitive();
auto us_params = get_default_params<kernel_selector::resample_params>(*impl_param);
auto us_optional_params =
get_default_optional_params<kernel_selector::resample_optional_params>(arg.get_program());
const auto& primitive = arg.get_primitive();
size_t dimsNum = arg.get_output_layout().format.dimension();
size_t dimsNum = impl_param->output_layout.format.dimension();
us_params.resampleType = convert_to_sample_type(primitive->operation_type);
us_params.nearestMode = convert_to_nearest_mode(primitive->round_mode);
us_params.coordTransMode = convert_to_coord_transform_mode(primitive->coord_trans_mode);

View File

@@ -22,12 +22,11 @@ struct reshape_impl : public typed_primitive_impl_ocl<reshape> {
}
public:
static primitive_impl* create(reshape_node const& arg) {
static primitive_impl* create(reshape_node const& arg, std::shared_ptr<kernel_impl_params> impl_param) {
if (arg.can_be_optimized()) {
return new reshape_impl(arg, {});
}
auto reorder_params = get_default_params<kernel_selector::reshape_params>(arg);
auto reorder_params = get_default_params<kernel_selector::reshape_params>(*impl_param);
auto reorder_optional_params =
get_default_optional_params<kernel_selector::reshape_optional_params>(arg.get_program());

View File

@@ -24,8 +24,8 @@ struct reverse_impl : typed_primitive_impl_ocl<reverse> {
}
public:
static primitive_impl* create(const reverse_node& arg) {
auto params = get_default_params<kernel_selector::reverse_params>(arg);
static primitive_impl* create(const reverse_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::reverse_params>(*impl_param);
const auto optional_params =
get_default_optional_params<kernel_selector::reverse_optional_params>(arg.get_program());

View File

@@ -23,15 +23,16 @@ struct reverse_sequence_impl : typed_primitive_impl_ocl<reverse_sequence> {
}
public:
static primitive_impl* create(const reverse_sequence_node& arg) {
auto reverse_sequence_params = get_default_params<kernel_selector::reverse_sequence_params>(arg);
static primitive_impl* create(const reverse_sequence_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto reverse_sequence_params = get_default_params<kernel_selector::reverse_sequence_params>(*impl_param);
auto reverse_sequence_optional_params =
get_default_optional_params<kernel_selector::reverse_sequence_optional_params>(arg.get_program());
reverse_sequence_params.seq_axis = arg.get_primitive()->seq_axis;
reverse_sequence_params.batch_axis = arg.get_primitive()->batch_axis;
reverse_sequence_params.seq_axis = prim->seq_axis;
reverse_sequence_params.batch_axis = prim->batch_axis;
reverse_sequence_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
reverse_sequence_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
auto& kernel_selector = kernel_selector::reverse_sequence_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(reverse_sequence_params, reverse_sequence_optional_params);

View File

@@ -55,11 +55,11 @@ protected:
}
public:
static primitive_impl* create(const roi_align_node& arg) {
const auto& input_layout = arg.input().get_output_layout();
const auto& output_layout = arg.get_output_layout();
const auto& rois_layout = arg.input(1).get_output_layout();
const auto& batches_layout = arg.input(2).get_output_layout();
static primitive_impl* create(const roi_align_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& input_layout = impl_param->input_layouts[0];
const auto& output_layout = impl_param->output_layout;
const auto& rois_layout = impl_param->input_layouts[1];
const auto& batches_layout = impl_param->input_layouts[2];
const auto& primitive = arg.get_primitive();
const auto padding_filling_value = output_layout.data_padding.filling_value();
@@ -75,8 +75,7 @@ public:
input_layout.format.value,
"output_layout.format",
output_layout.format);
auto roi_align_params = get_default_params<kernel_selector::roi_align_params>(arg);
auto roi_align_params = get_default_params<kernel_selector::roi_align_params>(*impl_param);
auto roi_align_optional_params =
get_default_optional_params<kernel_selector::roi_align_optional_params>(arg.get_program());

View File

@@ -59,10 +59,10 @@ protected:
}
public:
static primitive_impl* create(const roi_pooling_node& arg) {
const auto& input_layout = arg.input().get_output_layout();
const auto& output_layout = arg.get_output_layout();
const auto& rois_layout = arg.rois().get_output_layout();
static primitive_impl* create(const roi_pooling_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& input_layout = impl_param->input_layouts[0];
const auto& output_layout = impl_param->output_layout;
const auto& rois_layout = impl_param->input_layouts[1];
const auto& primitive = arg.get_primitive();
const auto padding_filling_value = output_layout.data_padding.filling_value();
@@ -78,8 +78,7 @@ public:
input_layout.format.value,
"output_layout.format",
output_layout.format);
auto roi_params = get_default_params<kernel_selector::roi_pooling_params>(arg);
auto roi_params = get_default_params<kernel_selector::roi_pooling_params>(*impl_param);
auto roi_optional_params =
get_default_optional_params<kernel_selector::roi_pooling_optional_params>(arg.get_program());
@@ -87,7 +86,7 @@ public:
const auto roi_bf = roi_bfyx.FlattenFeatureAndSpatials();
roi_params.inputs.push_back(roi_bf);
if (primitive->mode == pooling_mode::deformable_bilinear && !primitive->no_trans)
roi_params.inputs.push_back(convert_data_tensor(arg.trans().get_output_layout()));
roi_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
roi_params.mode = cldnn_2_pool_type(primitive->mode);
roi_params.position_sensitive = primitive->position_sensitive;
roi_params.pooled_width = primitive->pooled_width;

View File

@@ -24,8 +24,8 @@ struct roll_impl : typed_primitive_impl_ocl<roll> {
return make_unique<roll_impl>(*this);
}
static primitive_impl* create(const roll_node& arg) {
auto roll_params = get_default_params<kernel_selector::roll_params>(arg);
static primitive_impl* create(const roll_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto roll_params = get_default_params<kernel_selector::roll_params>(*impl_param);
auto roll_optional_params =
get_default_optional_params<kernel_selector::roll_optional_params>(arg.get_program());

View File

@@ -36,19 +36,19 @@ protected:
}
public:
static primitive_impl* create(const scale_node& arg) {
auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg);
static primitive_impl* create(const scale_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto ew_params = get_default_params<kernel_selector::eltwise_params>(*impl_param);
auto ew_optional_params =
get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
ew_params.inputs.push_back(convert_data_tensor(arg.scale_in().get_output_layout()));
ew_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
ew_params.operations.push_back({{kernel_selector::eltwise_params::InputType::Buffer(0),
kernel_selector::eltwise_params::InputType::Buffer(1)},
kernel_selector::eltwise_mode::MUL});
if (arg.bias_term()) {
ew_params.inputs.push_back(convert_data_tensor(arg.bias().get_output_layout()));
ew_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
ew_params.operations.push_back({{kernel_selector::eltwise_params::InputType::Intermediate(0),
kernel_selector::eltwise_params::InputType::Buffer(2)},
kernel_selector::eltwise_mode::ADD});

View File

@@ -43,15 +43,16 @@ struct scatter_elements_update_impl : typed_primitive_impl_ocl<scatter_elements_
}
public:
static primitive_impl* create(const scatter_elements_update_node& arg) {
auto scatter_elements_update_params = get_default_params<kernel_selector::scatter_elements_update_params>(arg);
static primitive_impl* create(const scatter_elements_update_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto scatter_elements_update_params = get_default_params<kernel_selector::scatter_elements_update_params>(*impl_param);
auto scatter_elements_update_optional_params =
get_default_optional_params<kernel_selector::scatter_elements_update_optional_params>(arg.get_program());
scatter_elements_update_params.axis = convert_axis(arg.get_primitive()->axis, arg);
scatter_elements_update_params.axis = convert_axis(prim->axis, arg);
scatter_elements_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
scatter_elements_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
scatter_elements_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
scatter_elements_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
auto& kernel_selector = kernel_selector::scatter_elements_update_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(scatter_elements_update_params, scatter_elements_update_optional_params);

View File

@@ -24,15 +24,15 @@ struct scatter_nd_update_impl : typed_primitive_impl_ocl<scatter_nd_update> {
}
public:
static primitive_impl* create(const scatter_nd_update_node& arg) {
auto scatter_nd_update_params = get_default_params<kernel_selector::scatter_nd_update_params>(arg);
static primitive_impl* create(const scatter_nd_update_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto scatter_nd_update_params = get_default_params<kernel_selector::scatter_nd_update_params>(*impl_param);
auto scatter_nd_update_optional_params =
get_default_optional_params<kernel_selector::scatter_nd_update_optional_params>(arg.get_program());
scatter_nd_update_params.indices_rank = arg.get_primitive()->indices_rank;
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
scatter_nd_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
scatter_nd_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
auto& kernel_selector = kernel_selector::scatter_nd_update_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(scatter_nd_update_params, scatter_nd_update_optional_params);

View File

@@ -43,15 +43,15 @@ struct scatter_update_impl : typed_primitive_impl_ocl<scatter_update> {
}
public:
static primitive_impl* create(const scatter_update_node& arg) {
auto scatter_update_params = get_default_params<kernel_selector::scatter_update_params>(arg);
static primitive_impl* create(const scatter_update_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto scatter_update_params = get_default_params<kernel_selector::scatter_update_params>(*impl_param);
auto scatter_update_optional_params =
get_default_optional_params<kernel_selector::scatter_update_optional_params>(arg.get_program());
scatter_update_params.axis = convert_axis(arg.get_primitive()->axis, arg);
scatter_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
scatter_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
scatter_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[1]));
scatter_update_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[2]));
auto& kernel_selector = kernel_selector::scatter_update_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(scatter_update_params, scatter_update_optional_params);

View File

@@ -22,13 +22,13 @@ struct select_impl : typed_primitive_impl_ocl<select> {
}
public:
static primitive_impl* create(const select_node& arg) {
auto select_params = get_default_params<kernel_selector::select_params>(arg);
static primitive_impl* create(const select_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto select_params = get_default_params<kernel_selector::select_params>(*impl_param);
auto select_optional_params =
get_default_optional_params<kernel_selector::select_optional_params>(arg.get_program());
for (size_t i = 1; i < arg.inputs_count(); i++) {
select_params.inputs.push_back(convert_data_tensor(arg.input(i).get_output_layout()));
select_params.inputs.push_back(convert_data_tensor(impl_param->input_layouts[i]));
}
auto& kernel_selector = kernel_selector::select_kernel_selector::Instance();

View File

@@ -21,13 +21,14 @@ struct shape_of_impl : typed_primitive_impl_ocl<shape_of> {
return make_unique<shape_of_impl>(*this);
}
static primitive_impl* create(const shape_of_node& arg) {
auto shape_of_params = get_default_params<kernel_selector::shape_of_params>(arg);
static primitive_impl* create(const shape_of_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto shape_of_params = get_default_params<kernel_selector::shape_of_params>(*impl_param);
auto shape_of_optional_params =
get_default_optional_params<kernel_selector::shape_of_optional_params>(arg.get_program());
shape_of_params.input_rank = arg.get_dependency(0).get_output_layout().get_rank();
shape_of_params.input_dims = arg.get_dependency(0).get_output_layout().get_dims();
auto input_layout = impl_param->input_layouts[0];
shape_of_params.input_rank = input_layout.get_rank();
shape_of_params.input_dims = input_layout.get_dims();
auto& kernel_selector = kernel_selector::shape_of_instance();
auto best_kernels = kernel_selector.GetBestKernels(shape_of_params, shape_of_optional_params);

View File

@@ -24,18 +24,19 @@ struct shuffle_channels_impl : typed_primitive_impl_ocl<shuffle_channels> {
}
public:
static primitive_impl* create(const shuffle_channels_node& arg) {
auto shuffle_channels_params = get_default_params<kernel_selector::shuffle_channels_params>(arg);
static primitive_impl* create(const shuffle_channels_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto shuffle_channels_params = get_default_params<kernel_selector::shuffle_channels_params>(*impl_param);
auto shuffle_channels_optional_params =
get_default_optional_params<kernel_selector::shuffle_channels_optional_params>(arg.get_program());
const int32_t number_of_dims = 4;
int32_t axis = arg.get_primitive()->axis;
int32_t axis = prim->axis;
if (axis < 0)
axis += number_of_dims;
shuffle_channels_params.group = arg.get_primitive()->group;
shuffle_channels_params.group = prim->group;
shuffle_channels_params.axis = axis;
auto& kernel_selector = kernel_selector::shuffle_channels_kernel_selector::Instance();

View File

@@ -70,12 +70,9 @@ struct slice_impl : typed_primitive_impl_ocl<slice> {
return make_unique<slice_impl>(*this);
}
static primitive_impl* create(const slice_node& arg) {
auto params = get_default_params<kernel_selector::slice_params>(
arg);
auto op_params = get_default_optional_params<
kernel_selector::slice_optional_params>(
arg.get_program());
static primitive_impl* create(const slice_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto params = get_default_params<kernel_selector::slice_params>(*impl_param);
auto op_params = get_default_optional_params<kernel_selector::slice_optional_params>(arg.get_program());
const auto& inputs = arg.get_dependencies();
const stream& stream = arg.get_program().get_stream();
auto start_elts = extractIntegerData(inputs[InputIndices::kStart]->as<data>(), stream);

View File

@@ -43,13 +43,12 @@ struct softmax_impl : typed_primitive_impl_ocl<softmax> {
return make_unique<softmax_impl>(*this);
}
static primitive_impl* create(const softmax_node& arg) {
auto sm_params = get_default_params<kernel_selector::softmax_params>(arg);
static primitive_impl* create(const softmax_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto primitive = arg.get_primitive();
auto sm_params = get_default_params<kernel_selector::softmax_params>(*impl_param);
auto sm_optional_params =
get_default_optional_params<kernel_selector::softmax_optional_params>(arg.get_program());
const auto primitive = arg.get_primitive();
size_t rank = arg.get_output_layout().get_rank();
sm_params.dim = GetSoftmaxDim(primitive->dimension, rank);

View File

@@ -25,13 +25,12 @@ struct space_to_batch_impl : typed_primitive_impl_ocl<space_to_batch> {
}
public:
static primitive_impl* create(const space_to_batch_node& arg) {
auto space_to_batch_params = get_default_params<kernel_selector::space_to_batch_params>(arg);
static primitive_impl* create(const space_to_batch_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto primitive = arg.get_primitive();
auto space_to_batch_params = get_default_params<kernel_selector::space_to_batch_params>(*impl_param);
auto space_to_batch_optional_params =
get_default_optional_params<kernel_selector::space_to_batch_optional_params>(arg.get_program());
auto primitive = arg.get_primitive();
space_to_batch_params.block_shape = convert_dim_vector(primitive->block_shape);
space_to_batch_params.pads_begin = convert_dim_vector(primitive->pads_begin);
space_to_batch_params.pads_end = convert_dim_vector(primitive->pads_end);

View File

@@ -23,16 +23,17 @@ struct space_to_depth_impl : typed_primitive_impl_ocl<space_to_depth> {
}
public:
static primitive_impl* create(const space_to_depth_node& arg) {
auto space_to_depth_params = get_default_params<kernel_selector::space_to_depth_params>(arg);
static primitive_impl* create(const space_to_depth_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto space_to_depth_params = get_default_params<kernel_selector::space_to_depth_params>(*impl_param);
auto space_to_depth_optional_params =
get_default_optional_params<kernel_selector::space_to_depth_optional_params>(arg.get_program());
space_to_depth_params.depth_mode = (arg.get_primitive()->mode == space_to_depth::blocks_first) ?
space_to_depth_params.depth_mode = (prim->mode == space_to_depth::blocks_first) ?
kernel_selector::SpaceToDepthMode::BLOCKS_FIRST :
kernel_selector::SpaceToDepthMode::DEPTH_FIRST;
space_to_depth_params.block_size = arg.get_primitive()->block_size;
space_to_depth_params.block_size = prim->block_size;
auto& kernel_selector = kernel_selector::space_to_depth_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(space_to_depth_params, space_to_depth_optional_params);

View File

@@ -26,8 +26,9 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
}
public:
static primitive_impl* create(const strided_slice_node& arg) {
auto params = get_default_params<kernel_selector::strided_slice_params>(arg);
static primitive_impl* create(const strided_slice_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
const auto& prim = arg.get_primitive();
auto params = get_default_params<kernel_selector::strided_slice_params>(*impl_param);
auto op_params = get_default_optional_params<kernel_selector::strided_slice_optional_params>(arg.get_program());
const size_t dims_num = params.inputs[0].Dimentions();
@@ -52,10 +53,10 @@ public:
params.striding_params.push_back(sizes);
}
auto begin_mask_ = arg.get_primitive()->begin_mask;
auto end_mask_ = arg.get_primitive()->end_mask;
auto new_axis_mask_ = arg.get_primitive()->new_axis_mask;
auto shrink_axis_mask_ = arg.get_primitive()->shrink_axis_mask;
auto begin_mask_ = prim->begin_mask;
auto end_mask_ = prim->end_mask;
auto new_axis_mask_ = prim->new_axis_mask;
auto shrink_axis_mask_ = prim->shrink_axis_mask;
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
@@ -72,6 +73,7 @@ public:
pad_vector_to_size(params.end_mask, dims_num, 1);
params.begin_mask = begin_mask;
pad_vector_to_size(params.begin_mask, dims_num, 1);
params.new_axis_mask = new_axis_mask;
params.shrink_axis_mask = shrink_axis_mask;
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0);

View File

@@ -24,8 +24,8 @@ struct tile_impl : typed_primitive_impl_ocl<tile> {
}
public:
static primitive_impl* create(const tile_node& arg) {
auto tile_params = get_default_params<kernel_selector::tile_params>(arg);
static primitive_impl* create(const tile_node& arg, std::shared_ptr<kernel_impl_params> impl_param) {
auto tile_params = get_default_params<kernel_selector::tile_params>(*impl_param);
auto tile_optional_params =
get_default_optional_params<kernel_selector::tile_optional_params>(arg.get_program());

View File

@@ -62,7 +62,7 @@ protected:
}
public:
static primitive_impl* create(const concatenation_node& arg) {
static primitive_impl* create(const concatenation_node& arg, std::shared_ptr<kernel_impl_params>) {
auto desc = get_concatenation_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

View File

@@ -158,7 +158,20 @@ protected:
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);
set_params(arg, r_params);
const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout(),
arg.weights_zero_points_term() ? optional_layout(arg.weights_zero_points().get_output_layout())
: optional_layout(),
arg.activations_zero_points_term() ? optional_layout(arg.activations_zero_points().get_output_layout())
: optional_layout(),
arg.compensation_term() ? optional_layout(arg.compensation().get_output_layout())
: optional_layout());
set_params(param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), arg.get_groups(), false);
@@ -234,7 +247,7 @@ protected:
}
public:
static primitive_impl* create(const convolution_node& arg) {
static primitive_impl* create(const convolution_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_convolution_descriptor(arg);
auto attr = get_primitive_attributes(arg);

View File

@@ -78,7 +78,14 @@ protected:
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);
set_params(arg, r_params);
const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout());
set_params(param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), arg.get_groups(), false);
@@ -154,7 +161,7 @@ protected:
}
public:
static primitive_impl* create(const deconvolution_node& arg) {
static primitive_impl* create(const deconvolution_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_deconvolution_descriptor(arg);
auto attr = get_primitive_attributes(arg);

View File

@@ -58,17 +58,24 @@ protected:
}
static kernel_selector::WeightsReorderParams get_weights_reorder(const fully_connected_node& arg, const dnnl::primitive_desc& pd) {
auto weights_layout = arg.get_dependency(1).get_output_layout();
auto cldnn_prim = arg.get_primitive();
const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout());
kernel_selector::WeightsReorderParams weights_reorder_params;
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
kernel_selector::reorder_weights_params r_params;
auto cldnn_prim = arg.get_primitive();
auto weights_layout = arg.get_dependency(1).get_output_layout();
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0));
kernel_selector::WeightsLayout req_layout = to_weights_layout(out_fmt, false);
// set engine info & forcing
set_params(arg, r_params);
set_params(param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, false);
r_params.output = r_params.input.TransformIgnorePadding(req_layout, r_params.input.GetDType(), 1, false);
@@ -125,7 +132,7 @@ protected:
}
public:
static primitive_impl* create(const fully_connected_node& arg) {
static primitive_impl* create(const fully_connected_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_fully_connected_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

View File

@@ -123,7 +123,7 @@ protected:
}
public:
static primitive_impl* create(const gemm_node& arg) {
static primitive_impl* create(const gemm_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_gemm_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

View File

@@ -66,7 +66,7 @@ protected:
}
public:
static primitive_impl* create(const pooling_node& arg) {
static primitive_impl* create(const pooling_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_pooling_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

View File

@@ -63,7 +63,7 @@ protected:
}
public:
static primitive_impl* create(const reduce_node& arg) {
static primitive_impl* create(const reduce_node& arg, std::shared_ptr<kernel_impl_params>) {
auto& engine = arg.get_program().get_engine();
auto desc = get_reduction_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

Some files were not shown because too many files have changed in this diff Show More