Review comments applied (#17168)

This commit is contained in:
Vladislav Golubev 2023-04-24 22:59:03 +02:00 committed by GitHub
parent 8e5b0650a0
commit a6b1544acf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 6 deletions

View File

@ -253,6 +253,9 @@ void TileBroadcastCommon::optimizedExecute(const MemoryPtr& srcMemory, const Mem
if (srcMemory->getStaticDims() == dstMemory->getStaticDims()) {
const auto prc = dstMemory->getDesc().getPrecision();
// TODO: 109204
// cpu_convert have to be used here because its implementation faster than cpu_memcpy
// in the case when copySize exceeds L2 cache size
cpu_convert(srcData, dstData, prc, prc, optimizedParams.copySize / prc.size());
} else if (optimizedParams.srcStrides[5] == 0) {
if (optimizedParams.dstStrides[0] == optimizedParams.dims[5] * optimizedParams.dstStrides[5]) {

View File

@ -74,17 +74,41 @@ protected:
inputs.insert({funcInputs[1].get_node_shared_ptr(), shape_tensor});
}
void CheckLastNode(const ov::CompiledModel& execNet) {
const auto model = execNet.get_runtime_model();
const auto last_node = model->get_result()->get_input_node_shared_ptr(0);
const auto& rt_info = last_node->get_rt_info();
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
EXPECT_EQ(layer_type, "Broadcast");
const bool data_shape_exceeds_target = [&]() {
const auto& in_data_shape = model->get_parameters()[0]->get_output_shape(0);
if (in_data_shape.size() < target_shape.size()) {
return false;
}
auto bcasted_target = target_shape;
bcasted_target.insert(bcasted_target.begin(), in_data_shape.size() - bcasted_target.size(), 1);
for (size_t i = 0; i < in_data_shape.size(); ++i) {
if (in_data_shape[i] < bcasted_target[i]) {
return false;
}
}
return true;
}();
// If data shape exceeds original target shape, Broadcast must have equal input and output shapes after transition
if (data_shape_exceeds_target) {
EXPECT_EQ(last_node->get_input_shape(0), last_node->get_output_shape(0));
}
}
private:
ov::Shape target_shape;
};
TEST_P(BroadcastEltwise, smoke_CompareWithRefs) {
run();
const auto model = compiledModel.get_runtime_model();
const auto last_node = model->get_result()->get_input_node_shared_ptr(0);
const auto& rt_info = last_node->get_rt_info();
const auto layerType = rt_info.find("layerType")->second.as<std::string>();
EXPECT_EQ(layerType, "Broadcast");
CheckLastNode(compiledModel);
}
namespace {
@ -94,6 +118,7 @@ const std::vector<InputShape> input_shapes = {
};
const std::vector<ov::Shape> target_shapes = {
{1, 3, 16, 16},
{1, 3, 16, 1},
{16, 16},
};