[GPU] Use transformed gemm shapes for layout optimizer checks (#14407)
This commit is contained in:
parent
f91d3d1d04
commit
6df011c9f6
@ -113,6 +113,99 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& /*node*/, co
|
||||
|
||||
template std::vector<layout> gemm_inst::calc_output_layouts<ov::PartialShape>(gemm_node const& node, const kernel_impl_params& impl_param);
|
||||
|
||||
std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<const gemm> primitive,
|
||||
const std::vector<layout>& input_layouts,
|
||||
const layout& output_layout) {
|
||||
auto get_updated_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
|
||||
ov::PartialShape updated_input_pshape;
|
||||
|
||||
if (input_rank == 1) {
|
||||
if (input_pshape.is_static()) {
|
||||
auto input_shape = input_pshape.to_shape();
|
||||
updated_input_pshape = ov::PartialShape{ static_cast<int64_t>(*std::max_element(input_shape.begin(), input_shape.end())) };
|
||||
} else {
|
||||
updated_input_pshape = ov::PartialShape::dynamic(input_rank);
|
||||
}
|
||||
} else {
|
||||
if (input_pshape.is_static()) {
|
||||
OPENVINO_ASSERT(input_pshape.size() >= input_rank, "[GPU] Requested input rank in gemm primitive is greater than actual shape");
|
||||
std::vector<ov::Dimension> dims(input_pshape.begin(), input_pshape.begin() + input_rank);
|
||||
updated_input_pshape = ov::PartialShape(dims);
|
||||
} else {
|
||||
updated_input_pshape = input_pshape;
|
||||
}
|
||||
}
|
||||
|
||||
if (updated_input_pshape.size() == 1) {
|
||||
first_input ? updated_input_pshape.insert(updated_input_pshape.begin(), 1)
|
||||
: updated_input_pshape.insert(updated_input_pshape.end(), 1);
|
||||
|
||||
if (transpose) {
|
||||
std::swap(updated_input_pshape[0], updated_input_pshape[1]);
|
||||
}
|
||||
}
|
||||
size_t ones_to_add = std::max(output_rank, static_cast<size_t>(4)) - updated_input_pshape.size();
|
||||
updated_input_pshape.insert(updated_input_pshape.begin(), ones_to_add, 1ul);
|
||||
|
||||
return updated_input_pshape;
|
||||
};
|
||||
|
||||
auto input0_pshape = input_layouts[0].get_partial_shape();
|
||||
auto input1_pshape = input_layouts[1].get_partial_shape();
|
||||
|
||||
bool reordered = primitive->input_rank > 4 || primitive->weight_rank > 4;
|
||||
size_t output_rank = std::max(primitive->input_rank, primitive->weight_rank);
|
||||
size_t input_rank = reordered ? output_rank : primitive->input_rank;
|
||||
size_t weight_rank = reordered ? output_rank : primitive->weight_rank;
|
||||
|
||||
auto updated_input0_pshape = get_updated_input_shape(input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
|
||||
auto updated_input1_pshape = get_updated_input_shape(input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
|
||||
|
||||
std::vector<layout> layouts = input_layouts;
|
||||
layouts[0].set_partial_shape(updated_input0_pshape);
|
||||
layouts[1].set_partial_shape(updated_input1_pshape);
|
||||
|
||||
if (input_layouts.size() == 3) {
|
||||
auto bias_pshape = input_layouts[2].get_partial_shape();
|
||||
auto updated_bias_pshape = get_updated_input_shape(bias_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
|
||||
layouts[2].set_partial_shape(updated_bias_pshape);
|
||||
}
|
||||
|
||||
return layouts;
|
||||
}
|
||||
|
||||
layout gemm_inst::transform_output_layout(const std::shared_ptr<const gemm> primitive,
|
||||
const std::vector<layout>& input_layouts,
|
||||
const layout& output_layout) {
|
||||
auto updated_output_layout = output_layout;
|
||||
auto output_rank = output_layout.get_partial_shape().size();
|
||||
if (output_rank < 4) {
|
||||
auto input0_pshape = input_layouts[0].get_partial_shape();
|
||||
auto input1_pshape = input_layouts[1].get_partial_shape();
|
||||
|
||||
auto M = !primitive->transpose_input0 ? input0_pshape[input0_pshape.size() - 2] : input0_pshape[input0_pshape.size() - 1];
|
||||
auto N = !primitive->transpose_input1 ? input1_pshape[input1_pshape.size() - 1] : input1_pshape[input1_pshape.size() - 2];
|
||||
|
||||
auto output_pshape = input_layouts[0].get_partial_shape();
|
||||
for (size_t i = 0; i != input_layouts.size(); ++i) {
|
||||
auto input_pshape = input_layouts[i].get_partial_shape();
|
||||
for (size_t j = 0; j != input_pshape.size(); ++j) {
|
||||
ov::Dimension::merge(output_pshape[j], output_pshape[j], input_pshape[j]);
|
||||
}
|
||||
}
|
||||
|
||||
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
|
||||
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
|
||||
return idx;
|
||||
};
|
||||
|
||||
output_pshape[get_spatial_idx(updated_output_layout.format, 0)] = N;
|
||||
output_pshape[get_spatial_idx(updated_output_layout.format, 1)] = M;
|
||||
updated_output_layout.set_partial_shape(output_pshape);
|
||||
}
|
||||
return updated_output_layout;
|
||||
}
|
||||
|
||||
std::string gemm_inst::to_string(gemm_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
|
@ -29,92 +29,8 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
|
||||
public:
|
||||
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
|
||||
const auto& primitive = impl_param.typed_desc<gemm>();
|
||||
auto get_gemm_input_layouts = [primitive](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||
auto get_updated_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, bool transpose, bool first_input) {
|
||||
ov::PartialShape updated_input_pshape;
|
||||
|
||||
if (input_rank == 1) {
|
||||
if (input_pshape.is_static()) {
|
||||
auto input_shape = input_pshape.to_shape();
|
||||
updated_input_pshape = ov::PartialShape{ static_cast<int64_t>(*std::max_element(input_shape.begin(), input_shape.end())) };
|
||||
} else {
|
||||
updated_input_pshape = ov::PartialShape::dynamic(input_rank);
|
||||
}
|
||||
} else {
|
||||
if (input_pshape.is_static()) {
|
||||
OPENVINO_ASSERT(input_pshape.size() >= input_rank, "[GPU] Requested input rank in gemm primitive is greater than actual shape");
|
||||
std::vector<ov::Dimension> dims(input_pshape.begin(), input_pshape.begin() + input_rank);
|
||||
updated_input_pshape = ov::PartialShape(dims);
|
||||
} else {
|
||||
updated_input_pshape = input_pshape;
|
||||
}
|
||||
}
|
||||
|
||||
if (updated_input_pshape.size() == 1) {
|
||||
first_input ? updated_input_pshape.insert(updated_input_pshape.begin(), 1)
|
||||
: updated_input_pshape.insert(updated_input_pshape.end(), 1);
|
||||
|
||||
if (transpose) {
|
||||
std::swap(updated_input_pshape[0], updated_input_pshape[1]);
|
||||
}
|
||||
}
|
||||
size_t ones_to_add = std::max(output_layout.get_partial_shape().size(), static_cast<size_t>(4)) - updated_input_pshape.size();
|
||||
updated_input_pshape.insert(updated_input_pshape.begin(), ones_to_add, 1ul);
|
||||
|
||||
return updated_input_pshape;
|
||||
};
|
||||
|
||||
auto input0_pshape = input_layouts[0].get_partial_shape();
|
||||
auto input1_pshape = input_layouts[1].get_partial_shape();
|
||||
|
||||
auto updated_input0_pshape = get_updated_input_shape(input0_pshape, primitive->input_rank, primitive->transpose_input0, true);
|
||||
auto updated_input1_pshape = get_updated_input_shape(input1_pshape, primitive->weight_rank, primitive->transpose_input1, false);
|
||||
|
||||
std::vector<layout> layouts = input_layouts;
|
||||
layouts[0].set_partial_shape(updated_input0_pshape);
|
||||
layouts[1].set_partial_shape(updated_input1_pshape);
|
||||
|
||||
if (input_layouts.size() == 3) {
|
||||
auto bias_pshape = input_layouts[2].get_partial_shape();
|
||||
auto updated_bias_pshape = get_updated_input_shape(bias_pshape, primitive->weight_rank, primitive->transpose_input1, false);
|
||||
layouts[2].set_partial_shape(updated_bias_pshape);
|
||||
}
|
||||
|
||||
return layouts;
|
||||
};
|
||||
|
||||
auto get_gemm_output_layout = [primitive](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||
auto updated_output_layout = output_layout;
|
||||
auto output_rank = output_layout.get_partial_shape().size();
|
||||
if (output_rank < 4) {
|
||||
auto input0_pshape = input_layouts[0].get_partial_shape();
|
||||
auto input1_pshape = input_layouts[1].get_partial_shape();
|
||||
|
||||
auto M = !primitive->transpose_input0 ? input0_pshape[input0_pshape.size() - 2] : input0_pshape[input0_pshape.size() - 1];
|
||||
auto N = !primitive->transpose_input1 ? input1_pshape[input1_pshape.size() - 1] : input1_pshape[input1_pshape.size() - 2];
|
||||
|
||||
auto output_pshape = input_layouts[0].get_partial_shape();
|
||||
for (size_t i = 0; i != input_layouts.size(); ++i) {
|
||||
auto input_pshape = input_layouts[i].get_partial_shape();
|
||||
for (size_t j = 0; j != input_pshape.size(); ++j) {
|
||||
ov::Dimension::merge(output_pshape[j], output_pshape[j], input_pshape[j]);
|
||||
}
|
||||
}
|
||||
|
||||
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
|
||||
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
|
||||
return idx;
|
||||
};
|
||||
|
||||
output_pshape[get_spatial_idx(updated_output_layout.format, 0)] = N;
|
||||
output_pshape[get_spatial_idx(updated_output_layout.format, 1)] = M;
|
||||
updated_output_layout.set_partial_shape(output_pshape);
|
||||
}
|
||||
return updated_output_layout;
|
||||
};
|
||||
|
||||
const auto input_layouts = get_gemm_input_layouts(impl_param.input_layouts, impl_param.output_layouts[0]);
|
||||
const auto output_layout = get_gemm_output_layout(input_layouts, impl_param.output_layouts[0]);
|
||||
const auto input_layouts = gemm_inst::transform_input_layouts(primitive, impl_param.input_layouts, impl_param.output_layouts[0]);
|
||||
const auto output_layout = gemm_inst::transform_output_layout(primitive, input_layouts, impl_param.output_layouts[0]);
|
||||
|
||||
auto params = get_default_params<kernel_selector::gemm_params>(impl_param, 1);
|
||||
auto optional_params = get_default_optional_params<kernel_selector::gemm_optional_params>(impl_param.get_program());
|
||||
|
@ -55,85 +55,6 @@ protected:
|
||||
|
||||
static std::shared_ptr<dnnl::matmul::desc> get_gemm_descriptor(const kernel_impl_params& impl_params) {
|
||||
auto prim = impl_params.typed_desc<gemm>();
|
||||
|
||||
auto get_gemm_input_layouts = [prim](const std::vector<layout>& input_layouts) {
|
||||
auto get_updated_input_shape = [&](const ov::Shape& input_shape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
|
||||
ov::Shape updated_input_shape;
|
||||
|
||||
if (input_rank == 1) {
|
||||
updated_input_shape = { *std::max_element(input_shape.begin(), input_shape.end()) };
|
||||
} else {
|
||||
updated_input_shape = ov::Shape(input_shape.begin(), input_shape.begin() + input_rank);
|
||||
}
|
||||
|
||||
if (updated_input_shape.size() == 1) {
|
||||
first_input ? updated_input_shape.insert(updated_input_shape.begin(), 1)
|
||||
: updated_input_shape.insert(updated_input_shape.end(), 1);
|
||||
|
||||
if (transpose) {
|
||||
std::swap(updated_input_shape[0], updated_input_shape[1]);
|
||||
}
|
||||
}
|
||||
size_t ones_to_add = std::max(output_rank, static_cast<size_t>(4)) - updated_input_shape.size();
|
||||
updated_input_shape.insert(updated_input_shape.begin(), ones_to_add, 1ul);
|
||||
|
||||
return updated_input_shape;
|
||||
};
|
||||
|
||||
auto input0_shape = input_layouts[0].get_shape();
|
||||
auto input1_shape = input_layouts[1].get_shape();
|
||||
|
||||
bool reordered = prim->input_rank > 4 || prim->weight_rank > 4;
|
||||
size_t output_rank = std::max(prim->input_rank, prim->weight_rank);
|
||||
size_t input_rank = reordered ? output_rank : prim->input_rank;
|
||||
size_t weight_rank = reordered ? output_rank : prim->weight_rank;
|
||||
|
||||
auto updated_input0_shape = get_updated_input_shape(input0_shape, input_rank, output_rank, prim->transpose_input0, true);
|
||||
auto updated_input1_shape = get_updated_input_shape(input1_shape, weight_rank, output_rank, prim->transpose_input1, false);
|
||||
|
||||
std::vector<layout> layouts = input_layouts;
|
||||
layouts[0].set_partial_shape(updated_input0_shape);
|
||||
layouts[1].set_partial_shape(updated_input1_shape);
|
||||
|
||||
if (input_layouts.size() == 3) {
|
||||
auto bias_shape = input_layouts[2].get_shape();
|
||||
auto updated_bias_shape = get_updated_input_shape(bias_shape, prim->weight_rank, output_rank, prim->transpose_input1, false);
|
||||
layouts[2].set_partial_shape(updated_bias_shape);
|
||||
}
|
||||
|
||||
return layouts;
|
||||
};
|
||||
|
||||
auto get_gemm_output_layout = [prim](const std::vector<layout>& input_layouts, const layout& output_layout) {
|
||||
auto updated_output_layout = output_layout;
|
||||
auto output_rank = output_layout.get_shape().size();
|
||||
if (output_rank < 4) {
|
||||
const auto& input0_layout = input_layouts[0];
|
||||
const auto& input1_layout = input_layouts[1];
|
||||
|
||||
auto M = !prim->transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
|
||||
auto N = !prim->transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);
|
||||
|
||||
auto output_shape = input0_layout.get_shape();
|
||||
for (const auto& input_layout : input_layouts) {
|
||||
auto input_shape = input_layout.get_shape();
|
||||
for (size_t i = 0; i != input_shape.size(); ++i) {
|
||||
output_shape[i] = std::max(output_shape[i], input_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
|
||||
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
|
||||
return idx;
|
||||
};
|
||||
|
||||
output_shape[get_spatial_idx(updated_output_layout.format, 0)] = N;
|
||||
output_shape[get_spatial_idx(updated_output_layout.format, 1)] = M;
|
||||
updated_output_layout.set_partial_shape(output_shape);
|
||||
}
|
||||
return updated_output_layout;
|
||||
};
|
||||
|
||||
auto gemm_with_bias = prim->dependencies().size() == 3;
|
||||
auto out_l = impl_params.get_output_layout();
|
||||
|
||||
@ -142,8 +63,8 @@ protected:
|
||||
in_layouts.emplace_back(impl_params.get_input_layout(2));
|
||||
}
|
||||
|
||||
in_layouts = get_gemm_input_layouts(in_layouts);
|
||||
out_l = get_gemm_output_layout(in_layouts, out_l);
|
||||
in_layouts = gemm_inst::transform_input_layouts(prim, in_layouts, out_l);
|
||||
out_l = gemm_inst::transform_output_layout(prim, in_layouts, out_l);
|
||||
|
||||
const auto& in0_l = in_layouts[0];
|
||||
const auto& in1_l = in_layouts[1];
|
||||
|
@ -35,7 +35,11 @@ public:
|
||||
static layout calc_output_layout(gemm_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(gemm_node const& node);
|
||||
|
||||
public:
|
||||
static std::vector<layout> transform_input_layouts(const std::shared_ptr<const gemm> primitive,
|
||||
const std::vector<layout>& input_layouts,
|
||||
const layout& output_layout);
|
||||
static layout transform_output_layout(const std::shared_ptr<const gemm> primitive, const std::vector<layout>& input_layouts, const layout& output_layout);
|
||||
|
||||
typed_primitive_inst(network& network, gemm_node const& node);
|
||||
};
|
||||
|
||||
|
@ -1592,13 +1592,22 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
if (node.is_dynamic()) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
} else {
|
||||
auto in0_l = node.get_dependency(0).get_output_layout();
|
||||
auto in1_l = node.get_dependency(1).get_output_layout();
|
||||
auto out_l = node.get_output_layout();
|
||||
auto has_input2 = gemm_prim->dependencies().size() == 3;
|
||||
std::vector<layout> in_layouts { node.get_dependency(0).get_output_layout(), node.get_dependency(1).get_output_layout() };
|
||||
if (has_input2) {
|
||||
in_layouts.emplace_back(node.get_dependency(2).get_output_layout());
|
||||
}
|
||||
auto out_l = node.get_output_layout();
|
||||
|
||||
in_layouts = gemm_inst::transform_input_layouts(gemm_prim, in_layouts, out_l);
|
||||
out_l = gemm_inst::transform_output_layout(gemm_prim, in_layouts, out_l);
|
||||
|
||||
auto in0_l = in_layouts[0];
|
||||
auto in1_l = in_layouts[1];
|
||||
|
||||
size_t in2_batched_size = 0;
|
||||
if (has_input2) {
|
||||
auto in2_l = node.get_dependency(2).get_output_layout();
|
||||
auto in2_l = in_layouts[2];
|
||||
in2_batched_size = in2_l.count() / (in2_l.spatial(0) * in2_l.spatial(1));
|
||||
}
|
||||
size_t size_k = gemm_prim->transpose_input0 ? in0_l.spatial(1) : in0_l.spatial(0);
|
||||
|
Loading…
Reference in New Issue
Block a user