[GPU] Pass concat unit tests on DG2 (#12142)
* check optimized * skip kernel compile when optimized
This commit is contained in:
parent
029f94fad9
commit
9d5e799c62
@ -63,6 +63,8 @@ protected:
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const concatenation_node& arg, std::shared_ptr<kernel_impl_params>) {
|
||||
if (arg.can_be_optimized())
|
||||
return new concatenation_onednn(arg);
|
||||
auto desc = get_concatenation_descriptor(arg);
|
||||
auto attr = arg.get_onednn_primitive_attributes();
|
||||
|
||||
|
@ -51,6 +51,14 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
|
||||
build_primitive();
|
||||
}
|
||||
|
||||
typed_primitive_onednn_impl(const typed_program_node<PType>& arg)
|
||||
: typed_primitive_impl<PType>({}, "undef"),
|
||||
_outer(arg),
|
||||
_pd(),
|
||||
_prim() {
|
||||
assert(arg.can_be_optimized());
|
||||
}
|
||||
|
||||
bool is_cpu() const override { return false; }
|
||||
|
||||
private:
|
||||
@ -240,6 +248,8 @@ protected:
|
||||
}
|
||||
|
||||
void set_arguments_impl(typed_primitive_inst<PType>& instance) override {
|
||||
if (instance.can_be_optimized())
|
||||
return;
|
||||
uint32_t net_id = instance.get_network().get_id();
|
||||
_args[net_id] = get_arguments(instance);
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "test_utils.h"
|
||||
#include "concatenation_inst.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/convolution.hpp>
|
||||
@ -860,16 +861,9 @@ public:
|
||||
}
|
||||
concat_network.execute();
|
||||
|
||||
for (auto i : concat_network.get_primitives_info()) {
|
||||
// std::cout << " " << i.original_id << " " << i.kernel_id << std::endl;
|
||||
if (i.original_id == "concat") {
|
||||
if (options.get<build_option_type::optimize_data>()->enabled()) {
|
||||
EXPECT_TRUE(i.kernel_id == "undef");
|
||||
} else {
|
||||
EXPECT_FALSE(i.kernel_id == "undef");
|
||||
}
|
||||
}
|
||||
}
|
||||
bool concat_opt_enabled = options.get<build_option_type::optimize_data>()->enabled();
|
||||
bool concat_opt_result = std::static_pointer_cast<concatenation_inst>(concat_network.get_primitive("concat"))->node.can_be_optimized();
|
||||
EXPECT_TRUE(concat_opt_enabled==concat_opt_result);
|
||||
|
||||
return concat_network.get_output("reorder").get_memory();
|
||||
}
|
||||
@ -1079,16 +1073,9 @@ public:
|
||||
}
|
||||
concat_network.execute();
|
||||
|
||||
for (auto i : concat_network.get_primitives_info()) {
|
||||
// std::cout << " " << i.original_id << " " << i.kernel_id << std::endl;
|
||||
if (i.original_id == "concat") {
|
||||
if (options.get<build_option_type::optimize_data>()->enabled()) {
|
||||
EXPECT_TRUE(i.kernel_id == "undef");
|
||||
} else {
|
||||
EXPECT_FALSE(i.kernel_id == "undef");
|
||||
}
|
||||
}
|
||||
}
|
||||
bool concat_opt_enabled = options.get<build_option_type::optimize_data>()->enabled();
|
||||
bool concat_opt_result = std::static_pointer_cast<concatenation_inst>(concat_network.get_primitive("concat"))->node.can_be_optimized();
|
||||
EXPECT_TRUE(concat_opt_enabled==concat_opt_result);
|
||||
|
||||
return concat_network.get_output("reorder").get_memory();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user