Removed BatchNotmTraining (#1185)

This commit is contained in:
Ilya Churaev 2020-07-06 11:22:27 +03:00 committed by GitHub
parent 84f7cd2c02
commit 293b72151d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 0 additions and 967 deletions

View File

@ -66,12 +66,6 @@ namespace ngraph
throw ngraph_error(
"Cannot create nGraph batch norm with unsupported number of inputs");
// return {std::make_shared<ngraph::opset0::BatchNormTraining>(
// x, scale, bias, epsilon),
// after_bn_mean,
// after_bn_var,
// saved_mean,
// saved_var};
}
} // namespace set_1

View File

@ -24,64 +24,6 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::BatchNormTraining::type_info;
op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
// DEPRECATED
op::BatchNormTraining::BatchNormTraining(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input)
: Op({gamma, beta, input})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
bool op::BatchNormTraining::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormTraining::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
PartialShape result_channel_shape;
set_output_size(3);
std::tie(result_et, result_batch_shape, result_channel_shape) =
infer_batch_norm_forward(this,
get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_GAMMA),
get_input_element_type(INPUT_BETA),
get_input_partial_shape(INPUT_DATA),
get_input_partial_shape(INPUT_GAMMA),
get_input_partial_shape(INPUT_BETA));
set_output_type(0, result_et, result_batch_shape);
set_output_type(1, result_et, result_channel_shape);
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<Node>
op::BatchNormTraining::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
}
constexpr NodeTypeInfo op::BatchNormInference::type_info;
op::BatchNormInference::BatchNormInference(const Output<Node>& input,

View File

@ -28,67 +28,6 @@ namespace ngraph
{
namespace v0
{
/// \brief Batchnorm for training operation
class NGRAPH_API BatchNormTraining : public Op
{
public:
static constexpr NodeTypeInfo type_info{"BatchNormTraining", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
BatchNormTraining() = default;
/// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
NGRAPH_DEPRECATED_DOC
/// In this version of BatchNorm:
///
/// MEAN AND VARIANCE: computed directly from the content of 'input'.
///
/// OUTPUT VALUE: A tuple with the following structure:
/// [0] - The normalization of 'input'.
/// [1] - The per-channel means of (pre-normalized) 'input'.
/// [2] - The per-channel variances of (pre-normalized) 'input'.
///
/// AUTODIFF SUPPORT: yes: 'generate_adjoints(...)' works as expected.
///
/// SHAPE DETAILS:
/// gamma: must have rank 1, with the same span as input's channel axis.
/// beta: must have rank 1, with the same span as input's channel axis.
/// input: must have rank >= 2. The second dimension represents the channel
/// axis
/// and must have a span of at least 1.
/// output[0]: shall have the same shape as 'input'.
/// output[1]: shall have rank 1, with the same span as input's channel axis.
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
private:
double m_epsilon;
};
class NGRAPH_API BatchNormInference : public Op
{
public:
@ -154,8 +93,6 @@ namespace ngraph
double m_epsilon;
};
} // namespace v0
using v0::BatchNormInference;
using v0::BatchNormTraining;
}
}

View File

@ -45,7 +45,6 @@ NGRAPH_OP(AvgPool, ngraph::op::v1, 1)
NGRAPH_OP(BatchMatMul, ngraph::op::v0, 0)
NGRAPH_OP(BatchMatMulTranspose, ngraph::op::v0, 0)
NGRAPH_OP(BatchNormInference, ngraph::op::v0, 0)
NGRAPH_OP(BatchNormTraining, ngraph::op::v0, 0)
NGRAPH_OP(BatchToSpace, ngraph::op::v1, 1)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1, 1)
NGRAPH_OP(Broadcast, ngraph::op::v0, 0)

View File

