Review comments applied (#17168)
This commit is contained in:
parent
8e5b0650a0
commit
a6b1544acf
@ -253,6 +253,9 @@ void TileBroadcastCommon::optimizedExecute(const MemoryPtr& srcMemory, const Mem
|
||||
|
||||
if (srcMemory->getStaticDims() == dstMemory->getStaticDims()) {
|
||||
const auto prc = dstMemory->getDesc().getPrecision();
|
||||
// TODO: 109204
|
||||
// cpu_convert have to be used here because its implementation faster than cpu_memcpy
|
||||
// in the case when copySize exceeds L2 cache size
|
||||
cpu_convert(srcData, dstData, prc, prc, optimizedParams.copySize / prc.size());
|
||||
} else if (optimizedParams.srcStrides[5] == 0) {
|
||||
if (optimizedParams.dstStrides[0] == optimizedParams.dims[5] * optimizedParams.dstStrides[5]) {
|
||||
|
@ -74,17 +74,41 @@ protected:
|
||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), shape_tensor});
|
||||
}
|
||||
|
||||
void CheckLastNode(const ov::CompiledModel& execNet) {
|
||||
const auto model = execNet.get_runtime_model();
|
||||
const auto last_node = model->get_result()->get_input_node_shared_ptr(0);
|
||||
const auto& rt_info = last_node->get_rt_info();
|
||||
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
|
||||
EXPECT_EQ(layer_type, "Broadcast");
|
||||
|
||||
const bool data_shape_exceeds_target = [&]() {
|
||||
const auto& in_data_shape = model->get_parameters()[0]->get_output_shape(0);
|
||||
if (in_data_shape.size() < target_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
auto bcasted_target = target_shape;
|
||||
bcasted_target.insert(bcasted_target.begin(), in_data_shape.size() - bcasted_target.size(), 1);
|
||||
for (size_t i = 0; i < in_data_shape.size(); ++i) {
|
||||
if (in_data_shape[i] < bcasted_target[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
|
||||
// If data shape exceeds original target shape, Broadcast must have equal input and output shapes after transition
|
||||
if (data_shape_exceeds_target) {
|
||||
EXPECT_EQ(last_node->get_input_shape(0), last_node->get_output_shape(0));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ov::Shape target_shape;
|
||||
};
|
||||
|
||||
TEST_P(BroadcastEltwise, smoke_CompareWithRefs) {
|
||||
run();
|
||||
|
||||
const auto model = compiledModel.get_runtime_model();
|
||||
const auto last_node = model->get_result()->get_input_node_shared_ptr(0);
|
||||
const auto& rt_info = last_node->get_rt_info();
|
||||
const auto layerType = rt_info.find("layerType")->second.as<std::string>();
|
||||
EXPECT_EQ(layerType, "Broadcast");
|
||||
CheckLastNode(compiledModel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -94,6 +118,7 @@ const std::vector<InputShape> input_shapes = {
|
||||
};
|
||||
|
||||
const std::vector<ov::Shape> target_shapes = {
|
||||
{1, 3, 16, 16},
|
||||
{1, 3, 16, 1},
|
||||
{16, 16},
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user