Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/caffe/caffe_shrinker.cpp
16339 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
//
5
// Copyright (C) 2017, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
#include "../precomp.hpp"
9
10
#ifdef HAVE_PROTOBUF
11
#include <fstream>
12
#include "caffe_io.hpp"
13
#endif
14
15
namespace cv { namespace dnn {
16
CV__DNN_INLINE_NS_BEGIN
17
18
#ifdef HAVE_PROTOBUF
19
20
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& layersTypes)
21
{
22
CV_TRACE_FUNCTION();
23
24
std::vector<String> types(layersTypes);
25
if (types.empty())
26
{
27
types.push_back("Convolution");
28
types.push_back("InnerProduct");
29
}
30
31
caffe::NetParameter net;
32
ReadNetParamsFromBinaryFileOrDie(src.c_str(), &net);
33
34
for (int i = 0; i < net.layer_size(); ++i)
35
{
36
caffe::LayerParameter* lp = net.mutable_layer(i);
37
if (std::find(types.begin(), types.end(), lp->type()) == types.end())
38
{
39
continue;
40
}
41
for (int j = 0; j < lp->blobs_size(); ++j)
42
{
43
caffe::BlobProto* blob = lp->mutable_blobs(j);
44
CV_Assert(blob->data_size() != 0); // float32 array.
45
46
Mat floats(1, blob->data_size(), CV_32FC1, (void*)blob->data().data());
47
Mat halfs(1, blob->data_size(), CV_16SC1);
48
convertFp16(floats, halfs); // Convert to float16.
49
50
blob->clear_data(); // Clear float32 data.
51
52
// Set float16 data.
53
blob->set_raw_data(halfs.data, halfs.total() * halfs.elemSize());
54
blob->set_raw_data_type(caffe::FLOAT16);
55
}
56
}
57
#if GOOGLE_PROTOBUF_VERSION < 3005000
58
size_t msgSize = saturate_cast<size_t>(net.ByteSize());
59
#else
60
size_t msgSize = net.ByteSizeLong();
61
#endif
62
std::vector<uint8_t> output(msgSize);
63
net.SerializeWithCachedSizesToArray(&output[0]);
64
65
std::ofstream ofs(dst.c_str(), std::ios::binary);
66
ofs.write((const char*)&output[0], msgSize);
67
ofs.close();
68
}
69
70
#else
71
72
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& types)
73
{
74
CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models");
75
}
76
77
#endif // HAVE_PROTOBUF
78
79
CV__DNN_INLINE_NS_END
80
}} // namespace
81
82