@ -1019,13 +1019,6 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
make_shared<op::BatchMatMulTranspose>(args[0], args[1], transpose_0, transpose_1);
break;
}
case OP_TYPEID::BatchNormTraining:
{
auto epsilon = node_js.at("eps").get<double>();
// Odd order for back-compatibility
node = make_shared<op::BatchNormTraining>(args[2], args[0], args[1], epsilon);
break;
}
case OP_TYPEID::BatchNormInference:
{
auto epsilon = node_js.at("eps").get<double>();
@ -2557,12 +2550,6 @@ json JSONSerializer::serialize_node(const Node& n)
node["transpose_1"] = tmp->get_transpose_arg1();
break;
}
case OP_TYPEID::BatchNormTraining:
{
auto tmp = static_cast<const op::BatchNormTraining*>(&n);
node["eps"] = tmp->get_eps_value();
break;
}
case OP_TYPEID::BatchNormInference:
{
auto tmp = static_cast<const op::BatchNormInference*>(&n);

View File

@ -275,196 +275,6 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_f32)
EXPECT_TRUE(bnt.test_variance()) << "Variance test";
}
template <typename T>
class BatchNormTrainingTester
{
public:
BatchNormTrainingTester(const std::shared_ptr<ngraph::runtime::Backend>& backend,
const Shape& input_shape,
element::Type etype,
double epsilon)
: m_backend(backend)
{
Shape channel_shape{input_shape.at(1)};
auto Input = make_shared<op::Parameter>(etype, input_shape);
auto Gamma = make_shared<op::Parameter>(etype, channel_shape);
auto Beta = make_shared<op::Parameter>(etype, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(Input, Gamma, Beta, epsilon);
auto NormedInput = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
auto Mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
auto Variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
m_function = make_shared<Function>(ResultVector{NormedInput, Mean, Variance},
ParameterVector{Input, Gamma, Beta});
m_input = backend->create_tensor(etype, input_shape);
m_gamma = backend->create_tensor(etype, channel_shape);
m_beta = backend->create_tensor(etype, channel_shape);
m_normed_input = backend->create_tensor(etype, input_shape);
m_mean = backend->create_tensor(etype, channel_shape);
m_variance = backend->create_tensor(etype, channel_shape);
}
std::tuple<bool, bool, bool> call(const std::vector<T>& input,
const std::vector<T>& gamma,
const std::vector<T>& beta,
const std::vector<T>& normed_input,
const std::vector<T>& mean,
const std::vector<T>& variance)
{
copy_data(m_input, input);
copy_data(m_gamma, gamma);
copy_data(m_beta, beta);
auto handle = m_backend->compile(m_function);
handle->call_with_validate({m_normed_input, m_mean, m_variance},
{m_input, m_gamma, m_beta});
auto res_normed_input = read_vector<T>(m_normed_input);
bool normed_input_test = test::all_close(normed_input, res_normed_input);
auto res_mean = read_vector<T>(m_mean);
bool mean_test = test::all_close(mean, res_mean);
auto res_variance = read_vector<T>(m_variance);
bool variance_test = test::all_close(variance, res_variance);
return std::tuple<bool, bool, bool>(normed_input_test, mean_test, variance_test);
}
protected:
const std::shared_ptr<ngraph::runtime::Backend>& m_backend;
std::shared_ptr<Function> m_function;
std::shared_ptr<ngraph::runtime::Tensor> m_input;
std::shared_ptr<ngraph::runtime::Tensor> m_gamma;
std::shared_ptr<ngraph::runtime::Tensor> m_beta;
std::shared_ptr<ngraph::runtime::Tensor> m_normed_input;
std::shared_ptr<ngraph::runtime::Tensor> m_mean;
std::shared_ptr<ngraph::runtime::Tensor> m_variance;
};
template <typename T>
class BatchNormTrainingTesterZeroEpsilon : public BatchNormTrainingTester<T>
{
public:
// These are for documentation purposes only below
using Input = test::NDArray<T, 2>;
using Gamma = test::NDArray<T, 1>;
using Beta = test::NDArray<T, 1>;
using NormedInput = test::NDArray<T, 2>;
using Mean = test::NDArray<T, 1>;
using Variance = test::NDArray<T, 1>;
BatchNormTrainingTesterZeroEpsilon(const std::shared_ptr<ngraph::runtime::Backend>& backend,
element::Type etype)
: BatchNormTrainingTester<T>(backend, Shape{10, 3}, etype, 0.0)
{
}
std::tuple<bool, bool, bool> test(const Input& input,
const Gamma& gamma,
const Beta& beta,
const NormedInput& normed_input,
const Mean& mean,
const Variance& variance)
{
return BatchNormTrainingTester<T>::call(input.get_vector(),
gamma.get_vector(),
beta.get_vector(),
normed_input.get_vector(),
mean.get_vector(),
variance.get_vector());
}
std::tuple<bool, bool, bool> test_mean_variance()
{
return test(Input{{0.0, 1.0, 0.0},
{1.0, 2.0, 0.25},
{1.0, 2.0, 0.25},
{3.0, 4.0, 0.75},
{3.0, 4.0, 0.75},
{0.0, 1.0, 0.0},
{-1.0, 0.0, -0.25},
{-1.0, 0.0, -0.25},
{-3.0, -2.0, -0.75},
{-3.0, -2.0, -0.75}},
Gamma{1.0, 1.0, 1.0},
Beta{0.0, 0.0, 0.0},
NormedInput{{0.0, 0.0, 0.0},
{0.5, 0.5, 0.5},
{0.5, 0.5, 0.5},
{1.5, 1.5, 1.5},
{1.5, 1.5, 1.5},
{0.0, 0.0, 0.0},
{-0.5, -0.5, -0.5},
{-0.5, -0.5, -0.5},
{-1.5, -1.5, -1.5},
{-1.5, -1.5, -1.5}},
Mean{0.0, 1.0, 0.0},
Variance{4.0, 4.0, 0.25});
}
std::tuple<bool, bool, bool> test_gamma_beta()
{
return test(Input{{0.0, 1.0, 0.0},
{1.0, 2.0, 0.25},
{1.0, 2.0, 0.25},
{3.0, 4.0, 0.75},
{3.0, 4.0, 0.75},
{0.0, 1.0, 0.0},
{-1.0, 0.0, -0.25},
{-1.0, 0.0, -0.25},
{-3.0, -2.0, -0.75},
{-3.0, -2.0, -0.75}},
Gamma{2.0, 1.0, 2.0},
Beta{0.0, 1.0, 1.0},
NormedInput{{0.0, 1.0, 1.0},
{1.0, 1.5, 2.0},
{1.0, 1.5, 2.0},
{3.0, 2.5, 4.0},
{3.0, 2.5, 4.0},
{0.0, 1.0, 1.0},
{-1.0, 0.5, 0.0},
{-1.0, 0.5, 0.0},
{-3.0, -0.5, -2.0},
{-3.0, -0.5, -2.0}},
Mean{0.0, 1.0, 0.0},
Variance{4.0, 4.0, 0.25});
}
};
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_training_0eps_f64)
{
using T = double;
auto& et = element::f64;
auto backend = runtime::Backend::create("${BACKEND_NAME}");
BatchNormTrainingTesterZeroEpsilon<T> bnt(backend, et);
std::tuple<bool, bool, bool> result;
result = bnt.test_mean_variance();
EXPECT_TRUE(std::get<0>(result)) << "Mean variance test normed input";
EXPECT_TRUE(std::get<1>(result)) << "Mean variance test mean";
EXPECT_TRUE(std::get<2>(result)) << "Mean variance test variance";
result = bnt.test_gamma_beta();
EXPECT_TRUE(std::get<0>(result)) << "Gamma beta test normed input";
EXPECT_TRUE(std::get<1>(result)) << "Gamma beta test mean";
EXPECT_TRUE(std::get<2>(result)) << "Gamma test variance";
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_training_0eps_f32)
{
using T = float;
auto& et = element::f32;
auto backend = runtime::Backend::create("${BACKEND_NAME}");
BatchNormTrainingTesterZeroEpsilon<T> bnt(backend, et);
std::tuple<bool, bool, bool> result;
result = bnt.test_mean_variance();
EXPECT_TRUE(std::get<0>(result)) << "Mean variance test normed input";
EXPECT_TRUE(std::get<1>(result)) << "Mean variance test mean";
EXPECT_TRUE(std::get<2>(result)) << "Mean variance test variance";
result = bnt.test_gamma_beta();
EXPECT_TRUE(std::get<0>(result)) << "Gamma beta test normed input";
EXPECT_TRUE(std::get<1>(result)) << "Gamma beta test mean";
EXPECT_TRUE(std::get<2>(result)) << "Gamma beta test variance";
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication)
{
auto input_shape = Shape{2, 2, 2, 1};
@ -510,177 +320,6 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_inference_parameters_duplication)
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{1, 2, 2, 2});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{-0.71498716f,
1.48388731f,
-0.00196938f,
-0.76693159f,
-0.91316032f,
0.23943391f,
-0.84090298f,
1.51462936f};
vector<float> expected_mean{0.602912f, 0.599727f};
vector<float> expected_variance{0.00472505f, 0.0361782f};
auto handle = backend->compile(f);
handle->call_with_validate({bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output), 1e-5f, 1e-6f));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean), 1e-5f, 1e-6f));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
vector<float> expected_mean{0.583388f, 0.619252f};
vector<float> expected_variance{0.0119972f, 0.0282681f};
auto handle = backend->compile(f);
handle->call_with_validate({bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_b2c2d2h1w1)
{
auto input_shape = Shape{2, 2, 2, 1, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1, 1};
auto bn = make_shared<op::BatchNormTraining>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
vector<float> expected_mean{0.583388f, 0.619252f};
vector<float> expected_variance{0.0119972f, 0.0282681f};
auto handle = backend->compile(f);
handle->call_with_validate({bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
@ -729,82 +368,3 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_inference_b2c2h2w1)
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
NGRAPH_TEST(DISABLED_${BACKEND_NAME}, dyn_batch_norm_fprop_b1c2h2w2)
{
// auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto shapeof_mean_rt = std::make_shared<ngraph::op::ShapeOf>(mean_rt);
auto rankof_mean_rt = std::make_shared<ngraph::op::ShapeOf>(shapeof_mean_rt);
auto rank_scalar = std::make_shared<ngraph::op::Reshape>(
rankof_mean_rt, ngraph::AxisVector{0}, ngraph::Shape{});
auto range = std::make_shared<ngraph::op::Range>(
ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}),
rank_scalar,
ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}));
auto one_bcast = std::make_shared<ngraph::op::DynBroadcast>(
ngraph::op::Constant::create(mean_rt->get_element_type(), ngraph::Shape{}, {1}),
shapeof_mean_rt,
range);
auto mean_rt_multiplied = std::make_shared<ngraph::op::Multiply>(one_bcast, mean_rt);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt_multiplied, variance_rt},
ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{1, 2, 2, 2});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto result_mean = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto result_variance = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
vector<float> expected_result{-0.71498716f,
1.48388731f,
-0.00196938f,
-0.76693159f,
-0.91316032f,
0.23943391f,
-0.84090298f,
1.51462936f};
vector<float> expected_mean{0.602912f, 0.599727f};
vector<float> expected_variance{0.00472505f, 0.0361782f};
auto handle = backend->compile(f);
handle->call_with_validate({bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output), 1e-5f, 1e-6f));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean), 1e-5f, 1e-6f));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}

