Path: blob/master/modules/dnn/src/caffe/caffe_importer.cpp
16339 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// License Agreement10// For Open Source Computer Vision Library11//12// Copyright (C) 2013, OpenCV Foundation, all rights reserved.13// Third party copyrights are property of their respective owners.14//15// Redistribution and use in source and binary forms, with or without modification,16// are permitted provided that the following conditions are met:17//18// * Redistribution's of source code must retain the above copyright notice,19// this list of conditions and the following disclaimer.20//21// * Redistribution's in binary form must reproduce the above copyright notice,22// this list of conditions and the following disclaimer in the documentation23// and/or other materials provided with the distribution.24//25// * The name of the copyright holders may not be used to endorse or promote products26// derived from this software without specific prior written permission.27//28// This software is provided by the copyright holders and contributors "as is" and29// any express or implied warranties, including, but not limited to, the implied30// warranties of merchantability and fitness for a particular purpose are disclaimed.31// In no event shall the Intel Corporation or contributors be liable for any direct,32// indirect, incidental, special, exemplary, or consequential damages33// (including, but not limited to, procurement of substitute goods or services;34// loss of use, data, or profits; or business interruption) however caused35// and on any theory of liability, whether in contract, strict liability,36// or tort (including negligence or otherwise) arising in any way out of37// the use of this software, even if advised of the possibility of such damage.38//39//M*/4041#include "../precomp.hpp"4243#ifdef HAVE_PROTOBUF44#include <iostream>45#include <fstream>46#include <sstream>47#include <algorithm>48#include <google/protobuf/message.h>49#include <google/protobuf/text_format.h>50#include <google/protobuf/io/zero_copy_stream_impl.h>51#include "caffe_io.hpp"52#endif5354namespace cv {55namespace dnn {56CV__DNN_INLINE_NS_BEGIN5758#ifdef HAVE_PROTOBUF59using ::google::protobuf::RepeatedField;60using ::google::protobuf::RepeatedPtrField;61using ::google::protobuf::Message;62using ::google::protobuf::Descriptor;63using ::google::protobuf::FieldDescriptor;64using ::google::protobuf::Reflection;6566namespace67{6869template<typename T>70static cv::String toString(const T &v)71{72std::ostringstream ss;73ss << v;74return ss.str();75}7677class CaffeImporter78{79caffe::NetParameter net;80caffe::NetParameter netBinary;8182public:8384CaffeImporter(const char *pototxt, const char *caffeModel)85{86CV_TRACE_FUNCTION();8788ReadNetParamsFromTextFileOrDie(pototxt, &net);8990if (caffeModel && caffeModel[0])91ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);92}9394CaffeImporter(const char *dataProto, size_t lenProto,95const char *dataModel, size_t lenModel)96{97CV_TRACE_FUNCTION();9899ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);100101if (dataModel != NULL && lenModel > 0)102ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);103}104105void extractCustomParams(const google::protobuf::UnknownFieldSet& unknownFields, cv::dnn::LayerParams ¶ms)106{107const int numFields = unknownFields.field_count();108for (int i = 0; i < numFields; ++i)109{110const google::protobuf::UnknownField& field = unknownFields.field(i);111CV_Assert(field.type() == google::protobuf::UnknownField::TYPE_GROUP);112std::string fieldName = field.group().field(0).length_delimited();113std::string fieldValue = field.group().field(1).length_delimited();114params.set(fieldName, fieldValue);115}116}117118void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms)119{120const Reflection *refl = msg.GetReflection();121int type = field->cpp_type();122bool isRepeated = field->is_repeated();123const std::string &name = field->name();124125#define SET_UP_FILED(getter, arrayConstr, gtype) \126if (isRepeated) { \127const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \128params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \129} \130else { \131params.set(name, refl->getter(msg, field)); \132}133134switch (type)135{136case FieldDescriptor::CPPTYPE_INT32:137SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);138break;139case FieldDescriptor::CPPTYPE_UINT32:140SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);141break;142case FieldDescriptor::CPPTYPE_INT64:143SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);144break;145case FieldDescriptor::CPPTYPE_UINT64:146SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);147break;148case FieldDescriptor::CPPTYPE_BOOL:149SET_UP_FILED(GetBool, arrayInt, bool);150break;151case FieldDescriptor::CPPTYPE_DOUBLE:152SET_UP_FILED(GetDouble, arrayReal, double);153break;154case FieldDescriptor::CPPTYPE_FLOAT:155SET_UP_FILED(GetFloat, arrayReal, float);156break;157case FieldDescriptor::CPPTYPE_STRING:158if (isRepeated) {159const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);160params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));161}162else {163params.set(name, refl->GetString(msg, field));164}165break;166case FieldDescriptor::CPPTYPE_ENUM:167if (isRepeated) {168int size = refl->FieldSize(msg, field);169std::vector<cv::String> buf(size);170for (int i = 0; i < size; i++)171buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();172params.set(name, DictValue::arrayString(buf.begin(), size));173}174else {175params.set(name, refl->GetEnum(msg, field)->name());176}177break;178default:179CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");180break;181}182}183184inline static bool ends_with_param(const std::string &str)185{186static const std::string _param("_param");187return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;188}189190void extractLayerParams(const Message &msg, cv::dnn::LayerParams ¶ms, bool isInternal = false)191{192const Descriptor *msgDesc = msg.GetDescriptor();193const Reflection *msgRefl = msg.GetReflection();194195for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)196{197const FieldDescriptor *fd = msgDesc->field(fieldId);198199if (!isInternal && !ends_with_param(fd->name()))200continue;201202const google::protobuf::UnknownFieldSet& unknownFields = msgRefl->GetUnknownFields(msg);203bool hasData = fd->is_required() ||204(fd->is_optional() && msgRefl->HasField(msg, fd)) ||205(fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0) ||206!unknownFields.empty();207if (!hasData)208continue;209210extractCustomParams(unknownFields, params);211if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)212{213if (fd->is_repeated()) //Extract only first item!214extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);215else216extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);217}218else219{220addParam(msg, fd, params);221}222}223}224225void blobShapeFromProto(const caffe::BlobProto &pbBlob, MatShape& shape)226{227shape.clear();228if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())229{230shape.push_back(pbBlob.num());231shape.push_back(pbBlob.channels());232shape.push_back(pbBlob.height());233shape.push_back(pbBlob.width());234}235else if (pbBlob.has_shape())236{237const caffe::BlobShape &_shape = pbBlob.shape();238239for (int i = 0; i < _shape.dim_size(); i++)240shape.push_back((int)_shape.dim(i));241}242else243shape.resize(1, 1); // Is a scalar.244}245246void blobFromProto(const caffe::BlobProto &pbBlob, cv::Mat &dstBlob)247{248MatShape shape;249blobShapeFromProto(pbBlob, shape);250251dstBlob.create((int)shape.size(), &shape[0], CV_32F);252if (pbBlob.data_size())253{254// Single precision floats.255CV_Assert(pbBlob.data_size() == (int)dstBlob.total());256257CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);258Mat(dstBlob.dims, &dstBlob.size[0], CV_32F, (void*)pbBlob.data().data()).copyTo(dstBlob);259}260else261{262// Half precision floats.263CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);264std::string raw_data = pbBlob.raw_data();265266CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());267268Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());269convertFp16(halfs, dstBlob);270}271}272273void extractBinaryLayerParams(const caffe::LayerParameter& layer, LayerParams& layerParams)274{275const std::string &name = layer.name();276277int li;278for (li = 0; li != netBinary.layer_size(); li++)279{280const caffe::LayerParameter& binLayer = netBinary.layer(li);281// Break if the layer name is the same and the blobs are not cleared282if (binLayer.name() == name && binLayer.blobs_size() != 0)283break;284}285286if (li == netBinary.layer_size())287return;288289caffe::LayerParameter* binLayer = netBinary.mutable_layer(li);290const int numBlobs = binLayer->blobs_size();291layerParams.blobs.resize(numBlobs);292for (int bi = 0; bi < numBlobs; bi++)293{294blobFromProto(binLayer->blobs(bi), layerParams.blobs[bi]);295}296binLayer->clear_blobs();297CV_Assert(numBlobs == binLayer->blobs().ClearedCount());298for (int bi = 0; bi < numBlobs; bi++)299{300delete binLayer->mutable_blobs()->ReleaseCleared();301}302}303304struct BlobNote305{306BlobNote(const std::string &_name, int _layerId, int _outNum) :307name(_name), layerId(_layerId), outNum(_outNum) {}308309std::string name;310int layerId, outNum;311};312313std::vector<BlobNote> addedBlobs;314std::map<String, int> layerCounter;315316void populateNet(Net dstNet)317{318CV_TRACE_FUNCTION();319320int layersSize = net.layer_size();321layerCounter.clear();322addedBlobs.clear();323addedBlobs.reserve(layersSize + 1);324325//setup input layer names326std::vector<String> netInputs(net.input_size());327{328for (int inNum = 0; inNum < net.input_size(); inNum++)329{330addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));331netInputs[inNum] = net.input(inNum);332}333}334335for (int li = 0; li < layersSize; li++)336{337const caffe::LayerParameter &layer = net.layer(li);338String name = layer.name();339String type = layer.type();340LayerParams layerParams;341342extractLayerParams(layer, layerParams);343extractBinaryLayerParams(layer, layerParams);344345int repetitions = layerCounter[name]++;346if (repetitions)347name += String("_") + toString(repetitions);348349if (type == "Input")350{351for (int outNum = 0; outNum < layer.top_size(); outNum++)352{353addOutput(layer, 0, outNum);354addedBlobs.back().outNum = netInputs.size();355netInputs.push_back(addedBlobs.back().name);356}357continue;358}359else if (type == "BatchNorm")360{361if (!layerParams.get<bool>("use_global_stats", true))362{363CV_Assert_N(layer.bottom_size() == 1, layer.top_size() == 1);364365LayerParams mvnParams;366mvnParams.set("eps", layerParams.get<float>("eps", 1e-5));367std::string mvnName = name + "/mvn";368369int repetitions = layerCounter[mvnName]++;370if (repetitions)371mvnName += String("_") + toString(repetitions);372373int mvnId = dstNet.addLayer(mvnName, "MVN", mvnParams);374addInput(layer.bottom(0), mvnId, 0, dstNet);375addOutput(layer, mvnId, 0);376net.mutable_layer(li)->set_bottom(0, layer.top(0));377layerParams.blobs[0].setTo(0); // mean378layerParams.blobs[1].setTo(1); // std379}380}381else if ("ConvolutionDepthwise" == type)382{383type = "Convolution";384}385386int id = dstNet.addLayer(name, type, layerParams);387388for (int inNum = 0; inNum < layer.bottom_size(); inNum++)389addInput(layer.bottom(inNum), id, inNum, dstNet);390391for (int outNum = 0; outNum < layer.top_size(); outNum++)392addOutput(layer, id, outNum);393}394dstNet.setInputsNames(netInputs);395396addedBlobs.clear();397}398399void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)400{401const std::string &name = layer.top(outNum);402403bool haveDups = false;404for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)405{406if (addedBlobs[idx].name == name)407{408haveDups = true;409break;410}411}412413if (haveDups)414{415bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;416if (!isInplace)417CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");418}419420addedBlobs.push_back(BlobNote(name, layerId, outNum));421}422423void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)424{425int idx;426for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)427{428if (addedBlobs[idx].name == name)429break;430}431432if (idx < 0)433{434CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");435return;436}437438dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);439}440};441442}443444Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)445{446CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str());447Net net;448caffeImporter.populateNet(net);449return net;450}451452Net readNetFromCaffe(const char *bufferProto, size_t lenProto,453const char *bufferModel, size_t lenModel)454{455CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);456Net net;457caffeImporter.populateNet(net);458return net;459}460461Net readNetFromCaffe(const std::vector<uchar>& bufferProto, const std::vector<uchar>& bufferModel)462{463const char* bufferProtoPtr = reinterpret_cast<const char*>(&bufferProto[0]);464const char* bufferModelPtr = bufferModel.empty() ? NULL :465reinterpret_cast<const char*>(&bufferModel[0]);466return readNetFromCaffe(bufferProtoPtr, bufferProto.size(),467bufferModelPtr, bufferModel.size());468}469470#endif //HAVE_PROTOBUF471472CV__DNN_INLINE_NS_END473}} // namespace474475476