Path: blob/master/modules/dnn/src/layers/batch_norm_layer.cpp
16337 views
// This file is part of OpenCV project.1// It is subject to the license terms in the LICENSE file found in the top-level directory2// of this distribution and at http://opencv.org/license.html.34// Copyright (C) 2016, Intel Corporation, all rights reserved.5// Third party copyrights are property of their respective owners.67/*8Implementation of Batch Normalization layer.9*/1011#include "../precomp.hpp"12#include "../op_halide.hpp"13#include "../op_inf_engine.hpp"14#include <opencv2/dnn/shape_utils.hpp>1516#ifdef HAVE_OPENCL17#include "opencl_kernels_dnn.hpp"18#endif1920namespace cv21{22namespace dnn23{2425class BatchNormLayerImpl CV_FINAL : public BatchNormLayer26{27public:28Mat weights_, bias_;29UMat umat_weight, umat_bias;3031BatchNormLayerImpl(const LayerParams& params)32{33setParamsFrom(params);34CV_Assert(blobs.size() >= 2);3536hasWeights = params.get<bool>("has_weight", false);37hasBias = params.get<bool>("has_bias", false);38useGlobalStats = params.get<bool>("use_global_stats", true);39if(params.get<bool>("scale_bias", false))40hasWeights = hasBias = true;41epsilon = params.get<float>("eps", 1E-5);4243size_t n = blobs[0].total();44CV_Assert(blobs[1].total() == n &&45blobs[0].isContinuous() && blobs[1].isContinuous() &&46blobs[0].type() == CV_32F && blobs[1].type() == CV_32F);4748float varMeanScale = 1.f;49if (!hasWeights && !hasBias && blobs.size() > 2 && useGlobalStats) {50CV_Assert(blobs.size() == 3); CV_CheckTypeEQ(blobs[2].type(), CV_32FC1, "");51varMeanScale = blobs[2].at<float>(0);52if (varMeanScale != 0)53varMeanScale = 1/varMeanScale;54}5556const int biasBlobIndex = blobs.size() - 1;57const int weightsBlobIndex = biasBlobIndex - hasBias;5859if( hasWeights )60{61CV_Assert((size_t)weightsBlobIndex < blobs.size());62const Mat& w = blobs[weightsBlobIndex];63CV_Assert(w.isContinuous() && w.type() == CV_32F && w.total() == (size_t)n);64}6566if( hasBias )67{68CV_Assert((size_t)biasBlobIndex < blobs.size());69const Mat& b = blobs[weightsBlobIndex];70CV_Assert(b.isContinuous() && b.type() == CV_32F && b.total() == (size_t)n);71}7273const float* meanData = blobs[0].ptr<float>();74const float* stdData = blobs[1].ptr<float>();75const float* weightsData = hasWeights ? blobs[weightsBlobIndex].ptr<float>() : 0;76const float* biasData = hasBias ? blobs[biasBlobIndex].ptr<float>() : 0;7778weights_.create(1, (int)n, CV_32F);79bias_.create(1, (int)n, CV_32F);8081float* dstWeightsData = weights_.ptr<float>();82float* dstBiasData = bias_.ptr<float>();8384for (size_t i = 0; i < n; ++i)85{86float w = (hasWeights ? weightsData[i] : 1.0f) / sqrt(stdData[i] * varMeanScale + epsilon);87dstWeightsData[i] = w;88dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale;89}90}9192void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE93{94scale = weights_;95shift = bias_;96}9798virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE99{100Mat w, b;101top->getScaleShift(w, b);102if (w.empty() && b.empty())103return false;104105const int numChannels = weights_.total();106const int numFusedWeights = w.total();107const int numFusedBias = b.total();108109if ((numFusedWeights != numChannels && numFusedWeights != 1 && !w.empty()) ||110(numFusedBias != numChannels && numFusedBias != 1 && !b.empty()))111return false;112113if (!w.empty())114{115w = w.reshape(1, 1);116if (numFusedWeights == 1)117{118multiply(weights_, w.at<float>(0), weights_);119multiply(bias_, w.at<float>(0), bias_);120}121else122{123multiply(weights_, w, weights_);124multiply(bias_, w, bias_);125}126}127if (!b.empty())128{129b = b.reshape(1, 1);130if (numFusedBias == 1)131add(bias_, b.at<float>(0), bias_);132else133add(bias_, b.reshape(1, 1), bias_);134}135return true;136}137138bool getMemoryShapes(const std::vector<MatShape> &inputs,139const int requiredOutputs,140std::vector<MatShape> &outputs,141std::vector<MatShape> &internals) const CV_OVERRIDE142{143if (!useGlobalStats && inputs[0][0] != 1)144CV_Error(Error::StsNotImplemented, "Batch normalization in training mode with batch size > 1");145Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);146return true;147}148149virtual bool supportBackend(int backendId) CV_OVERRIDE150{151return backendId == DNN_BACKEND_OPENCV ||152backendId == DNN_BACKEND_HALIDE && haveHalide() ||153backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine();154}155156#ifdef HAVE_OPENCL157bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)158{159std::vector<UMat> inputs;160std::vector<UMat> outputs;161162bool use_half = (inputs_.depth() == CV_16S);163inputs_.getUMatVector(inputs);164outputs_.getUMatVector(outputs);165166CV_Assert(blobs.size() >= 2);167CV_Assert(inputs.size() == 1);168169if (use_half && inputs[0].dims == 2)170return false;171172if (umat_weight.empty())173{174weights_.copyTo(umat_weight);175bias_.copyTo(umat_bias);176}177178UMat &inpBlob = inputs[0];179CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);180int groups = inpBlob.size[0];181int channels = inpBlob.size[1];182int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;183int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;184185String opts = (use_half) ? " -DDtype=half" : " -DDtype=float";186for (size_t ii = 0; ii < outputs.size(); ii++)187{188if (inpBlob.dims == 2)189{190UMat& src = inputs[ii];191UMat& dst = outputs[ii];192multiply(src, weights_, dst);193add(dst, bias_, dst);194}195else196{197MatShape s = shape(groups * channels, rows * cols);198UMat src = inputs[ii].reshape(1, s.size(), &s[0]);199UMat dst = outputs[ii].reshape(1, s.size(), &s[0]);200int number = (s[1] % 8 == 0) ? 8 : ((s[1] % 4 == 0) ? 4 : 1);201String buildopt = format("-DNUM=%d", number) + opts;202String kname = format("batch_norm%d", number);203if (number == 1)204buildopt += format(" -Dconvert_T=convert_%s", use_half ? "half" : "float");205else206buildopt += format(" -Dconvert_T=convert_%s%d", use_half ? "half" : "float", number);207ocl::Kernel kernel(kname.c_str(), ocl::dnn::batchnorm_oclsrc, buildopt);208if (kernel.empty())209return false;210size_t global[] = { (size_t)s[0], (size_t)(s[1] / number) };211kernel.set(0, ocl::KernelArg::PtrReadOnly(src));212kernel.set(1, (int)s[0]);213kernel.set(2, (int)s[1]);214kernel.set(3, (int)channels);215kernel.set(4, ocl::KernelArg::PtrReadOnly(umat_weight));216kernel.set(5, ocl::KernelArg::PtrReadOnly(umat_bias));217kernel.set(6, ocl::KernelArg::PtrWriteOnly(dst));218bool ret = kernel.run(2, global, NULL, false);219if (!ret)220return false;221}222}223return true;224}225#endif226227void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE228{229CV_TRACE_FUNCTION();230CV_TRACE_ARG_VALUE(name, "name", name.c_str());231232CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),233forward_ocl(inputs_arr, outputs_arr, internals_arr))234235if (inputs_arr.depth() == CV_16S)236{237forward_fallback(inputs_arr, outputs_arr, internals_arr);238return;239}240241std::vector<Mat> inputs, outputs;242inputs_arr.getMatVector(inputs);243outputs_arr.getMatVector(outputs);244245CV_Assert(blobs.size() >= 2);246CV_Assert(inputs.size() == 1);247248Mat &inpBlob = inputs[0];249CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);250int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;251int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;252253for (size_t ii = 0; ii < outputs.size(); ii++)254{255Mat &outBlob = outputs[ii];256257for(int num = 0; num < outBlob.size[0]; num++)258{259for (int n = 0; n < outBlob.size[1]; n++)260{261float w = weights_.at<float>(n);262float b = bias_.at<float>(n);263Mat inpBlobPlane(rows, cols, CV_32F, inpBlob.ptr<float>(num, n));264Mat outBlobPlane(rows, cols, CV_32F, outBlob.ptr<float>(num, n));265inpBlobPlane.convertTo(outBlobPlane, CV_32F, w, b);266}267}268}269}270271void forwardSlice(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const CV_OVERRIDE272{273for( int cn = cn0; cn < cn1; cn++, srcptr += planeSize, dstptr += planeSize )274{275int i = 0;276float w = weights_.at<float>(cn);277float b = bias_.at<float>(cn);278#if CV_SIMD128279v_float32x4 wV = v_setall_f32(w), bV = v_setall_f32(b);280for( ; i <= len - 16; i += 16 )281{282v_float32x4 x0 = v_load(srcptr + i);283v_float32x4 x1 = v_load(srcptr + i + 4);284v_float32x4 x2 = v_load(srcptr + i + 8);285v_float32x4 x3 = v_load(srcptr + i + 12);286x0 = v_muladd(x0, w, b);287x1 = v_muladd(x1, w, b);288x2 = v_muladd(x2, w, b);289x3 = v_muladd(x3, w, b);290v_store(dstptr + i, x0);291v_store(dstptr + i + 4, x1);292v_store(dstptr + i + 8, x2);293v_store(dstptr + i + 12, x3);294}295#endif296for( ; i < len; i++ )297dstptr[i] = w * srcptr[i] + b;298}299}300301virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node) CV_OVERRIDE302{303switch (node->backendId)304{305case DNN_BACKEND_HALIDE:306{307#ifdef HAVE_HALIDE308auto base = node.dynamicCast<HalideBackendNode>();309Halide::Func& input = base->funcs.back();310Halide::Var x("x"), y("y"), c("c"), n("n");311Halide::Func top = attachHalide(input(x, y, c, n));312return Ptr<BackendNode>(new HalideBackendNode(base, top));313#endif // HAVE_HALIDE314break;315}316}317return Ptr<BackendNode>();318}319320virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE321{322#ifdef HAVE_HALIDE323Halide::Buffer<float> input = halideBuffer(inputs[0]);324Halide::Var x("x"), y("y"), c("c"), n("n");325Halide::Func top = attachHalide(input(x, y, c, n));326return Ptr<BackendNode>(new HalideBackendNode(top));327#endif // HAVE_HALIDE328return Ptr<BackendNode>();329}330331#ifdef HAVE_HALIDE332// attachHalide can work both with Halide::Buffer and Halide::Func. In the333// second case it will be a fusion.334Halide::Func attachHalide(const Halide::Expr& input)335{336Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));337Halide::Var x("x"), y("y"), c("c"), n("n");338339const int numChannels = weights_.total();340auto weights = wrapToHalideBuffer(weights_, {numChannels});341auto bias = wrapToHalideBuffer(bias_, {numChannels});342top(x, y, c, n) = input * weights(c) + bias(c);343return top;344}345#endif // HAVE_HALIDE346347virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE348{349#ifdef HAVE_INF_ENGINE350InferenceEngine::LayerParams lp;351lp.name = name;352lp.type = "ScaleShift";353lp.precision = InferenceEngine::Precision::FP32;354std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));355356const size_t numChannels = weights_.total();357ieLayer->_weights = wrapToInfEngineBlob(weights_, {numChannels}, InferenceEngine::Layout::C);358ieLayer->_biases = wrapToInfEngineBlob(bias_, {numChannels}, InferenceEngine::Layout::C);359360return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));361#endif // HAVE_INF_ENGINE362return Ptr<BackendNode>();363}364365virtual int64 getFLOPS(const std::vector<MatShape> &inputs,366const std::vector<MatShape> &outputs) const CV_OVERRIDE367{368CV_UNUSED(outputs); // suppress unused variable warning369370int64 flops = 0;371for(int i = 0; i < inputs.size(); i++)372{373flops += 3*total(inputs[i]);374}375return flops;376}377378private:379bool useGlobalStats;380};381382Ptr<BatchNormLayer> BatchNormLayer::create(const LayerParams& params)383{384return Ptr<BatchNormLayer>(new BatchNormLayerImpl(params));385}386387} // namespace dnn388} // namespace cv389390391