[GPU] Minor fixes for dynamic BERT models (#16158)

* [GPU] Minor fix for dynamic bert-base-uncased-qqp

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Fix to check full tensor only for static shape during creating onednn gemm

Signed-off-by: Andrew Park <andrew.park@intel.com>

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park
2023-03-10 07:48:08 +09:00
committed by GitHub
parent dff7f2451b
commit 3ec386a741
2 changed files with 9 additions and 2 deletions

View File

@@ -46,6 +46,11 @@ void prepare_primitive_fusing_through::run(program& p) {
if (node->is_type<reshape>() && node->get_dependencies().front().first->is_type<reduce>())
return false;
// Not to raise up target node through reshape where the size of dimension is changed (e.g. Unsqueeze)
if (node->is_type<reshape>() &&
node->get_output_layout().get_partial_shape().size() != node->get_dependency(0).get_output_layout().get_partial_shape().size())
return false;
return true;
};

View File

@@ -303,8 +303,10 @@ public:
static std::unique_ptr<primitive_impl> create(const gemm_node& arg, const kernel_impl_params& impl_params) {
bool full_tensor_or_per_tensor = true;
for (auto prim : arg.get_fused_primitives()) {
full_tensor_or_per_tensor &=
prim.input_layout.count() == prim.output_layout.count() || prim.input_layout.count() == 1;
if (prim.input_layout.is_static() && prim.output_layout.is_static()) {
full_tensor_or_per_tensor &=
prim.input_layout.count() == prim.output_layout.count() || prim.input_layout.count() == 1;
}
}
if (!full_tensor_or_per_tensor) {
IE_THROW() << "Unimplemented: per channel binary post-operation is not supported for onednn gemm. Refer PR(#15353) message.";