Reference Implementation for RegionYolo operator (#2474)
This commit is contained in:
parent
db85069713
commit
c9b16a79f5
@ -0,0 +1,85 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/region_yolo.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
const std::vector<ngraph::Shape> inShapes_caffe = {
|
||||
{1, 125, 13, 13}
|
||||
};
|
||||
|
||||
const std::vector<ngraph::Shape> inShapes_mxnet = {
|
||||
{1, 75, 52, 52},
|
||||
{1, 75, 32, 32},
|
||||
{1, 75, 26, 26},
|
||||
{1, 75, 16, 16},
|
||||
{1, 75, 13, 13},
|
||||
{1, 75, 8, 8}
|
||||
};
|
||||
|
||||
const std::vector<ngraph::Shape> inShapes_v3 = {
|
||||
{1, 255, 52, 52},
|
||||
{1, 255, 26, 26},
|
||||
{1, 255, 13, 13}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int64_t>> masks = {
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8}
|
||||
};
|
||||
|
||||
const std::vector<bool> do_softmax = {true, false};
|
||||
const std::vector<size_t> classes = {80, 20};
|
||||
const std::vector<size_t> num_regions = {5, 9};
|
||||
const size_t coords = 4;
|
||||
const int start_axis = 1;
|
||||
const int end_axis = 3;
|
||||
|
||||
const auto testCase_yolov3 = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_v3),
|
||||
::testing::Values(classes[0]),
|
||||
::testing::Values(coords),
|
||||
::testing::Values(num_regions[1]),
|
||||
::testing::Values(do_softmax[1]),
|
||||
::testing::Values(masks[2]),
|
||||
::testing::Values(start_axis),
|
||||
::testing::Values(end_axis),
|
||||
::testing::Values(InferenceEngine::Precision::FP32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||
);
|
||||
|
||||
const auto testCase_yolov3_mxnet = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_mxnet),
|
||||
::testing::Values(classes[1]),
|
||||
::testing::Values(coords),
|
||||
::testing::Values(num_regions[1]),
|
||||
::testing::Values(do_softmax[1]),
|
||||
::testing::Values(masks[1]),
|
||||
::testing::Values(start_axis),
|
||||
::testing::Values(end_axis),
|
||||
::testing::Values(InferenceEngine::Precision::FP32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||
);
|
||||
|
||||
const auto testCase_yolov2_caffe = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapes_caffe),
|
||||
::testing::Values(classes[1]),
|
||||
::testing::Values(coords),
|
||||
::testing::Values(num_regions[0]),
|
||||
::testing::Values(do_softmax[0]),
|
||||
::testing::Values(masks[0]),
|
||||
::testing::Values(start_axis),
|
||||
::testing::Values(end_axis),
|
||||
::testing::Values(InferenceEngine::Precision::FP32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYolov3, RegionYoloLayerTest, testCase_yolov3, RegionYoloLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloMxnet, RegionYoloLayerTest, testCase_yolov3_mxnet, RegionYoloLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloCaffe, RegionYoloLayerTest, testCase_yolov2_caffe, RegionYoloLayerTest::getTestCaseName);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using regionYoloParamsTuple = std::tuple<
|
||||
ngraph::Shape, // Input Shape
|
||||
size_t, // classes
|
||||
size_t, // coordinates
|
||||
size_t, // num regions
|
||||
bool, // do softmax
|
||||
std::vector<int64_t>, // mask
|
||||
int, // start axis
|
||||
int, // end axis
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class RegionYoloLayerTest : public testing::WithParamInterface<regionYoloParamsTuple>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include "single_layer_tests/region_yolo.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string RegionYoloLayerTest::getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj) {
|
||||
ngraph::Shape inputShape;
|
||||
size_t classes;
|
||||
size_t coords;
|
||||
size_t num_regions;
|
||||
bool do_softmax;
|
||||
std::vector<int64_t> mask;
|
||||
int start_axis;
|
||||
int end_axis;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetName;
|
||||
std::tie(inputShape, classes, coords, num_regions, do_softmax , mask, start_axis, end_axis, netPrecision, targetName) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "IS=" << inputShape << "_";
|
||||
result << "classes=" << classes << "_";
|
||||
result << "coords=" << coords << "_";
|
||||
result << "num=" << num_regions << "_";
|
||||
result << "doSoftmax=" << do_softmax << "_";
|
||||
result << "axis=" << start_axis << "_";
|
||||
result << "endAxis=" << end_axis << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetName << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void RegionYoloLayerTest::SetUp() {
|
||||
ngraph::Shape inputShape;
|
||||
size_t classes;
|
||||
size_t coords;
|
||||
size_t num_regions;
|
||||
bool do_softmax;
|
||||
std::vector<int64_t> mask;
|
||||
int start_axis;
|
||||
int end_axis;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(inputShape, classes, coords, num_regions, do_softmax, mask, start_axis, end_axis, netPrecision, targetDevice) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, inputShape);
|
||||
auto region_yolo = std::make_shared<ngraph::op::v0::RegionYolo>(param, coords, classes, num_regions, do_softmax, mask, start_axis, end_axis);
|
||||
function = std::make_shared<ngraph::Function>(std::make_shared<ngraph::opset1::Result>(region_yolo), ngraph::ParameterVector{param}, "RegionYolo");
|
||||
}
|
||||
|
||||
TEST_P(RegionYoloLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -79,7 +79,7 @@ namespace ngraph
|
||||
int m_axis;
|
||||
int m_end_axis;
|
||||
};
|
||||
}
|
||||
} // namespace v0
|
||||
using v0::RegionYolo;
|
||||
}
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -0,0 +1,175 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
static inline int entry_index(int width,
|
||||
int height,
|
||||
int coords,
|
||||
int classes,
|
||||
int outputs,
|
||||
int batch,
|
||||
int location,
|
||||
int entry)
|
||||
{
|
||||
int n = location / (width * height);
|
||||
int loc = location % (width * height);
|
||||
return batch * outputs + n * width * height * (coords + classes + 1) +
|
||||
entry * width * height + loc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline T sigmoid(float x)
|
||||
{
|
||||
return static_cast<T>(1.f / (1.f + std::exp(-x)));
|
||||
}
|
||||
template <typename T>
|
||||
static inline void softmax_generic(
|
||||
const T* src_data, T* dst_data, int batches, int channels, int height, int width)
|
||||
{
|
||||
const int area = height * width;
|
||||
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
|
||||
{
|
||||
const int offset = batch_idx * channels * area;
|
||||
for (unsigned int i = 0; i < height * width; i++)
|
||||
{
|
||||
T max = src_data[batch_idx * channels * area + i];
|
||||
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||
{
|
||||
T val = src_data[offset + channel_idx * area + i];
|
||||
max = std::max(max, val);
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||
{
|
||||
dst_data[offset + channel_idx * area + i] =
|
||||
std::exp(src_data[offset + channel_idx * area + i] - max);
|
||||
sum += dst_data[offset + channel_idx * area + i];
|
||||
}
|
||||
|
||||
for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
|
||||
{
|
||||
dst_data[offset + channel_idx * area + i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void region_yolo(const T* input,
|
||||
T* output,
|
||||
const Shape& input_shape,
|
||||
const int coords,
|
||||
const int classes,
|
||||
const int regions,
|
||||
const bool do_softmax,
|
||||
const std::vector<int64_t>& mask)
|
||||
{
|
||||
NGRAPH_CHECK(input_shape.size() == 4);
|
||||
|
||||
const int batches = input_shape[0];
|
||||
const int channels = input_shape[1];
|
||||
const int height = input_shape[2];
|
||||
const int width = input_shape[3];
|
||||
|
||||
const auto mask_size = mask.size();
|
||||
|
||||
std::copy(input, input + shape_size(input_shape), output);
|
||||
|
||||
int num_regions = 0;
|
||||
int end_index = 0;
|
||||
|
||||
if (do_softmax)
|
||||
{
|
||||
// Region layer (Yolo v2)
|
||||
num_regions = regions;
|
||||
end_index = width * height;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Yolo layer (Yolo v3)
|
||||
num_regions = mask_size;
|
||||
end_index = width * height * (classes + 1);
|
||||
}
|
||||
|
||||
const int inputs_size = width * height * num_regions * (classes + coords + 1);
|
||||
|
||||
for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
|
||||
{
|
||||
for (unsigned int n = 0; n < num_regions; n++)
|
||||
{
|
||||
int index = entry_index(width,
|
||||
height,
|
||||
coords,
|
||||
classes,
|
||||
inputs_size,
|
||||
batch_idx,
|
||||
n * width * height,
|
||||
0);
|
||||
std::transform(output + index,
|
||||
output + index + 2 * width * height,
|
||||
output + index,
|
||||
[](T elem) { return sigmoid<T>(elem); });
|
||||
|
||||
index = entry_index(width,
|
||||
height,
|
||||
coords,
|
||||
classes,
|
||||
inputs_size,
|
||||
batch_idx,
|
||||
n * width * height,
|
||||
coords);
|
||||
std::transform(output + index,
|
||||
output + index + end_index,
|
||||
output + index,
|
||||
[](T elem) { return sigmoid<T>(elem); });
|
||||
}
|
||||
}
|
||||
|
||||
if (do_softmax)
|
||||
{
|
||||
int index =
|
||||
entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1);
|
||||
int batch_offset = inputs_size / regions;
|
||||
for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++)
|
||||
{
|
||||
softmax_generic<T>(input + index + batch_idx * batch_offset,
|
||||
output + index + batch_idx * batch_offset,
|
||||
1,
|
||||
classes,
|
||||
height,
|
||||
width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference
|
||||
|
||||
} // namespace runtime
|
||||
|
||||
} // namespace ngraph
|
@ -60,6 +60,12 @@ bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
|
||||
void op::RegionYolo::validate_and_infer_types()
|
||||
{
|
||||
auto input_et = get_input_element_type(0);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et.is_real(),
|
||||
"Type of input is expected to be a floating point type. Got: ",
|
||||
input_et);
|
||||
|
||||
if (get_input_partial_shape(0).is_static())
|
||||
{
|
||||
Shape input_shape = get_input_partial_shape(0).to_shape();
|
||||
|
@ -325,6 +325,7 @@ set(MULTI_TEST_SRC
|
||||
backend/reduce_min.in.cpp
|
||||
backend/reduce_prod.in.cpp
|
||||
backend/reduce_sum.in.cpp
|
||||
backend/region_yolo.in.cpp
|
||||
backend/relu.in.cpp
|
||||
backend/reorg_yolo.in.cpp
|
||||
backend/replace_slice.in.cpp
|
||||
|
@ -787,7 +787,7 @@ TEST(attributes, reduce_sum_op)
|
||||
TEST(attributes, region_yolo_op)
|
||||
{
|
||||
FactoryRegistry<Node>::get().register_factory<opset1::RegionYolo>();
|
||||
auto data = make_shared<op::Parameter>(element::i64, Shape{1, 255, 26, 26});
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 255, 26, 26});
|
||||
|
||||
size_t num_coords = 4;
|
||||
size_t num_classes = 1;
|
||||
|
86
ngraph/test/backend/region_yolo.in.cpp
Normal file
86
ngraph/test/backend/region_yolo.in.cpp
Normal file
@ -0,0 +1,86 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/engine/test_engines.hpp"
|
||||
#include "util/test_case.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v2_caffe)
|
||||
{
|
||||
const size_t num = 5;
|
||||
const size_t coords = 4;
|
||||
const size_t classes = 20;
|
||||
const size_t batch = 1;
|
||||
const size_t channels = 125;
|
||||
const size_t width = 13;
|
||||
const size_t height = 13;
|
||||
const size_t count = width * height * channels;
|
||||
const std::vector<int64_t> mask{0, 1, 2};
|
||||
|
||||
Shape input_shape{batch, channels, height, width};
|
||||
Shape output_shape{batch, channels * height * width};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, input_shape);
|
||||
auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, true, mask, 1, 3);
|
||||
auto f = make_shared<Function>(R, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
|
||||
test_case.add_input_from_file<float>(input_shape, TEST_FILES, "region_in_yolov2_caffe.data");
|
||||
test_case.add_expected_output_from_file<float>(
|
||||
output_shape, TEST_FILES, "region_out_yolov2_caffe.data");
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v3_mxnet)
|
||||
{
|
||||
const size_t num = 9;
|
||||
const size_t coords = 4;
|
||||
const size_t classes = 20;
|
||||
const size_t batch = 1;
|
||||
const size_t channels = 75;
|
||||
const size_t width = 32;
|
||||
const size_t height = 32;
|
||||
const std::vector<int64_t> mask{0, 1, 2};
|
||||
|
||||
Shape shape{batch, channels, height, width};
|
||||
const auto count = shape_size(shape);
|
||||
|
||||
const auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
const auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, false, mask, 1, 3);
|
||||
const auto f = make_shared<Function>(R, ParameterVector{A});
|
||||
|
||||
EXPECT_EQ(R->get_output_shape(0), shape);
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
|
||||
test_case.add_input_from_file<float>(shape, TEST_FILES, "region_in_yolov3_mxnet.data");
|
||||
test_case.add_expected_output_from_file<float>(
|
||||
shape, TEST_FILES, "region_out_yolov3_mxnet.data");
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
BIN
ngraph/test/files/region_in_yolov2_caffe.data
Normal file
BIN
ngraph/test/files/region_in_yolov2_caffe.data
Normal file
Binary file not shown.
BIN
ngraph/test/files/region_in_yolov3_mxnet.data
Normal file
BIN
ngraph/test/files/region_in_yolov3_mxnet.data
Normal file
Binary file not shown.
BIN
ngraph/test/files/region_out_yolov2_caffe.data
Normal file
BIN
ngraph/test/files/region_out_yolov2_caffe.data
Normal file
Binary file not shown.
BIN
ngraph/test/files/region_out_yolov3_mxnet.data
Normal file
BIN
ngraph/test/files/region_out_yolov3_mxnet.data
Normal file
Binary file not shown.
@ -1466,6 +1466,8 @@ IE_GPU.matmul_2x2_2x2
|
||||
IE_GPU.matmul_2x3_3x3
|
||||
IE_GPU.matmul_3x2_3x3_transpose
|
||||
IE_GPU.matmul_3x2_2x3_transpose
|
||||
IE_GPU.region_yolo_v2_caffe
|
||||
IE_GPU.region_yolo_v3_mxnet
|
||||
|
||||
# Unsupported collapse op with dynamic shape
|
||||
IE_GPU.builder_opset1_collapse_dyn_shape
|
||||
|
@ -77,6 +77,7 @@
|
||||
#include "ngraph/runtime/reference/prior_box.hpp"
|
||||
#include "ngraph/runtime/reference/product.hpp"
|
||||
#include "ngraph/runtime/reference/quantize.hpp"
|
||||
#include "ngraph/runtime/reference/region_yolo.hpp"
|
||||
#include "ngraph/runtime/reference/relu.hpp"
|
||||
#include "ngraph/runtime/reference/reorg_yolo.hpp"
|
||||
#include "ngraph/runtime/reference/replace_slice.hpp"
|
||||
@ -1187,6 +1188,19 @@ protected:
|
||||
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::RegionYolo_v0:
|
||||
{
|
||||
const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
|
||||
reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
args[0]->get_shape(),
|
||||
region_yolo->get_num_coords(),
|
||||
region_yolo->get_num_classes(),
|
||||
region_yolo->get_num_regions(),
|
||||
region_yolo->get_do_softmax(),
|
||||
region_yolo->get_mask());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Relu:
|
||||
{
|
||||
size_t element_count = shape_size(node.get_output_shape(0));
|
||||
|
@ -21,6 +21,7 @@
|
||||
#define ID_SUFFIX(NAME) NAME##_v0
|
||||
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
|
||||
NGRAPH_OP(DetectionOutput, op::v0)
|
||||
NGRAPH_OP(RegionYolo, op::v0)
|
||||
NGRAPH_OP(ReorgYolo, op::v0)
|
||||
NGRAPH_OP(RNNCell, op::v0)
|
||||
#undef ID_SUFFIX
|
||||
|
Loading…
Reference in New Issue
Block a user