Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/caffe/caffe_importer.cpp
16339 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of the copyright holders may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
// (including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort (including negligence or otherwise) arising in any way out of
38
// the use of this software, even if advised of the possibility of such damage.
39
//
40
//M*/
41
42
#include "../precomp.hpp"
43
44
#ifdef HAVE_PROTOBUF
45
#include <iostream>
46
#include <fstream>
47
#include <sstream>
48
#include <algorithm>
49
#include <google/protobuf/message.h>
50
#include <google/protobuf/text_format.h>
51
#include <google/protobuf/io/zero_copy_stream_impl.h>
52
#include "caffe_io.hpp"
53
#endif
54
55
namespace cv {
56
namespace dnn {
57
CV__DNN_INLINE_NS_BEGIN
58
59
#ifdef HAVE_PROTOBUF
60
using ::google::protobuf::RepeatedField;
61
using ::google::protobuf::RepeatedPtrField;
62
using ::google::protobuf::Message;
63
using ::google::protobuf::Descriptor;
64
using ::google::protobuf::FieldDescriptor;
65
using ::google::protobuf::Reflection;
66
67
namespace
68
{
69
70
template<typename T>
71
static cv::String toString(const T &v)
72
{
73
std::ostringstream ss;
74
ss << v;
75
return ss.str();
76
}
77
78
class CaffeImporter
79
{
80
caffe::NetParameter net;
81
caffe::NetParameter netBinary;
82
83
public:
84
85
CaffeImporter(const char *pototxt, const char *caffeModel)
86
{
87
CV_TRACE_FUNCTION();
88
89
ReadNetParamsFromTextFileOrDie(pototxt, &net);
90
91
if (caffeModel && caffeModel[0])
92
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
93
}
94
95
CaffeImporter(const char *dataProto, size_t lenProto,
96
const char *dataModel, size_t lenModel)
97
{
98
CV_TRACE_FUNCTION();
99
100
ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);
101
102
if (dataModel != NULL && lenModel > 0)
103
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
104
}
105
106
void extractCustomParams(const google::protobuf::UnknownFieldSet& unknownFields, cv::dnn::LayerParams &params)
107
{
108
const int numFields = unknownFields.field_count();
109
for (int i = 0; i < numFields; ++i)
110
{
111
const google::protobuf::UnknownField& field = unknownFields.field(i);
112
CV_Assert(field.type() == google::protobuf::UnknownField::TYPE_GROUP);
113
std::string fieldName = field.group().field(0).length_delimited();
114
std::string fieldValue = field.group().field(1).length_delimited();
115
params.set(fieldName, fieldValue);
116
}
117
}
118
119
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
120
{
121
const Reflection *refl = msg.GetReflection();
122
int type = field->cpp_type();
123
bool isRepeated = field->is_repeated();
124
const std::string &name = field->name();
125
126
#define SET_UP_FILED(getter, arrayConstr, gtype) \
127
if (isRepeated) { \
128
const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \
129
params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \
130
} \
131
else { \
132
params.set(name, refl->getter(msg, field)); \
133
}
134
135
switch (type)
136
{
137
case FieldDescriptor::CPPTYPE_INT32:
138
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
139
break;
140
case FieldDescriptor::CPPTYPE_UINT32:
141
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
142
break;
143
case FieldDescriptor::CPPTYPE_INT64:
144
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
145
break;
146
case FieldDescriptor::CPPTYPE_UINT64:
147
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
148
break;
149
case FieldDescriptor::CPPTYPE_BOOL:
150
SET_UP_FILED(GetBool, arrayInt, bool);
151
break;
152
case FieldDescriptor::CPPTYPE_DOUBLE:
153
SET_UP_FILED(GetDouble, arrayReal, double);
154
break;
155
case FieldDescriptor::CPPTYPE_FLOAT:
156
SET_UP_FILED(GetFloat, arrayReal, float);
157
break;
158
case FieldDescriptor::CPPTYPE_STRING:
159
if (isRepeated) {
160
const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
161
params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
162
}
163
else {
164
params.set(name, refl->GetString(msg, field));
165
}
166
break;
167
case FieldDescriptor::CPPTYPE_ENUM:
168
if (isRepeated) {
169
int size = refl->FieldSize(msg, field);
170
std::vector<cv::String> buf(size);
171
for (int i = 0; i < size; i++)
172
buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
173
params.set(name, DictValue::arrayString(buf.begin(), size));
174
}
175
else {
176
params.set(name, refl->GetEnum(msg, field)->name());
177
}
178
break;
179
default:
180
CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
181
break;
182
}
183
}
184
185
inline static bool ends_with_param(const std::string &str)
186
{
187
static const std::string _param("_param");
188
return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
189
}
190
191
void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false)
192
{
193
const Descriptor *msgDesc = msg.GetDescriptor();
194
const Reflection *msgRefl = msg.GetReflection();
195
196
for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)
197
{
198
const FieldDescriptor *fd = msgDesc->field(fieldId);
199
200
if (!isInternal && !ends_with_param(fd->name()))
201
continue;
202
203
const google::protobuf::UnknownFieldSet& unknownFields = msgRefl->GetUnknownFields(msg);
204
bool hasData = fd->is_required() ||
205
(fd->is_optional() && msgRefl->HasField(msg, fd)) ||
206
(fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0) ||
207
!unknownFields.empty();
208
if (!hasData)
209
continue;
210
211
extractCustomParams(unknownFields, params);
212
if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
213
{
214
if (fd->is_repeated()) //Extract only first item!
215
extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
216
else
217
extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
218
}
219
else
220
{
221
addParam(msg, fd, params);
222
}
223
}
224
}
225
226
void blobShapeFromProto(const caffe::BlobProto &pbBlob, MatShape& shape)
227
{
228
shape.clear();
229
if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
230
{
231
shape.push_back(pbBlob.num());
232
shape.push_back(pbBlob.channels());
233
shape.push_back(pbBlob.height());
234
shape.push_back(pbBlob.width());
235
}
236
else if (pbBlob.has_shape())
237
{
238
const caffe::BlobShape &_shape = pbBlob.shape();
239
240
for (int i = 0; i < _shape.dim_size(); i++)
241
shape.push_back((int)_shape.dim(i));
242
}
243
else
244
shape.resize(1, 1); // Is a scalar.
245
}
246
247
void blobFromProto(const caffe::BlobProto &pbBlob, cv::Mat &dstBlob)
248
{
249
MatShape shape;
250
blobShapeFromProto(pbBlob, shape);
251
252
dstBlob.create((int)shape.size(), &shape[0], CV_32F);
253
if (pbBlob.data_size())
254
{
255
// Single precision floats.
256
CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
257
258
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
259
Mat(dstBlob.dims, &dstBlob.size[0], CV_32F, (void*)pbBlob.data().data()).copyTo(dstBlob);
260
}
261
else
262
{
263
// Half precision floats.
264
CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);
265
std::string raw_data = pbBlob.raw_data();
266
267
CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());
268
269
Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());
270
convertFp16(halfs, dstBlob);
271
}
272
}
273
274
void extractBinaryLayerParams(const caffe::LayerParameter& layer, LayerParams& layerParams)
275
{
276
const std::string &name = layer.name();
277
278
int li;
279
for (li = 0; li != netBinary.layer_size(); li++)
280
{
281
const caffe::LayerParameter& binLayer = netBinary.layer(li);
282
// Break if the layer name is the same and the blobs are not cleared
283
if (binLayer.name() == name && binLayer.blobs_size() != 0)
284
break;
285
}
286
287
if (li == netBinary.layer_size())
288
return;
289
290
caffe::LayerParameter* binLayer = netBinary.mutable_layer(li);
291
const int numBlobs = binLayer->blobs_size();
292
layerParams.blobs.resize(numBlobs);
293
for (int bi = 0; bi < numBlobs; bi++)
294
{
295
blobFromProto(binLayer->blobs(bi), layerParams.blobs[bi]);
296
}
297
binLayer->clear_blobs();
298
CV_Assert(numBlobs == binLayer->blobs().ClearedCount());
299
for (int bi = 0; bi < numBlobs; bi++)
300
{
301
delete binLayer->mutable_blobs()->ReleaseCleared();
302
}
303
}
304
305
struct BlobNote
306
{
307
BlobNote(const std::string &_name, int _layerId, int _outNum) :
308
name(_name), layerId(_layerId), outNum(_outNum) {}
309
310
std::string name;
311
int layerId, outNum;
312
};
313
314
std::vector<BlobNote> addedBlobs;
315
std::map<String, int> layerCounter;
316
317
void populateNet(Net dstNet)
318
{
319
CV_TRACE_FUNCTION();
320
321
int layersSize = net.layer_size();
322
layerCounter.clear();
323
addedBlobs.clear();
324
addedBlobs.reserve(layersSize + 1);
325
326
//setup input layer names
327
std::vector<String> netInputs(net.input_size());
328
{
329
for (int inNum = 0; inNum < net.input_size(); inNum++)
330
{
331
addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
332
netInputs[inNum] = net.input(inNum);
333
}
334
}
335
336
for (int li = 0; li < layersSize; li++)
337
{
338
const caffe::LayerParameter &layer = net.layer(li);
339
String name = layer.name();
340
String type = layer.type();
341
LayerParams layerParams;
342
343
extractLayerParams(layer, layerParams);
344
extractBinaryLayerParams(layer, layerParams);
345
346
int repetitions = layerCounter[name]++;
347
if (repetitions)
348
name += String("_") + toString(repetitions);
349
350
if (type == "Input")
351
{
352
for (int outNum = 0; outNum < layer.top_size(); outNum++)
353
{
354
addOutput(layer, 0, outNum);
355
addedBlobs.back().outNum = netInputs.size();
356
netInputs.push_back(addedBlobs.back().name);
357
}
358
continue;
359
}
360
else if (type == "BatchNorm")
361
{
362
if (!layerParams.get<bool>("use_global_stats", true))
363
{
364
CV_Assert_N(layer.bottom_size() == 1, layer.top_size() == 1);
365
366
LayerParams mvnParams;
367
mvnParams.set("eps", layerParams.get<float>("eps", 1e-5));
368
std::string mvnName = name + "/mvn";
369
370
int repetitions = layerCounter[mvnName]++;
371
if (repetitions)
372
mvnName += String("_") + toString(repetitions);
373
374
int mvnId = dstNet.addLayer(mvnName, "MVN", mvnParams);
375
addInput(layer.bottom(0), mvnId, 0, dstNet);
376
addOutput(layer, mvnId, 0);
377
net.mutable_layer(li)->set_bottom(0, layer.top(0));
378
layerParams.blobs[0].setTo(0); // mean
379
layerParams.blobs[1].setTo(1); // std
380
}
381
}
382
else if ("ConvolutionDepthwise" == type)
383
{
384
type = "Convolution";
385
}
386
387
int id = dstNet.addLayer(name, type, layerParams);
388
389
for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
390
addInput(layer.bottom(inNum), id, inNum, dstNet);
391
392
for (int outNum = 0; outNum < layer.top_size(); outNum++)
393
addOutput(layer, id, outNum);
394
}
395
dstNet.setInputsNames(netInputs);
396
397
addedBlobs.clear();
398
}
399
400
void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)
401
{
402
const std::string &name = layer.top(outNum);
403
404
bool haveDups = false;
405
for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
406
{
407
if (addedBlobs[idx].name == name)
408
{
409
haveDups = true;
410
break;
411
}
412
}
413
414
if (haveDups)
415
{
416
bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
417
if (!isInplace)
418
CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
419
}
420
421
addedBlobs.push_back(BlobNote(name, layerId, outNum));
422
}
423
424
void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)
425
{
426
int idx;
427
for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
428
{
429
if (addedBlobs[idx].name == name)
430
break;
431
}
432
433
if (idx < 0)
434
{
435
CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");
436
return;
437
}
438
439
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
440
}
441
};
442
443
}
444
445
Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)
446
{
447
CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str());
448
Net net;
449
caffeImporter.populateNet(net);
450
return net;
451
}
452
453
Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
454
const char *bufferModel, size_t lenModel)
455
{
456
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
457
Net net;
458
caffeImporter.populateNet(net);
459
return net;
460
}
461
462
Net readNetFromCaffe(const std::vector<uchar>& bufferProto, const std::vector<uchar>& bufferModel)
463
{
464
const char* bufferProtoPtr = reinterpret_cast<const char*>(&bufferProto[0]);
465
const char* bufferModelPtr = bufferModel.empty() ? NULL :
466
reinterpret_cast<const char*>(&bufferModel[0]);
467
return readNetFromCaffe(bufferProtoPtr, bufferProto.size(),
468
bufferModelPtr, bufferModel.size());
469
}
470
471
#endif //HAVE_PROTOBUF
472
473
CV__DNN_INLINE_NS_END
474
}} // namespace
475
476