Path: blob/master/modules/dnn/src/layers/recurrent_layers.cpp
16337 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// License Agreement10// For Open Source Computer Vision Library11//12// Copyright (C) 2013, OpenCV Foundation, all rights reserved.13// Copyright (C) 2017, Intel Corporation, all rights reserved.14// Third party copyrights are property of their respective owners.15//16// Redistribution and use in source and binary forms, with or without modification,17// are permitted provided that the following conditions are met:18//19// * Redistribution's of source code must retain the above copyright notice,20// this list of conditions and the following disclaimer.21//22// * Redistribution's in binary form must reproduce the above copyright notice,23// this list of conditions and the following disclaimer in the documentation24// and/or other materials provided with the distribution.25//26// * The name of the copyright holders may not be used to endorse or promote products27// derived from this software without specific prior written permission.28//29// This software is provided by the copyright holders and contributors "as is" and30// any express or implied warranties, including, but not limited to, the implied31// warranties of merchantability and fitness for a particular purpose are disclaimed.32// In no event shall the Intel Corporation or contributors be liable for any direct,33// indirect, incidental, special, exemplary, or consequential damages34// (including, but not limited to, procurement of substitute goods or services;35// loss of use, data, or profits; or business interruption) however caused36// and on any theory of liability, whether in contract, strict liability,37// or tort (including negligence or otherwise) arising in any way out of38// the use of this software, even if advised of the possibility of such damage.39//40//M*/4142#include "../precomp.hpp"43#include <iostream>44#include <iterator>45#include <cmath>46#include <opencv2/dnn/shape_utils.hpp>4748namespace cv49{50namespace dnn51{5253template<typename Dtype>54static void tanh(const Mat &src, Mat &dst)55{56MatConstIterator_<Dtype> itSrc = src.begin<Dtype>();57MatIterator_<Dtype> itDst = dst.begin<Dtype>();5859for (; itSrc != src.end<Dtype>(); itSrc++, itDst++)60*itDst = std::tanh(*itSrc);61}6263//TODO: make utils method64static void tanh(const Mat &src, Mat &dst)65{66dst.create(src.dims, (const int*)src.size, src.type());6768if (src.type() == CV_32F)69tanh<float>(src, dst);70else if (src.type() == CV_64F)71tanh<double>(src, dst);72else73CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types");74}7576static void sigmoid(const Mat &src, Mat &dst)77{78cv::exp(-src, dst);79cv::pow(1 + dst, -1, dst);80}8182class LSTMLayerImpl CV_FINAL : public LSTMLayer83{84int numTimeStamps, numSamples;85bool allocated;8687MatShape outTailShape; //shape of single output sample88MatShape outTsShape; //shape of N output samples8990bool useTimestampDim;91bool produceCellOutput;92float forgetBias, cellClip;93bool useCellClip, usePeephole;9495public:9697LSTMLayerImpl(const LayerParams& params)98: numTimeStamps(0), numSamples(0)99{100setParamsFrom(params);101102if (!blobs.empty())103{104CV_Assert(blobs.size() >= 3);105106blobs[2] = blobs[2].reshape(1, 1);107108const Mat& Wh = blobs[0];109const Mat& Wx = blobs[1];110const Mat& bias = blobs[2];111CV_Assert(Wh.dims == 2 && Wx.dims == 2);112CV_Assert(Wh.rows == Wx.rows);113CV_Assert(Wh.rows == 4*Wh.cols);114CV_Assert(Wh.rows == (int)bias.total());115CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());116117// Peephole weights.118if (blobs.size() > 3)119{120CV_Assert(blobs.size() == 6);121const int N = Wh.cols;122for (int i = 3; i < 6; ++i)123{124CV_Assert(blobs[i].rows == N && blobs[i].cols == N);125CV_Assert(blobs[i].type() == bias.type());126}127}128}129useTimestampDim = params.get<bool>("use_timestamp_dim", true);130produceCellOutput = params.get<bool>("produce_cell_output", false);131forgetBias = params.get<float>("forget_bias", 0.0f);132cellClip = params.get<float>("cell_clip", 0.0f);133useCellClip = params.get<bool>("use_cell_clip", false);134usePeephole = params.get<bool>("use_peephole", false);135136allocated = false;137outTailShape.clear();138}139140void setUseTimstampsDim(bool use) CV_OVERRIDE141{142CV_Assert(!allocated);143useTimestampDim = use;144}145146void setProduceCellOutput(bool produce) CV_OVERRIDE147{148CV_Assert(!allocated);149produceCellOutput = produce;150}151152void setOutShape(const MatShape &outTailShape_) CV_OVERRIDE153{154CV_Assert(!allocated || total(outTailShape) == total(outTailShape_));155outTailShape = outTailShape_;156}157158void setWeights(const Mat &Wh, const Mat &Wx, const Mat &bias) CV_OVERRIDE159{160CV_Assert(Wh.dims == 2 && Wx.dims == 2);161CV_Assert(Wh.rows == Wx.rows);162CV_Assert(Wh.rows == 4*Wh.cols);163CV_Assert(Wh.rows == (int)bias.total());164CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());165166blobs.resize(3);167blobs[0] = Mat(Wh.clone());168blobs[1] = Mat(Wx.clone());169blobs[2] = Mat(bias.clone()).reshape(1, 1);170}171172bool getMemoryShapes(const std::vector<MatShape> &inputs,173const int requiredOutputs,174std::vector<MatShape> &outputs,175std::vector<MatShape> &internals) const CV_OVERRIDE176{177CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);178CV_Assert(inputs.size() == 1);179const MatShape& inp0 = inputs[0];180181const Mat &Wh = blobs[0], &Wx = blobs[1];182int _numOut = Wh.size[1];183int _numInp = Wx.size[1];184MatShape outTailShape_(outTailShape), outResShape;185186if (!outTailShape_.empty())187CV_Assert(total(outTailShape_) == _numOut);188else189outTailShape_.assign(1, _numOut);190191int _numSamples;192if (useTimestampDim)193{194CV_Assert(inp0.size() >= 2 && total(inp0, 2) == _numInp);195_numSamples = inp0[1];196outResShape.push_back(inp0[0]);197}198else199{200CV_Assert(inp0.size() >= 2 && total(inp0, 1) == _numInp);201_numSamples = inp0[0];202}203204outResShape.push_back(_numSamples);205outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());206207size_t noutputs = produceCellOutput ? 2 : 1;208outputs.assign(noutputs, outResShape);209210internals.assign(1, shape(_numSamples, _numOut)); // hInternal211internals.push_back(shape(_numSamples, _numOut)); // cInternal212internals.push_back(shape(_numSamples, 1)); // dummyOnes213internals.push_back(shape(_numSamples, 4*_numOut)); // gates214215return false;216}217218void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE219{220std::vector<Mat> input;221inputs_arr.getMatVector(input);222223CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);224CV_Assert(input.size() == 1);225const Mat& inp0 = input[0];226227Mat &Wh = blobs[0], &Wx = blobs[1];228int numOut = Wh.size[1];229int numInp = Wx.size[1];230231if (!outTailShape.empty())232CV_Assert(total(outTailShape) == numOut);233else234outTailShape.assign(1, numOut);235236if (useTimestampDim)237{238CV_Assert(inp0.dims >= 2 && (int)inp0.total(2) == numInp);239numTimeStamps = inp0.size[0];240numSamples = inp0.size[1];241}242else243{244CV_Assert(inp0.dims >= 2 && (int)inp0.total(1) == numInp);245numTimeStamps = 1;246numSamples = inp0.size[0];247}248249outTsShape.clear();250outTsShape.push_back(numSamples);251outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());252253allocated = true;254}255256void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE257{258CV_TRACE_FUNCTION();259CV_TRACE_ARG_VALUE(name, "name", name.c_str());260261if (inputs_arr.depth() == CV_16S)262{263forward_fallback(inputs_arr, outputs_arr, internals_arr);264return;265}266267std::vector<Mat> input, output, internals;268inputs_arr.getMatVector(input);269outputs_arr.getMatVector(output);270internals_arr.getMatVector(internals);271272const Mat &Wh = blobs[0];273const Mat &Wx = blobs[1];274const Mat &bias = blobs[2];275276int numOut = Wh.size[1];277278Mat hInternal = internals[0], cInternal = internals[1],279dummyOnes = internals[2], gates = internals[3];280hInternal.setTo(0.);281cInternal.setTo(0.);282dummyOnes.setTo(1.);283284int numSamplesTotal = numTimeStamps*numSamples;285Mat xTs = input[0].reshape(1, numSamplesTotal);286287Mat hOutTs = output[0].reshape(1, numSamplesTotal);288Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();289290for (int ts = 0; ts < numTimeStamps; ts++)291{292Range curRowRange(ts*numSamples, (ts + 1)*numSamples);293Mat xCurr = xTs.rowRange(curRowRange);294295gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t296gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}297gemm(dummyOnes, bias, 1, gates, 1, gates); //+b298299Mat gateI = gates.colRange(0*numOut, 1*numOut);300Mat gateF = gates.colRange(1*numOut, 2*numOut);301Mat gateO = gates.colRange(2*numOut, 3*numOut);302Mat gateG = gates.colRange(3*numOut, 4*numOut);303304if (forgetBias)305add(gateF, forgetBias, gateF);306307if (usePeephole)308{309Mat gatesIF = gates.colRange(0, 2*numOut);310gemm(cInternal, blobs[3], 1, gateI, 1, gateI);311gemm(cInternal, blobs[4], 1, gateF, 1, gateF);312sigmoid(gatesIF, gatesIF);313}314else315{316Mat gatesIFO = gates.colRange(0, 3*numOut);317sigmoid(gatesIFO, gatesIFO);318}319320tanh(gateG, gateG);321322//compute c_t323multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}324multiply(gateI, gateG, gateI); // i_t (*) g_t325add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t326327if (useCellClip)328{329min(cInternal, cellClip, cInternal);330max(cInternal, -cellClip, cInternal);331}332if (usePeephole)333{334gemm(cInternal, blobs[5], 1, gateO, 1, gateO);335sigmoid(gateO, gateO);336}337338//compute h_t339tanh(cInternal, hInternal);340multiply(gateO, hInternal, hInternal);341342//save results in output blobs343hInternal.copyTo(hOutTs.rowRange(curRowRange));344if (produceCellOutput)345cInternal.copyTo(cOutTs.rowRange(curRowRange));346}347}348};349350Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)351{352return Ptr<LSTMLayer>(new LSTMLayerImpl(params));353}354355int LSTMLayer::inputNameToIndex(String inputName)356{357if (toLowerCase(inputName) == "x")358return 0;359return -1;360}361362int LSTMLayer::outputNameToIndex(const String& outputName)363{364if (toLowerCase(outputName) == "h")365return 0;366else if (toLowerCase(outputName) == "c")367return 1;368return -1;369}370371372class RNNLayerImpl : public RNNLayer373{374int numX, numH, numO;375int numSamples, numTimestamps, numSamplesTotal;376int dtype;377Mat Whh, Wxh, bh;378Mat Who, bo;379bool produceH;380381public:382383RNNLayerImpl(const LayerParams& params)384: numX(0), numH(0), numO(0), numSamples(0), numTimestamps(0), numSamplesTotal(0), dtype(0)385{386setParamsFrom(params);387type = "RNN";388produceH = false;389}390391void setProduceHiddenOutput(bool produce = false) CV_OVERRIDE392{393produceH = produce;394}395396void setWeights(const Mat &W_xh, const Mat &b_h, const Mat &W_hh, const Mat &W_ho, const Mat &b_o) CV_OVERRIDE397{398CV_Assert(W_hh.dims == 2 && W_xh.dims == 2);399CV_Assert(W_hh.size[0] == W_xh.size[0] && W_hh.size[0] == W_hh.size[1] && (int)b_h.total() == W_xh.size[0]);400CV_Assert(W_ho.size[0] == (int)b_o.total());401CV_Assert(W_ho.size[1] == W_hh.size[1]);402403blobs.resize(5);404blobs[0] = Mat(W_xh.clone());405blobs[1] = Mat(b_h.clone());406blobs[2] = Mat(W_hh.clone());407blobs[3] = Mat(W_ho.clone());408blobs[4] = Mat(b_o.clone());409}410411bool getMemoryShapes(const std::vector<MatShape> &inputs,412const int requiredOutputs,413std::vector<MatShape> &outputs,414std::vector<MatShape> &internals) const CV_OVERRIDE415{416CV_Assert(inputs.size() >= 1 && inputs.size() <= 2);417418Mat Who_ = blobs[3];419Mat Wxh_ = blobs[0];420421int numTimestamps_ = inputs[0][0];422int numSamples_ = inputs[0][1];423424int numO_ = Who_.rows;425int numH_ = Wxh_.rows;426427outputs.clear();428int dims[] = {numTimestamps_, numSamples_, numO_};429outputs.push_back(shape(dims, 3));430dims[2] = numH_;431if (produceH)432outputs.push_back(shape(dims, 3));433434internals.assign(2, shape(numSamples_, numH_));435internals.push_back(shape(numSamples_, 1));436437return false;438}439440void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE441{442std::vector<Mat> input, outputs;443inputs_arr.getMatVector(input);444445CV_Assert(input.size() >= 1 && input.size() <= 2);446447Wxh = blobs[0];448bh = blobs[1];449Whh = blobs[2];450Who = blobs[3];451bo = blobs[4];452453numH = Wxh.rows;454numX = Wxh.cols;455numO = Who.rows;456457const Mat& inp0 = input[0];458459CV_Assert(inp0.dims >= 2);460CV_Assert(inp0.total(2) == numX);461dtype = CV_32F;462CV_Assert(inp0.type() == dtype);463numTimestamps = inp0.size[0];464numSamples = inp0.size[1];465numSamplesTotal = numTimestamps * numSamples;466467bh = bh.reshape(1, 1); //is 1 x numH Mat468bo = bo.reshape(1, 1); //is 1 x numO Mat469}470471void reshapeOutput(std::vector<Mat> &output)472{473output.resize(produceH ? 2 : 1);474int sz0[] = { numTimestamps, numSamples, numO };475output[0].create(3, sz0, dtype);476if (produceH)477{478int sz1[] = { numTimestamps, numSamples, numH };479output[1].create(3, sz1, dtype);480}481}482483void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE484{485CV_TRACE_FUNCTION();486CV_TRACE_ARG_VALUE(name, "name", name.c_str());487488if (inputs_arr.depth() == CV_16S)489{490forward_fallback(inputs_arr, outputs_arr, internals_arr);491return;492}493494std::vector<Mat> input, output, internals;495inputs_arr.getMatVector(input);496outputs_arr.getMatVector(output);497internals_arr.getMatVector(internals);498499Mat xTs = input[0].reshape(1, numSamplesTotal);500Mat oTs = output[0].reshape(1, numSamplesTotal);501Mat hTs = produceH ? output[1].reshape(1, numSamplesTotal) : Mat();502Mat hCurr = internals[0];503Mat hPrev = internals[1];504Mat dummyBiasOnes = internals[2];505506hPrev.setTo(0.);507dummyBiasOnes.setTo(1.);508509for (int ts = 0; ts < numTimestamps; ts++)510{511Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples);512Mat xCurr = xTs.rowRange(curRowRange);513514gemm(hPrev, Whh, 1, hCurr, 0, hCurr, GEMM_2_T); // W_{hh} * h_{prev}515gemm(xCurr, Wxh, 1, hCurr, 1, hCurr, GEMM_2_T); //+W_{xh} * x_{curr}516gemm(dummyBiasOnes, bh, 1, hCurr, 1, hCurr); //+bh517tanh(hCurr, hPrev);518519Mat oCurr = oTs.rowRange(curRowRange);520gemm(hPrev, Who, 1, oCurr, 0, oCurr, GEMM_2_T); // W_{ho} * h_{prev}521gemm(dummyBiasOnes, bo, 1, oCurr, 1, oCurr); //+b_o522tanh(oCurr, oCurr);523524if (produceH)525hPrev.copyTo(hTs.rowRange(curRowRange));526}527}528};529530CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create(const LayerParams& params)531{532return Ptr<RNNLayer>(new RNNLayerImpl(params));533}534535}536}537538539