Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/layers/batch_norm_layer.cpp
16337 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
5
// Copyright (C) 2016, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
/*
9
Implementation of Batch Normalization layer.
10
*/
11
12
#include "../precomp.hpp"
13
#include "../op_halide.hpp"
14
#include "../op_inf_engine.hpp"
15
#include <opencv2/dnn/shape_utils.hpp>
16
17
#ifdef HAVE_OPENCL
18
#include "opencl_kernels_dnn.hpp"
19
#endif
20
21
namespace cv
22
{
23
namespace dnn
24
{
25
26
class BatchNormLayerImpl CV_FINAL : public BatchNormLayer
27
{
28
public:
29
Mat weights_, bias_;
30
UMat umat_weight, umat_bias;
31
32
BatchNormLayerImpl(const LayerParams& params)
33
{
34
setParamsFrom(params);
35
CV_Assert(blobs.size() >= 2);
36
37
hasWeights = params.get<bool>("has_weight", false);
38
hasBias = params.get<bool>("has_bias", false);
39
useGlobalStats = params.get<bool>("use_global_stats", true);
40
if(params.get<bool>("scale_bias", false))
41
hasWeights = hasBias = true;
42
epsilon = params.get<float>("eps", 1E-5);
43
44
size_t n = blobs[0].total();
45
CV_Assert(blobs[1].total() == n &&
46
blobs[0].isContinuous() && blobs[1].isContinuous() &&
47
blobs[0].type() == CV_32F && blobs[1].type() == CV_32F);
48
49
float varMeanScale = 1.f;
50
if (!hasWeights && !hasBias && blobs.size() > 2 && useGlobalStats) {
51
CV_Assert(blobs.size() == 3); CV_CheckTypeEQ(blobs[2].type(), CV_32FC1, "");
52
varMeanScale = blobs[2].at<float>(0);
53
if (varMeanScale != 0)
54
varMeanScale = 1/varMeanScale;
55
}
56
57
const int biasBlobIndex = blobs.size() - 1;
58
const int weightsBlobIndex = biasBlobIndex - hasBias;
59
60
if( hasWeights )
61
{
62
CV_Assert((size_t)weightsBlobIndex < blobs.size());
63
const Mat& w = blobs[weightsBlobIndex];
64
CV_Assert(w.isContinuous() && w.type() == CV_32F && w.total() == (size_t)n);
65
}
66
67
if( hasBias )
68
{
69
CV_Assert((size_t)biasBlobIndex < blobs.size());
70
const Mat& b = blobs[weightsBlobIndex];
71
CV_Assert(b.isContinuous() && b.type() == CV_32F && b.total() == (size_t)n);
72
}
73
74
const float* meanData = blobs[0].ptr<float>();
75
const float* stdData = blobs[1].ptr<float>();
76
const float* weightsData = hasWeights ? blobs[weightsBlobIndex].ptr<float>() : 0;
77
const float* biasData = hasBias ? blobs[biasBlobIndex].ptr<float>() : 0;
78
79
weights_.create(1, (int)n, CV_32F);
80
bias_.create(1, (int)n, CV_32F);
81
82
float* dstWeightsData = weights_.ptr<float>();
83
float* dstBiasData = bias_.ptr<float>();
84
85
for (size_t i = 0; i < n; ++i)
86
{
87
float w = (hasWeights ? weightsData[i] : 1.0f) / sqrt(stdData[i] * varMeanScale + epsilon);
88
dstWeightsData[i] = w;
89
dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale;
90
}
91
}
92
93
void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE
94
{
95
scale = weights_;
96
shift = bias_;
97
}
98
99
virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
100
{
101
Mat w, b;
102
top->getScaleShift(w, b);
103
if (w.empty() && b.empty())
104
return false;
105
106
const int numChannels = weights_.total();
107
const int numFusedWeights = w.total();
108
const int numFusedBias = b.total();
109
110
if ((numFusedWeights != numChannels && numFusedWeights != 1 && !w.empty()) ||
111
(numFusedBias != numChannels && numFusedBias != 1 && !b.empty()))
112
return false;
113
114
if (!w.empty())
115
{
116
w = w.reshape(1, 1);
117
if (numFusedWeights == 1)
118
{
119
multiply(weights_, w.at<float>(0), weights_);
120
multiply(bias_, w.at<float>(0), bias_);
121
}
122
else
123
{
124
multiply(weights_, w, weights_);
125
multiply(bias_, w, bias_);
126
}
127
}
128
if (!b.empty())
129
{
130
b = b.reshape(1, 1);
131
if (numFusedBias == 1)
132
add(bias_, b.at<float>(0), bias_);
133
else
134
add(bias_, b.reshape(1, 1), bias_);
135
}
136
return true;
137
}
138
139
bool getMemoryShapes(const std::vector<MatShape> &inputs,
140
const int requiredOutputs,
141
std::vector<MatShape> &outputs,
142
std::vector<MatShape> &internals) const CV_OVERRIDE
143
{
144
if (!useGlobalStats && inputs[0][0] != 1)
145
CV_Error(Error::StsNotImplemented, "Batch normalization in training mode with batch size > 1");
146
Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
147
return true;
148
}
149
150
virtual bool supportBackend(int backendId) CV_OVERRIDE
151
{
152
return backendId == DNN_BACKEND_OPENCV ||
153
backendId == DNN_BACKEND_HALIDE && haveHalide() ||
154
backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine();
155
}
156
157
#ifdef HAVE_OPENCL
158
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
159
{
160
std::vector<UMat> inputs;
161
std::vector<UMat> outputs;
162
163
bool use_half = (inputs_.depth() == CV_16S);
164
inputs_.getUMatVector(inputs);
165
outputs_.getUMatVector(outputs);
166
167
CV_Assert(blobs.size() >= 2);
168
CV_Assert(inputs.size() == 1);
169
170
if (use_half && inputs[0].dims == 2)
171
return false;
172
173
if (umat_weight.empty())
174
{
175
weights_.copyTo(umat_weight);
176
bias_.copyTo(umat_bias);
177
}
178
179
UMat &inpBlob = inputs[0];
180
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
181
int groups = inpBlob.size[0];
182
int channels = inpBlob.size[1];
183
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
184
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
185
186
String opts = (use_half) ? " -DDtype=half" : " -DDtype=float";
187
for (size_t ii = 0; ii < outputs.size(); ii++)
188
{
189
if (inpBlob.dims == 2)
190
{
191
UMat& src = inputs[ii];
192
UMat& dst = outputs[ii];
193
multiply(src, weights_, dst);
194
add(dst, bias_, dst);
195
}
196
else
197
{
198
MatShape s = shape(groups * channels, rows * cols);
199
UMat src = inputs[ii].reshape(1, s.size(), &s[0]);
200
UMat dst = outputs[ii].reshape(1, s.size(), &s[0]);
201
int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1);
202
String buildopt = format("-DNUM=%d", number) + opts;
203
String kname = format("batch_norm%d", number);
204
if (number == 1)
205
buildopt += format(" -Dconvert_T=convert_%s", use_half ? "half" : "float");
206
else
207
buildopt += format(" -Dconvert_T=convert_%s%d", use_half ? "half" : "float", number);
208
ocl::Kernel kernel(kname.c_str(), ocl::dnn::batchnorm_oclsrc, buildopt);
209
if (kernel.empty())
210
return false;
211
size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) };
212
kernel.set(0, ocl::KernelArg::PtrReadOnly(src));
213
kernel.set(1, (int)s[0]);
214
kernel.set(2, (int)s[1]);
215
kernel.set(3, (int)channels);
216
kernel.set(4, ocl::KernelArg::PtrReadOnly(umat_weight));
217
kernel.set(5, ocl::KernelArg::PtrReadOnly(umat_bias));
218
kernel.set(6, ocl::KernelArg::PtrWriteOnly(dst));
219
bool ret = kernel.run(2, global, NULL, false);
220
if (!ret)
221
return false;
222
}
223
}
224
return true;
225
}
226
#endif
227
228
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
229
{
230
CV_TRACE_FUNCTION();
231
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
232
233
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
234
forward_ocl(inputs_arr, outputs_arr, internals_arr))
235
236
if (inputs_arr.depth() == CV_16S)
237
{
238
forward_fallback(inputs_arr, outputs_arr, internals_arr);
239
return;
240
}
241
242
std::vector<Mat> inputs, outputs;
243
inputs_arr.getMatVector(inputs);
244
outputs_arr.getMatVector(outputs);
245
246
CV_Assert(blobs.size() >= 2);
247
CV_Assert(inputs.size() == 1);
248
249
Mat &inpBlob = inputs[0];
250
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
251
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
252
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
253
254
for (size_t ii = 0; ii < outputs.size(); ii++)
255
{
256
Mat &outBlob = outputs[ii];
257
258
for(int num = 0; num < outBlob.size[0]; num++)
259
{
260
for (int n = 0; n < outBlob.size[1]; n++)
261
{
262
float w = weights_.at<float>(n);
263
float b = bias_.at<float>(n);
264
Mat inpBlobPlane(rows, cols, CV_32F, inpBlob.ptr<float>(num, n));
265
Mat outBlobPlane(rows, cols, CV_32F, outBlob.ptr<float>(num, n));
266
inpBlobPlane.convertTo(outBlobPlane, CV_32F, w, b);
267
}
268
}
269
}
270
}
271
272
void forwardSlice(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const CV_OVERRIDE
273
{
274
for( int cn = cn0; cn < cn1; cn++, srcptr += planeSize, dstptr += planeSize )
275
{
276
int i = 0;
277
float w = weights_.at<float>(cn);
278
float b = bias_.at<float>(cn);
279
#if CV_SIMD128
280
v_float32x4 wV = v_setall_f32(w), bV = v_setall_f32(b);
281
for( ; i <= len - 16; i += 16 )
282
{
283
v_float32x4 x0 = v_load(srcptr + i);
284
v_float32x4 x1 = v_load(srcptr + i + 4);
285
v_float32x4 x2 = v_load(srcptr + i + 8);
286
v_float32x4 x3 = v_load(srcptr + i + 12);
287
x0 = v_muladd(x0, w, b);
288
x1 = v_muladd(x1, w, b);
289
x2 = v_muladd(x2, w, b);
290
x3 = v_muladd(x3, w, b);
291
v_store(dstptr + i, x0);
292
v_store(dstptr + i + 4, x1);
293
v_store(dstptr + i + 8, x2);
294
v_store(dstptr + i + 12, x3);
295
}
296
#endif
297
for( ; i < len; i++ )
298
dstptr[i] = w * srcptr[i] + b;
299
}
300
}
301
302
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node) CV_OVERRIDE
303
{
304
switch (node->backendId)
305
{
306
case DNN_BACKEND_HALIDE:
307
{
308
#ifdef HAVE_HALIDE
309
auto base = node.dynamicCast<HalideBackendNode>();
310
Halide::Func& input = base->funcs.back();
311
Halide::Var x("x"), y("y"), c("c"), n("n");
312
Halide::Func top = attachHalide(input(x, y, c, n));
313
return Ptr<BackendNode>(new HalideBackendNode(base, top));
314
#endif // HAVE_HALIDE
315
break;
316
}
317
}
318
return Ptr<BackendNode>();
319
}
320
321
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE
322
{
323
#ifdef HAVE_HALIDE
324
Halide::Buffer<float> input = halideBuffer(inputs[0]);
325
Halide::Var x("x"), y("y"), c("c"), n("n");
326
Halide::Func top = attachHalide(input(x, y, c, n));
327
return Ptr<BackendNode>(new HalideBackendNode(top));
328
#endif // HAVE_HALIDE
329
return Ptr<BackendNode>();
330
}
331
332
#ifdef HAVE_HALIDE
333
// attachHalide can work both with Halide::Buffer and Halide::Func. In the
334
// second case it will be a fusion.
335
Halide::Func attachHalide(const Halide::Expr& input)
336
{
337
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
338
Halide::Var x("x"), y("y"), c("c"), n("n");
339
340
const int numChannels = weights_.total();
341
auto weights = wrapToHalideBuffer(weights_, {numChannels});
342
auto bias = wrapToHalideBuffer(bias_, {numChannels});
343
top(x, y, c, n) = input * weights(c) + bias(c);
344
return top;
345
}
346
#endif // HAVE_HALIDE
347
348
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
349
{
350
#ifdef HAVE_INF_ENGINE
351
InferenceEngine::LayerParams lp;
352
lp.name = name;
353
lp.type = "ScaleShift";
354
lp.precision = InferenceEngine::Precision::FP32;
355
std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));
356
357
const size_t numChannels = weights_.total();
358
ieLayer->_weights = wrapToInfEngineBlob(weights_, {numChannels}, InferenceEngine::Layout::C);
359
ieLayer->_biases = wrapToInfEngineBlob(bias_, {numChannels}, InferenceEngine::Layout::C);
360
361
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
362
#endif // HAVE_INF_ENGINE
363
return Ptr<BackendNode>();
364
}
365
366
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
367
const std::vector<MatShape> &outputs) const CV_OVERRIDE
368
{
369
CV_UNUSED(outputs); // suppress unused variable warning
370
371
int64 flops = 0;
372
for(int i = 0; i < inputs.size(); i++)
373
{
374
flops += 3*total(inputs[i]);
375
}
376
return flops;
377
}
378
379
private:
380
bool useGlobalStats;
381
};
382
383
Ptr<BatchNormLayer> BatchNormLayer::create(const LayerParams& params)
384
{
385
return Ptr<BatchNormLayer>(new BatchNormLayerImpl(params));
386
}
387
388
} // namespace dnn
389
} // namespace cv
390
391