Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/layers/recurrent_layers.cpp
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14
// Copyright (C) 2017, Intel Corporation, all rights reserved.
15
// Third party copyrights are property of their respective owners.
16
//
17
// Redistribution and use in source and binary forms, with or without modification,
18
// are permitted provided that the following conditions are met:
19
//
20
// * Redistribution's of source code must retain the above copyright notice,
21
// this list of conditions and the following disclaimer.
22
//
23
// * Redistribution's in binary form must reproduce the above copyright notice,
24
// this list of conditions and the following disclaimer in the documentation
25
// and/or other materials provided with the distribution.
26
//
27
// * The name of the copyright holders may not be used to endorse or promote products
28
// derived from this software without specific prior written permission.
29
//
30
// This software is provided by the copyright holders and contributors "as is" and
31
// any express or implied warranties, including, but not limited to, the implied
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
34
// indirect, incidental, special, exemplary, or consequential damages
35
// (including, but not limited to, procurement of substitute goods or services;
36
// loss of use, data, or profits; or business interruption) however caused
37
// and on any theory of liability, whether in contract, strict liability,
38
// or tort (including negligence or otherwise) arising in any way out of
39
// the use of this software, even if advised of the possibility of such damage.
40
//
41
//M*/
42
43
#include "../precomp.hpp"
44
#include <iostream>
45
#include <iterator>
46
#include <cmath>
47
#include <opencv2/dnn/shape_utils.hpp>
48
49
namespace cv
50
{
51
namespace dnn
52
{
53
54
template<typename Dtype>
55
static void tanh(const Mat &src, Mat &dst)
56
{
57
MatConstIterator_<Dtype> itSrc = src.begin<Dtype>();
58
MatIterator_<Dtype> itDst = dst.begin<Dtype>();
59
60
for (; itSrc != src.end<Dtype>(); itSrc++, itDst++)
61
*itDst = std::tanh(*itSrc);
62
}
63
64
//TODO: make utils method
65
static void tanh(const Mat &src, Mat &dst)
66
{
67
dst.create(src.dims, (const int*)src.size, src.type());
68
69
if (src.type() == CV_32F)
70
tanh<float>(src, dst);
71
else if (src.type() == CV_64F)
72
tanh<double>(src, dst);
73
else
74
CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types");
75
}
76
77
static void sigmoid(const Mat &src, Mat &dst)
78
{
79
cv::exp(-src, dst);
80
cv::pow(1 + dst, -1, dst);
81
}
82
83
class LSTMLayerImpl CV_FINAL : public LSTMLayer
84
{
85
int numTimeStamps, numSamples;
86
bool allocated;
87
88
MatShape outTailShape; //shape of single output sample
89
MatShape outTsShape; //shape of N output samples
90
91
bool useTimestampDim;
92
bool produceCellOutput;
93
float forgetBias, cellClip;
94
bool useCellClip, usePeephole;
95
96
public:
97
98
LSTMLayerImpl(const LayerParams& params)
99
: numTimeStamps(0), numSamples(0)
100
{
101
setParamsFrom(params);
102
103
if (!blobs.empty())
104
{
105
CV_Assert(blobs.size() >= 3);
106
107
blobs[2] = blobs[2].reshape(1, 1);
108
109
const Mat& Wh = blobs[0];
110
const Mat& Wx = blobs[1];
111
const Mat& bias = blobs[2];
112
CV_Assert(Wh.dims == 2 && Wx.dims == 2);
113
CV_Assert(Wh.rows == Wx.rows);
114
CV_Assert(Wh.rows == 4*Wh.cols);
115
CV_Assert(Wh.rows == (int)bias.total());
116
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
117
118
// Peephole weights.
119
if (blobs.size() > 3)
120
{
121
CV_Assert(blobs.size() == 6);
122
const int N = Wh.cols;
123
for (int i = 3; i < 6; ++i)
124
{
125
CV_Assert(blobs[i].rows == N && blobs[i].cols == N);
126
CV_Assert(blobs[i].type() == bias.type());
127
}
128
}
129
}
130
useTimestampDim = params.get<bool>("use_timestamp_dim", true);
131
produceCellOutput = params.get<bool>("produce_cell_output", false);
132
forgetBias = params.get<float>("forget_bias", 0.0f);
133
cellClip = params.get<float>("cell_clip", 0.0f);
134
useCellClip = params.get<bool>("use_cell_clip", false);
135
usePeephole = params.get<bool>("use_peephole", false);
136
137
allocated = false;
138
outTailShape.clear();
139
}
140
141
void setUseTimstampsDim(bool use) CV_OVERRIDE
142
{
143
CV_Assert(!allocated);
144
useTimestampDim = use;
145
}
146
147
void setProduceCellOutput(bool produce) CV_OVERRIDE
148
{
149
CV_Assert(!allocated);
150
produceCellOutput = produce;
151
}
152
153
void setOutShape(const MatShape &outTailShape_) CV_OVERRIDE
154
{
155
CV_Assert(!allocated || total(outTailShape) == total(outTailShape_));
156
outTailShape = outTailShape_;
157
}
158
159
void setWeights(const Mat &Wh, const Mat &Wx, const Mat &bias) CV_OVERRIDE
160
{
161
CV_Assert(Wh.dims == 2 && Wx.dims == 2);
162
CV_Assert(Wh.rows == Wx.rows);
163
CV_Assert(Wh.rows == 4*Wh.cols);
164
CV_Assert(Wh.rows == (int)bias.total());
165
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
166
167
blobs.resize(3);
168
blobs[0] = Mat(Wh.clone());
169
blobs[1] = Mat(Wx.clone());
170
blobs[2] = Mat(bias.clone()).reshape(1, 1);
171
}
172
173
bool getMemoryShapes(const std::vector<MatShape> &inputs,
174
const int requiredOutputs,
175
std::vector<MatShape> &outputs,
176
std::vector<MatShape> &internals) const CV_OVERRIDE
177
{
178
CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
179
CV_Assert(inputs.size() == 1);
180
const MatShape& inp0 = inputs[0];
181
182
const Mat &Wh = blobs[0], &Wx = blobs[1];
183
int _numOut = Wh.size[1];
184
int _numInp = Wx.size[1];
185
MatShape outTailShape_(outTailShape), outResShape;
186
187
if (!outTailShape_.empty())
188
CV_Assert(total(outTailShape_) == _numOut);
189
else
190
outTailShape_.assign(1, _numOut);
191
192
int _numSamples;
193
if (useTimestampDim)
194
{
195
CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp);
196
_numSamples = inp0[1];
197
outResShape.push_back(inp0[0]);
198
}
199
else
200
{
201
CV_Assert(inp0.size() >= 2 && total(inp0, 1) == _numInp);
202
_numSamples = inp0[0];
203
}
204
205
outResShape.push_back(_numSamples);
206
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
207
208
size_t noutputs = produceCellOutput ? 2 : 1;
209
outputs.assign(noutputs, outResShape);
210
211
internals.assign(1, shape(_numSamples, _numOut)); // hInternal
212
internals.push_back(shape(_numSamples, _numOut)); // cInternal
213
internals.push_back(shape(_numSamples, 1)); // dummyOnes
214
internals.push_back(shape(_numSamples, 4*_numOut)); // gates
215
216
return false;
217
}
218
219
void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
220
{
221
std::vector<Mat> input;
222
inputs_arr.getMatVector(input);
223
224
CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
225
CV_Assert(input.size() == 1);
226
const Mat& inp0 = input[0];
227
228
Mat &Wh = blobs[0], &Wx = blobs[1];
229
int numOut = Wh.size[1];
230
int numInp = Wx.size[1];
231
232
if (!outTailShape.empty())
233
CV_Assert(total(outTailShape) == numOut);
234
else
235
outTailShape.assign(1, numOut);
236
237
if (useTimestampDim)
238
{
239
CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp);
240
numTimeStamps = inp0.size[0];
241
numSamples = inp0.size[1];
242
}
243
else
244
{
245
CV_Assert(inp0.dims >= 2 && (int)inp0.total(1) == numInp);
246
numTimeStamps = 1;
247
numSamples = inp0.size[0];
248
}
249
250
outTsShape.clear();
251
outTsShape.push_back(numSamples);
252
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
253
254
allocated = true;
255
}
256
257
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
258
{
259
CV_TRACE_FUNCTION();
260
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
261
262
if (inputs_arr.depth() == CV_16S)
263
{
264
forward_fallback(inputs_arr, outputs_arr, internals_arr);
265
return;
266
}
267
268
std::vector<Mat> input, output, internals;
269
inputs_arr.getMatVector(input);
270
outputs_arr.getMatVector(output);
271
internals_arr.getMatVector(internals);
272
273
const Mat &Wh = blobs[0];
274
const Mat &Wx = blobs[1];
275
const Mat &bias = blobs[2];
276
277
int numOut = Wh.size[1];
278
279
Mat hInternal = internals[0], cInternal = internals[1],
280
dummyOnes = internals[2], gates = internals[3];
281
hInternal.setTo(0.);
282
cInternal.setTo(0.);
283
dummyOnes.setTo(1.);
284
285
int numSamplesTotal = numTimeStamps*numSamples;
286
Mat xTs = input[0].reshape(1, numSamplesTotal);
287
288
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
289
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
290
291
for (int ts = 0; ts < numTimeStamps; ts++)
292
{
293
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
294
Mat xCurr = xTs.rowRange(curRowRange);
295
296
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
297
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
298
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
299
300
Mat gateI = gates.colRange(0*numOut, 1*numOut);
301
Mat gateF = gates.colRange(1*numOut, 2*numOut);
302
Mat gateO = gates.colRange(2*numOut, 3*numOut);
303
Mat gateG = gates.colRange(3*numOut, 4*numOut);
304
305
if (forgetBias)
306
add(gateF, forgetBias, gateF);
307
308
if (usePeephole)
309
{
310
Mat gatesIF = gates.colRange(0, 2*numOut);
311
gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
312
gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
313
sigmoid(gatesIF, gatesIF);
314
}
315
else
316
{
317
Mat gatesIFO = gates.colRange(0, 3*numOut);
318
sigmoid(gatesIFO, gatesIFO);
319
}
320
321
tanh(gateG, gateG);
322
323
//compute c_t
324
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
325
multiply(gateI, gateG, gateI); // i_t (*) g_t
326
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
327
328
if (useCellClip)
329
{
330
min(cInternal, cellClip, cInternal);
331
max(cInternal, -cellClip, cInternal);
332
}
333
if (usePeephole)
334
{
335
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
336
sigmoid(gateO, gateO);
337
}
338
339
//compute h_t
340
tanh(cInternal, hInternal);
341
multiply(gateO, hInternal, hInternal);
342
343
//save results in output blobs
344
hInternal.copyTo(hOutTs.rowRange(curRowRange));
345
if (produceCellOutput)
346
cInternal.copyTo(cOutTs.rowRange(curRowRange));
347
}
348
}
349
};
350
351
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
352
{
353
return Ptr<LSTMLayer>(new LSTMLayerImpl(params));
354
}
355
356
int LSTMLayer::inputNameToIndex(String inputName)
357
{
358
if (toLowerCase(inputName) == "x")
359
return 0;
360
return -1;
361
}
362
363
int LSTMLayer::outputNameToIndex(const String& outputName)
364
{
365
if (toLowerCase(outputName) == "h")
366
return 0;
367
else if (toLowerCase(outputName) == "c")
368
return 1;
369
return -1;
370
}
371
372
373
class RNNLayerImpl : public RNNLayer
374
{
375
int numX, numH, numO;
376
int numSamples, numTimestamps, numSamplesTotal;
377
int dtype;
378
Mat Whh, Wxh, bh;
379
Mat Who, bo;
380
bool produceH;
381
382
public:
383
384
RNNLayerImpl(const LayerParams& params)
385
: numX(0), numH(0), numO(0), numSamples(0), numTimestamps(0), numSamplesTotal(0), dtype(0)
386
{
387
setParamsFrom(params);
388
type = "RNN";
389
produceH = false;
390
}
391
392
void setProduceHiddenOutput(bool produce = false) CV_OVERRIDE
393
{
394
produceH = produce;
395
}
396
397
void setWeights(const Mat &W_xh, const Mat &b_h, const Mat &W_hh, const Mat &W_ho, const Mat &b_o) CV_OVERRIDE
398
{
399
CV_Assert(W_hh.dims == 2 && W_xh.dims == 2);
400
CV_Assert(W_hh.size[0] == W_xh.size[0] && W_hh.size[0] == W_hh.size[1] && (int)b_h.total() == W_xh.size[0]);
401
CV_Assert(W_ho.size[0] == (int)b_o.total());
402
CV_Assert(W_ho.size[1] == W_hh.size[1]);
403
404
blobs.resize(5);
405
blobs[0] = Mat(W_xh.clone());
406
blobs[1] = Mat(b_h.clone());
407
blobs[2] = Mat(W_hh.clone());
408
blobs[3] = Mat(W_ho.clone());
409
blobs[4] = Mat(b_o.clone());
410
}
411
412
bool getMemoryShapes(const std::vector<MatShape> &inputs,
413
const int requiredOutputs,
414
std::vector<MatShape> &outputs,
415
std::vector<MatShape> &internals) const CV_OVERRIDE
416
{
417
CV_Assert(inputs.size() >= 1 && inputs.size() <= 2);
418
419
Mat Who_ = blobs[3];
420
Mat Wxh_ = blobs[0];
421
422
int numTimestamps_ = inputs[0][0];
423
int numSamples_ = inputs[0][1];
424
425
int numO_ = Who_.rows;
426
int numH_ = Wxh_.rows;
427
428
outputs.clear();
429
int dims[] = {numTimestamps_, numSamples_, numO_};
430
outputs.push_back(shape(dims, 3));
431
dims[2] = numH_;
432
if (produceH)
433
outputs.push_back(shape(dims, 3));
434
435
internals.assign(2, shape(numSamples_, numH_));
436
internals.push_back(shape(numSamples_, 1));
437
438
return false;
439
}
440
441
void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
442
{
443
std::vector<Mat> input, outputs;
444
inputs_arr.getMatVector(input);
445
446
CV_Assert(input.size() >= 1 && input.size() <= 2);
447
448
Wxh = blobs[0];
449
bh = blobs[1];
450
Whh = blobs[2];
451
Who = blobs[3];
452
bo = blobs[4];
453
454
numH = Wxh.rows;
455
numX = Wxh.cols;
456
numO = Who.rows;
457
458
const Mat& inp0 = input[0];
459
460
CV_Assert(inp0.dims >= 2);
461
CV_Assert(inp0.total(2) == numX);
462
dtype = CV_32F;
463
CV_Assert(inp0.type() == dtype);
464
numTimestamps = inp0.size[0];
465
numSamples = inp0.size[1];
466
numSamplesTotal = numTimestamps * numSamples;
467
468
bh = bh.reshape(1, 1); //is 1 x numH Mat
469
bo = bo.reshape(1, 1); //is 1 x numO Mat
470
}
471
472
void reshapeOutput(std::vector<Mat> &output)
473
{
474
output.resize(produceH ? 2 : 1);
475
int sz0[] = { numTimestamps, numSamples, numO };
476
output[0].create(3, sz0, dtype);
477
if (produceH)
478
{
479
int sz1[] = { numTimestamps, numSamples, numH };
480
output[1].create(3, sz1, dtype);
481
}
482
}
483
484
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
485
{
486
CV_TRACE_FUNCTION();
487
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
488
489
if (inputs_arr.depth() == CV_16S)
490
{
491
forward_fallback(inputs_arr, outputs_arr, internals_arr);
492
return;
493
}
494
495
std::vector<Mat> input, output, internals;
496
inputs_arr.getMatVector(input);
497
outputs_arr.getMatVector(output);
498
internals_arr.getMatVector(internals);
499
500
Mat xTs = input[0].reshape(1, numSamplesTotal);
501
Mat oTs = output[0].reshape(1, numSamplesTotal);
502
Mat hTs = produceH ? output[1].reshape(1, numSamplesTotal) : Mat();
503
Mat hCurr = internals[0];
504
Mat hPrev = internals[1];
505
Mat dummyBiasOnes = internals[2];
506
507
hPrev.setTo(0.);
508
dummyBiasOnes.setTo(1.);
509
510
for (int ts = 0; ts < numTimestamps; ts++)
511
{
512
Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples);
513
Mat xCurr = xTs.rowRange(curRowRange);
514
515
gemm(hPrev, Whh, 1, hCurr, 0, hCurr, GEMM_2_T); // W_{hh} * h_{prev}
516
gemm(xCurr, Wxh, 1, hCurr, 1, hCurr, GEMM_2_T); //+W_{xh} * x_{curr}
517
gemm(dummyBiasOnes, bh, 1, hCurr, 1, hCurr); //+bh
518
tanh(hCurr, hPrev);
519
520
Mat oCurr = oTs.rowRange(curRowRange);
521
gemm(hPrev, Who, 1, oCurr, 0, oCurr, GEMM_2_T); // W_{ho} * h_{prev}
522
gemm(dummyBiasOnes, bo, 1, oCurr, 1, oCurr); //+b_o
523
tanh(oCurr, oCurr);
524
525
if (produceH)
526
hPrev.copyTo(hTs.rowRange(curRowRange));
527
}
528
}
529
};
530
531
CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create(const LayerParams& params)
532
{
533
return Ptr<RNNLayer>(new RNNLayerImpl(params));
534
}
535
536
}
537
}
538
539