[GPU] Insert transpose for gemm input if its trans attr is set (#8460)
If gemm input dimensions are not multiple of 16 and any of transpose_a/transpose_b attribute is set - cldnn picks 'gemm_ref' kernel in favor of faster 'gemm_tiled_opt'. By emplacing explicit permute operation on the gemm input that it requires, we make cldnn to pick 'gemm_tiled_opt', which in result improves performance. For some input shapes, transpose(s) + gemm_tiled_opt can be slower than just gemm_ref. Based on benchmarks - the cutoff point was set for inputs shapes > (64, 64). Ticket: 67271
This commit is contained in:
parent
378e334ee6
commit
063ff5a5c5
@ -2,6 +2,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "intel_gpu/plugin/program.hpp"
|
||||
#include "intel_gpu/plugin/common_utils.hpp"
|
||||
|
||||
@ -166,6 +168,63 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
|
||||
// Add actual gemm
|
||||
auto alpha = 1.0f;
|
||||
auto beta = 0.0f;
|
||||
auto transA = op->get_transpose_a();
|
||||
auto transB = op->get_transpose_b();
|
||||
|
||||
std::array<ngraph::PartialShape, 2> inputShapes{
|
||||
op->get_input_partial_shape(0),
|
||||
op->get_input_partial_shape(1)
|
||||
};
|
||||
|
||||
auto canTransposeInputs = [] (const std::array<ngraph::PartialShape, 2>& shapes, bool transA, bool transB) -> bool {
|
||||
if (!transA && !transB)
|
||||
return false;
|
||||
if (shapes[0].rank().is_dynamic() ||
|
||||
shapes[1].rank().is_dynamic())
|
||||
return false;
|
||||
|
||||
// don't transpose inputs if they're aligned to 16
|
||||
bool inputsAligned = std::all_of(shapes[0].rbegin(), shapes[0].rbegin() + 2,
|
||||
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() % 16 == 0; }) &&
|
||||
std::all_of(shapes[1].rbegin(), shapes[1].rbegin() + 2,
|
||||
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() % 16 == 0; });
|
||||
if (inputsAligned)
|
||||
return false;
|
||||
|
||||
return std::all_of(shapes[0].rbegin(), shapes[0].rbegin() + 2,
|
||||
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() >= 64; }) &&
|
||||
std::all_of(shapes[1].rbegin(), shapes[1].rbegin() + 2,
|
||||
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() >= 64; });
|
||||
};
|
||||
|
||||
auto transposeInput = [&layerName] (Program& p, const std::shared_ptr<ngraph::Node>& op, const ngraph::PartialShape& shape,
|
||||
const std::string& suffix, const cldnn::primitive_id& primitiveId) -> std::string {
|
||||
std::vector<uint16_t> transposeOrder(shape.size());
|
||||
std::iota(transposeOrder.begin(), transposeOrder.end(), 0);
|
||||
for (auto o = transposeOrder.size(); o < 4; o++)
|
||||
transposeOrder.push_back((uint16_t)o);
|
||||
std::swap(*(transposeOrder.end() - 1), *(transposeOrder.end() - 2));
|
||||
|
||||
auto permuteName = op->get_friendly_name() + suffix;
|
||||
auto permutePrim = cldnn::permute(permuteName,
|
||||
primitiveId,
|
||||
transposeOrder);
|
||||
p.add_primitive(*op, permutePrim);
|
||||
return permuteName;
|
||||
};
|
||||
|
||||
if (canTransposeInputs(inputShapes, transA, transB)) {
|
||||
if (transA) {
|
||||
inputPrimitives[0] = transposeInput(p, op, inputShapes[0], "/transpose_a", inputPrimitives[0]);
|
||||
transA = false;
|
||||
}
|
||||
|
||||
if (transB) {
|
||||
inputPrimitives[1] = transposeInput(p, op, inputShapes[1], "/transpose_b", inputPrimitives[1]);
|
||||
transB = false;
|
||||
}
|
||||
}
|
||||
|
||||
auto gemmPrim = cldnn::gemm(layerName,
|
||||
inputPrimitives,
|
||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||
|
@ -33,6 +33,9 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
|
||||
{ { {2, 1, 2, 3}, true }, { {3, 4, 2}, true } },
|
||||
{ { {3}, false }, { {2, 2, 3, 1}, false } },
|
||||
{ { {2, 2, 1, 3}, false }, { {3}, false } },
|
||||
{ { {65, 100}, false }, { {73, 100}, true } },
|
||||
{ { {100, 65}, true }, { {100, 73}, false } },
|
||||
{ { {100, 65}, true }, { {73, 100}, true } },
|
||||
{ { {1, 5}, false }, { {5, 1}, false } },
|
||||
{ { {5, 1}, true }, { {5, 1}, false } },
|
||||
{ { {1, 5}, false }, { {1, 5}, true } },
|
||||
|
Loading…
Reference in New Issue
Block a user