Path: blob/master/modules/dnn/src/darknet/darknet_importer.cpp
16339 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// License Agreement10// For Open Source Computer Vision Library11// (3-clause BSD License)12//13// Copyright (C) 2017, Intel Corporation, all rights reserved.14// Third party copyrights are property of their respective owners.15//16// Redistribution and use in source and binary forms, with or without modification,17// are permitted provided that the following conditions are met:18//19// * Redistributions of source code must retain the above copyright notice,20// this list of conditions and the following disclaimer.21//22// * Redistributions in binary form must reproduce the above copyright notice,23// this list of conditions and the following disclaimer in the documentation24// and/or other materials provided with the distribution.25//26// * Neither the names of the copyright holders nor the names of the contributors27// may be used to endorse or promote products derived from this software28// without specific prior written permission.29//30// This software is provided by the copyright holders and contributors "as is" and31// any express or implied warranties, including, but not limited to, the implied32// warranties of merchantability and fitness for a particular purpose are disclaimed.33// In no event shall copyright holders or contributors be liable for any direct,34// indirect, incidental, special, exemplary, or consequential damages35// (including, but not limited to, procurement of substitute goods or services;36// loss of use, data, or profits; or business interruption) however caused37// and on any theory of liability, whether in contract, strict liability,38// or tort (including negligence or otherwise) arising in any way out of39// the use of this software, even if advised of the possibility of such damage.40//41//M*/4243#include "../precomp.hpp"4445#include <iostream>46#include <fstream>47#include <algorithm>48#include <vector>49#include <map>5051#include "darknet_io.hpp"525354namespace cv {55namespace dnn {56CV__DNN_INLINE_NS_BEGIN5758namespace59{6061class DarknetImporter62{63darknet::NetParameter net;6465public:6667DarknetImporter() {}6869DarknetImporter(std::istream &cfgStream, std::istream &darknetModelStream)70{71CV_TRACE_FUNCTION();7273ReadNetParamsFromCfgStreamOrDie(cfgStream, &net);74ReadNetParamsFromBinaryStreamOrDie(darknetModelStream, &net);75}7677DarknetImporter(std::istream &cfgStream)78{79CV_TRACE_FUNCTION();8081ReadNetParamsFromCfgStreamOrDie(cfgStream, &net);82}8384struct BlobNote85{86BlobNote(const std::string &_name, int _layerId, int _outNum) :87name(_name), layerId(_layerId), outNum(_outNum) {}8889std::string name;90int layerId, outNum;91};9293std::vector<BlobNote> addedBlobs;94std::map<String, int> layerCounter;9596void populateNet(Net dstNet)97{98CV_TRACE_FUNCTION();99100int layersSize = net.layer_size();101layerCounter.clear();102addedBlobs.clear();103addedBlobs.reserve(layersSize + 1);104105//setup input layer names106{107std::vector<String> netInputs(net.input_size());108for (int inNum = 0; inNum < net.input_size(); inNum++)109{110addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));111netInputs[inNum] = net.input(inNum);112}113dstNet.setInputsNames(netInputs);114}115116for (int li = 0; li < layersSize; li++)117{118const darknet::LayerParameter &layer = net.layer(li);119String name = layer.name();120String type = layer.type();121LayerParams layerParams = layer.getLayerParams();122123int repetitions = layerCounter[name]++;124if (repetitions)125name += cv::format("_%d", repetitions);126127int id = dstNet.addLayer(name, type, layerParams);128129// iterate many bottoms layers (for example for: route -1, -4)130for (int inNum = 0; inNum < layer.bottom_size(); inNum++)131addInput(layer.bottom(inNum), id, inNum, dstNet, layer.name());132133for (int outNum = 0; outNum < layer.top_size(); outNum++)134addOutput(layer, id, outNum);135}136137addedBlobs.clear();138}139140void addOutput(const darknet::LayerParameter &layer, int layerId, int outNum)141{142const std::string &name = layer.top(outNum);143144bool haveDups = false;145for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)146{147if (addedBlobs[idx].name == name)148{149haveDups = true;150break;151}152}153154if (haveDups)155{156bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;157if (!isInplace)158CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");159}160161addedBlobs.push_back(BlobNote(name, layerId, outNum));162}163164void addInput(const std::string &name, int layerId, int inNum, Net &dstNet, std::string nn)165{166int idx;167for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)168{169if (addedBlobs[idx].name == name)170break;171}172173if (idx < 0)174{175CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");176return;177}178179dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);180}181};182183static Net readNetFromDarknet(std::istream &cfgFile, std::istream &darknetModel)184{185Net net;186DarknetImporter darknetImporter(cfgFile, darknetModel);187darknetImporter.populateNet(net);188return net;189}190191static Net readNetFromDarknet(std::istream &cfgFile)192{193Net net;194DarknetImporter darknetImporter(cfgFile);195darknetImporter.populateNet(net);196return net;197}198199}200201Net readNetFromDarknet(const String &cfgFile, const String &darknetModel /*= String()*/)202{203std::ifstream cfgStream(cfgFile.c_str());204if (!cfgStream.is_open())205{206CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(cfgFile));207}208if (darknetModel != String())209{210std::ifstream darknetModelStream(darknetModel.c_str(), std::ios::binary);211if (!darknetModelStream.is_open())212{213CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(darknetModel));214}215return readNetFromDarknet(cfgStream, darknetModelStream);216}217else218return readNetFromDarknet(cfgStream);219}220221struct BufferStream : public std::streambuf222{223BufferStream(const char* s, std::size_t n)224{225char* ptr = const_cast<char*>(s);226setg(ptr, ptr, ptr + n);227}228};229230Net readNetFromDarknet(const char *bufferCfg, size_t lenCfg, const char *bufferModel, size_t lenModel)231{232BufferStream cfgBufferStream(bufferCfg, lenCfg);233std::istream cfgStream(&cfgBufferStream);234if (lenModel)235{236BufferStream weightsBufferStream(bufferModel, lenModel);237std::istream weightsStream(&weightsBufferStream);238return readNetFromDarknet(cfgStream, weightsStream);239}240else241return readNetFromDarknet(cfgStream);242}243244Net readNetFromDarknet(const std::vector<uchar>& bufferCfg, const std::vector<uchar>& bufferModel)245{246const char* bufferCfgPtr = reinterpret_cast<const char*>(&bufferCfg[0]);247const char* bufferModelPtr = bufferModel.empty() ? NULL :248reinterpret_cast<const char*>(&bufferModel[0]);249return readNetFromDarknet(bufferCfgPtr, bufferCfg.size(),250bufferModelPtr, bufferModel.size());251}252253CV__DNN_INLINE_NS_END254}} // namespace255256257