[CPU] Optimized shapeInfer() for OneHot (#8739)

This commit is contained in:
Egor Shulman 2021-11-23 15:55:52 +03:00 committed by GitHub
parent c29569ecbd
commit 9ce8ac536f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,25 +84,12 @@ bool MKLDNNOneHotNode::needShapeInfer() const {
}
std::vector<VectorDims> MKLDNNOneHotNode::shapeInfer() const {
std::vector<ov::StaticShape> input_shapes = {
getParentEdgesAtPort(0)[0]->getMemory().GetShape().getStaticDims(),
getParentEdgesAtPort(1)[0]->getMemory().GetShape().getStaticDims(),
getParentEdgesAtPort(2)[0]->getMemory().GetShape().getStaticDims(),
getParentEdgesAtPort(3)[0]->getMemory().GetShape().getStaticDims()
};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> input_values = {
{1, std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, VectorDims{ }, getParentEdgesAtPort(1)[0]->getMemory().GetPtr())},
{2, std::make_shared<ngraph::runtime::HostTensor>(opToShapeInfer->get_input_node_shared_ptr(2))},
{3, std::make_shared<ngraph::runtime::HostTensor>(opToShapeInfer->get_input_node_shared_ptr(3))}
};
std::vector<ov::StaticShape> output_shapes = {{}};
shape_inference(opToShapeInfer.get(), input_shapes, output_shapes, input_values);
std::vector<VectorDims> result(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(), result.begin(), [](const ov::StaticShape& s){ return s.to_shape(); });
depth = reinterpret_cast<int32_t *>(getParentEdgesAtPort(1)[0]->getMemoryPtr()->GetPtr())[0];
return result;
auto result = getParentEdgesAtPort(0)[0]->getMemory().getStaticDims();
result.insert(result.begin() + axis, depth);
return { result };
}
void MKLDNNOneHotNode::initSupportedPrimitiveDescriptors() {