View File

@ -169,15 +169,6 @@ namespace
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_BatchNormTraining()
{
op::BatchNormTraining node;
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_comparison());
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_Broadcast()
{
op::Broadcast node;

View File

@ -389,20 +389,6 @@ protected:
node.get_output_shape(0));
break;
}
case OP_TYPEID::BatchNormTraining:
{
const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
out[2]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
case OP_TYPEID::BatchNormInference:
{
const ngraph::op::BatchNormInference* bn =

View File

@ -66,7 +66,6 @@ NGRAPH_OP(AvgPool, ngraph::op::v0)
NGRAPH_OP(BatchMatMul, ngraph::op)
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastDistributed, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)

View File

@ -21,87 +21,6 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, batch_norm_training_rank_less_than_2)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{1});
try
{
auto bc = make_shared<op::BatchNormTraining>(dummy, dummy, dummy, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input argument must have rank of at least 2"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_zero_channel_check)
{
auto data_batch = make_shared<op::Parameter>(element::f32, Shape{1, 0, 2, 3});
auto gamma = make_shared<op::Parameter>(element::f32, Shape{0});
auto beta = make_shared<op::Parameter>(element::f32, Shape{0});
try
{
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_et_check)
{
auto data_batch = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
auto gamma = make_shared<op::Parameter>(element::f64, Shape{3});
auto beta = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw for different element types";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_shape_check)
{
auto data_batch = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
auto gamma = make_shared<op::Parameter>(element::f32, Shape{4});
auto beta = make_shared<op::Parameter>(element::f32, Shape{3});
try
{
auto bc = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, 0.001);
FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_inference_partial_all_rank_dynamic)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
@ -402,284 +321,3 @@ TEST(type_prop,
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_partial_all_rank_dynamic)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape gamma_shape{PartialShape::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
ASSERT_EQ(bn->get_output_size(), 3);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(1), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(2), data_batch_et);
ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(bn->get_output_partial_shape(1).same_scheme(PartialShape::dynamic(1)));
ASSERT_TRUE(bn->get_output_partial_shape(2).same_scheme(PartialShape::dynamic(1)));
}
TEST(type_prop, batch_norm_training_partial_input_rank_static_dynamic_batch_size_known_ok)
{
PartialShape data_batch_shape{
64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape gamma_shape{PartialShape::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
ASSERT_EQ(bn->get_output_size(), 3);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(1), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(2), data_batch_et);
ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme(
PartialShape{64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_TRUE(bn->get_output_partial_shape(1).same_scheme(PartialShape::dynamic(1)));
ASSERT_TRUE(bn->get_output_partial_shape(2).same_scheme(PartialShape::dynamic(1)));
}
TEST(type_prop, batch_norm_training_partial_input_rank_static_dynamic_channel_count_known_ok)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape gamma_shape{PartialShape::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
ASSERT_EQ(bn->get_output_size(), 3);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(1), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(2), data_batch_et);
ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()}));
ASSERT_TRUE(bn->get_output_partial_shape(1).same_scheme(PartialShape{3}));
ASSERT_TRUE(bn->get_output_partial_shape(2).same_scheme(PartialShape{3}));
}
TEST(type_prop, batch_norm_training_partial_input_rank_static_dynamic_zero_channels)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
PartialShape gamma_shape{PartialShape::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
try
{
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Zero channel count not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count must be at least 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_partial_input_rank_dynamic_some_rank_static_dynamic_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape gamma_shape{Dimension::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
ASSERT_EQ(bn->get_output_size(), 3);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(1), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(2), data_batch_et);
ASSERT_TRUE(bn->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(bn->get_output_partial_shape(1).same_scheme(PartialShape::dynamic(1)));
ASSERT_TRUE(bn->get_output_partial_shape(2).same_scheme(PartialShape::dynamic(1)));
}
TEST(type_prop, batch_norm_training_partial_input_rank_dynamic_some_rank_static_dynamic_wrong_rank)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape gamma_shape{Dimension::dynamic(), Dimension::dynamic()};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
try
{
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Wrong gamma/beta shape not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape for gamma/beta ({?,?}) does not have rank 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
batch_norm_training_partial_input_rank_dynamic_some_rank_static_dynamic_inconsistent_rank)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape gamma_shape{3, Dimension::dynamic()};
PartialShape beta_shape{Dimension::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
try
{
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta shape not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
batch_norm_training_partial_input_rank_dynamic_some_static_inconsistent_channel_count)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape gamma_shape{3};
PartialShape beta_shape{4};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
try
{
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent gamma/beta channel count not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Shapes for gamma/beta do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batch_norm_training_partial_input_rank_static_dynamic_some_static_ok)
{
PartialShape data_batch_shape{64, Dimension::dynamic(), Dimension::dynamic(), 224};
PartialShape gamma_shape{3};
PartialShape beta_shape{3};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
ASSERT_EQ(bn->get_output_size(), 3);
ASSERT_EQ(bn->get_output_element_type(0), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(1), data_batch_et);
ASSERT_EQ(bn->get_output_element_type(2), data_batch_et);
ASSERT_TRUE(bn->get_output_partial_shape(0).same_scheme(
PartialShape{64, 3, Dimension::dynamic(), 224}));
ASSERT_TRUE(bn->get_output_partial_shape(1).same_scheme(PartialShape{3}));
ASSERT_TRUE(bn->get_output_partial_shape(2).same_scheme(PartialShape{3}));
}
TEST(type_prop,
batch_norm_training_partial_input_rank_static_dynamic_some_static_inconsistent_channel_count)
{
PartialShape data_batch_shape{64, 4, Dimension::dynamic(), 224};
PartialShape gamma_shape{3};
PartialShape beta_shape{PartialShape::dynamic()};
double epsilon = 0.001;
element::Type data_batch_et = element::f32;
element::Type gamma_et = element::f32;
element::Type beta_et = element::f32;
auto data_batch = make_shared<op::Parameter>(data_batch_et, data_batch_shape);
auto gamma = make_shared<op::Parameter>(gamma_et, gamma_shape);
auto beta = make_shared<op::Parameter>(beta_et, beta_shape);
try
{
auto bn = make_shared<op::BatchNormTraining>(data_batch, gamma, beta, epsilon);
FAIL() << "Inconsistent input/gamma/beta channel count not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Input channel dimension (4) does not match shape for gamma/beta ({3})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}