[CPU] Optimized shapeInfer() for OneHot (#8739)
This commit is contained in:
parent
c29569ecbd
commit
9ce8ac536f
@ -84,25 +84,12 @@ bool MKLDNNOneHotNode::needShapeInfer() const {
|
||||
}
|
||||
|
||||
std::vector<VectorDims> MKLDNNOneHotNode::shapeInfer() const {
|
||||
std::vector<ov::StaticShape> input_shapes = {
|
||||
getParentEdgesAtPort(0)[0]->getMemory().GetShape().getStaticDims(),
|
||||
getParentEdgesAtPort(1)[0]->getMemory().GetShape().getStaticDims(),
|
||||
getParentEdgesAtPort(2)[0]->getMemory().GetShape().getStaticDims(),
|
||||
getParentEdgesAtPort(3)[0]->getMemory().GetShape().getStaticDims()
|
||||
};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> input_values = {
|
||||
{1, std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, VectorDims{ }, getParentEdgesAtPort(1)[0]->getMemory().GetPtr())},
|
||||
{2, std::make_shared<ngraph::runtime::HostTensor>(opToShapeInfer->get_input_node_shared_ptr(2))},
|
||||
{3, std::make_shared<ngraph::runtime::HostTensor>(opToShapeInfer->get_input_node_shared_ptr(3))}
|
||||
};
|
||||
std::vector<ov::StaticShape> output_shapes = {{}};
|
||||
shape_inference(opToShapeInfer.get(), input_shapes, output_shapes, input_values);
|
||||
|
||||
std::vector<VectorDims> result(output_shapes.size());
|
||||
std::transform(output_shapes.begin(), output_shapes.end(), result.begin(), [](const ov::StaticShape& s){ return s.to_shape(); });
|
||||
|
||||
depth = reinterpret_cast<int32_t *>(getParentEdgesAtPort(1)[0]->getMemoryPtr()->GetPtr())[0];
|
||||
return result;
|
||||
|
||||
auto result = getParentEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
result.insert(result.begin() + axis, depth);
|
||||
|
||||
return { result };
|
||||
}
|
||||
|
||||
void MKLDNNOneHotNode::initSupportedPrimitiveDescriptors() {
|
||||
|
Loading…
Reference in New Issue
Block a user