[GPU] Pass concat unit tests on DG2 (#12142)

* check optimized
* skip kernel compile when optimized
This commit is contained in:
Felix Dohyun Kim 2022-08-03 17:30:41 +09:00 committed by GitHub
parent 029f94fad9
commit 9d5e799c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 20 deletions

View File

@ -63,6 +63,8 @@ protected:
public:
static primitive_impl* create(const concatenation_node& arg, std::shared_ptr<kernel_impl_params>) {
if (arg.can_be_optimized())
return new concatenation_onednn(arg);
auto desc = get_concatenation_descriptor(arg);
auto attr = arg.get_onednn_primitive_attributes();

View File

@ -51,6 +51,14 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
build_primitive();
}
typed_primitive_onednn_impl(const typed_program_node<PType>& arg)
: typed_primitive_impl<PType>({}, "undef"),
_outer(arg),
_pd(),
_prim() {
assert(arg.can_be_optimized());
}
bool is_cpu() const override { return false; }
private:
@ -240,6 +248,8 @@ protected:
}
void set_arguments_impl(typed_primitive_inst<PType>& instance) override {
if (instance.can_be_optimized())
return;
uint32_t net_id = instance.get_network().get_id();
_args[net_id] = get_arguments(instance);
}

View File

@ -5,6 +5,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
#include "test_utils.h"
#include "concatenation_inst.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/convolution.hpp>
@ -860,16 +861,9 @@ public:
}
concat_network.execute();
for (auto i : concat_network.get_primitives_info()) {
// std::cout << " " << i.original_id << " " << i.kernel_id << std::endl;
if (i.original_id == "concat") {
if (options.get<build_option_type::optimize_data>()->enabled()) {
EXPECT_TRUE(i.kernel_id == "undef");
} else {
EXPECT_FALSE(i.kernel_id == "undef");
}
}
}
bool concat_opt_enabled = options.get<build_option_type::optimize_data>()->enabled();
bool concat_opt_result = std::static_pointer_cast<concatenation_inst>(concat_network.get_primitive("concat"))->node.can_be_optimized();
EXPECT_TRUE(concat_opt_enabled==concat_opt_result);
return concat_network.get_output("reorder").get_memory();
}
@ -1079,16 +1073,9 @@ public:
}
concat_network.execute();
for (auto i : concat_network.get_primitives_info()) {
// std::cout << " " << i.original_id << " " << i.kernel_id << std::endl;
if (i.original_id == "concat") {
if (options.get<build_option_type::optimize_data>()->enabled()) {
EXPECT_TRUE(i.kernel_id == "undef");
} else {
EXPECT_FALSE(i.kernel_id == "undef");
}
}
}
bool concat_opt_enabled = options.get<build_option_type::optimize_data>()->enabled();
bool concat_opt_result = std::static_pointer_cast<concatenation_inst>(concat_network.get_primitive("concat"))->node.can_be_optimized();
EXPECT_TRUE(concat_opt_enabled==concat_opt_result);
return concat_network.get_output("reorder").get_memory();
}