[GPU] Insert transpose for gemm input if its trans attr is set (#8460)
If gemm input dimensions are not multiple of 16 and any of transpose_a/transpose_b attribute is set - cldnn picks 'gemm_ref' kernel in favor of faster 'gemm_tiled_opt'. By emplacing explicit permute operation on the gemm input that it requires, we make cldnn to pick 'gemm_tiled_opt', which in result improves performance. For some input shapes, transpose(s) + gemm_tiled_opt can be slower than just gemm_ref. Based on benchmarks - the cutoff point was set for inputs shapes > (64, 64). Ticket: 67271
This commit is contained in:
parent
378e334ee6
commit
063ff5a5c5
@ -2,6 +2,8 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
#include "intel_gpu/plugin/program.hpp"
|
#include "intel_gpu/plugin/program.hpp"
|
||||||
#include "intel_gpu/plugin/common_utils.hpp"
|
#include "intel_gpu/plugin/common_utils.hpp"
|
||||||
|
|
||||||
@ -166,6 +168,63 @@ static void CreateMatMulOp(Program& p, const std::shared_ptr<ngraph::op::v0::Mat
|
|||||||
// Add actual gemm
|
// Add actual gemm
|
||||||
auto alpha = 1.0f;
|
auto alpha = 1.0f;
|
||||||
auto beta = 0.0f;
|
auto beta = 0.0f;
|
||||||
|
auto transA = op->get_transpose_a();
|
||||||
|
auto transB = op->get_transpose_b();
|
||||||
|
|
||||||
|
std::array<ngraph::PartialShape, 2> inputShapes{
|
||||||
|
op->get_input_partial_shape(0),
|
||||||
|
op->get_input_partial_shape(1)
|
||||||
|
};
|
||||||
|
|
||||||
|
auto canTransposeInputs = [] (const std::array<ngraph::PartialShape, 2>& shapes, bool transA, bool transB) -> bool {
|
||||||
|
if (!transA && !transB)
|
||||||
|
return false;
|
||||||
|
if (shapes[0].rank().is_dynamic() ||
|
||||||
|
shapes[1].rank().is_dynamic())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// don't transpose inputs if they're aligned to 16
|
||||||
|
bool inputsAligned = std::all_of(shapes[0].rbegin(), shapes[0].rbegin() + 2,
|
||||||
|
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() % 16 == 0; }) &&
|
||||||
|
std::all_of(shapes[1].rbegin(), shapes[1].rbegin() + 2,
|
||||||
|
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() % 16 == 0; });
|
||||||
|
if (inputsAligned)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return std::all_of(shapes[0].rbegin(), shapes[0].rbegin() + 2,
|
||||||
|
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() >= 64; }) &&
|
||||||
|
std::all_of(shapes[1].rbegin(), shapes[1].rbegin() + 2,
|
||||||
|
[] (const ngraph::Dimension& dim) { return dim.is_static() && dim.get_length() >= 64; });
|
||||||
|
};
|
||||||
|
|
||||||
|
auto transposeInput = [&layerName] (Program& p, const std::shared_ptr<ngraph::Node>& op, const ngraph::PartialShape& shape,
|
||||||
|
const std::string& suffix, const cldnn::primitive_id& primitiveId) -> std::string {
|
||||||
|
std::vector<uint16_t> transposeOrder(shape.size());
|
||||||
|
std::iota(transposeOrder.begin(), transposeOrder.end(), 0);
|
||||||
|
for (auto o = transposeOrder.size(); o < 4; o++)
|
||||||
|
transposeOrder.push_back((uint16_t)o);
|
||||||
|
std::swap(*(transposeOrder.end() - 1), *(transposeOrder.end() - 2));
|
||||||
|
|
||||||
|
auto permuteName = op->get_friendly_name() + suffix;
|
||||||
|
auto permutePrim = cldnn::permute(permuteName,
|
||||||
|
primitiveId,
|
||||||
|
transposeOrder);
|
||||||
|
p.add_primitive(*op, permutePrim);
|
||||||
|
return permuteName;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (canTransposeInputs(inputShapes, transA, transB)) {
|
||||||
|
if (transA) {
|
||||||
|
inputPrimitives[0] = transposeInput(p, op, inputShapes[0], "/transpose_a", inputPrimitives[0]);
|
||||||
|
transA = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transB) {
|
||||||
|
inputPrimitives[1] = transposeInput(p, op, inputShapes[1], "/transpose_b", inputPrimitives[1]);
|
||||||
|
transB = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto gemmPrim = cldnn::gemm(layerName,
|
auto gemmPrim = cldnn::gemm(layerName,
|
||||||
inputPrimitives,
|
inputPrimitives,
|
||||||
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
|
||||||
|
@ -33,6 +33,9 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
|
|||||||
{ { {2, 1, 2, 3}, true }, { {3, 4, 2}, true } },
|
{ { {2, 1, 2, 3}, true }, { {3, 4, 2}, true } },
|
||||||
{ { {3}, false }, { {2, 2, 3, 1}, false } },
|
{ { {3}, false }, { {2, 2, 3, 1}, false } },
|
||||||
{ { {2, 2, 1, 3}, false }, { {3}, false } },
|
{ { {2, 2, 1, 3}, false }, { {3}, false } },
|
||||||
|
{ { {65, 100}, false }, { {73, 100}, true } },
|
||||||
|
{ { {100, 65}, true }, { {100, 73}, false } },
|
||||||
|
{ { {100, 65}, true }, { {73, 100}, true } },
|
||||||
{ { {1, 5}, false }, { {5, 1}, false } },
|
{ { {1, 5}, false }, { {5, 1}, false } },
|
||||||
{ { {5, 1}, true }, { {5, 1}, false } },
|
{ { {5, 1}, true }, { {5, 1}, false } },
|
||||||
{ { {1, 5}, false }, { {1, 5}, true } },
|
{ { {1, 5}, false }, { {1, 5}, true } },
|
||||||
|
Loading…
Reference in New Issue
Block a user