Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/halide_scheduler.cpp
16337 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
//
5
// Copyright (C) 2017, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
#include "precomp.hpp"
9
#include "halide_scheduler.hpp"
10
#include "op_halide.hpp"
11
12
namespace cv
13
{
14
namespace dnn
15
{
16
17
#ifdef HAVE_HALIDE
18
static void applySplit(const FileNode& directive, Halide::Func& func,
19
const FileNode& params)
20
{
21
for (const auto& varNode : directive)
22
{
23
const std::string varName = varNode.name();
24
const std::string factorName = (std::string)varNode;
25
Halide::Var var(varName);
26
Halide::Var outerVar(varName + "o");
27
Halide::Var innerVar(varName + "i");
28
// If split factor is integer or parameters map has parameter value.
29
CV_Assert(varNode.isString() && !params[factorName].empty() ||
30
varNode.isInt());
31
int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
32
func.split(var, outerVar, innerVar, factor);
33
}
34
}
35
36
static void applyReorder(const FileNode& directive, Halide::Func& func)
37
{
38
std::string varName;
39
const int numVars = directive.size();
40
std::vector<Halide::VarOrRVar> reorderedVars;
41
reorderedVars.reserve(numVars);
42
for (int i = 0; i < numVars; ++i)
43
{
44
directive[i] >> varName;
45
reorderedVars.push_back(Halide::Var(varName));
46
}
47
func.reorder(reorderedVars);
48
}
49
50
static void applyFuse(const FileNode& directive, Halide::Func& func)
51
{
52
CV_Assert(directive["src"].size() >= 2);
53
CV_Assert(directive["dst"].size() == 1);
54
55
std::string str;
56
directive["src"][0] >> str;
57
Halide::Var firstVar(str);
58
directive["src"][1] >> str;
59
Halide::Var secondVar(str);
60
directive["dst"] >> str;
61
Halide::Var dstVar(str);
62
63
func.fuse(firstVar, secondVar, dstVar);
64
for (int i = 2, n = directive["src"].size(); i < n; ++i)
65
{
66
directive["src"][i] >> str;
67
func.fuse(Halide::Var(str), dstVar, dstVar);
68
}
69
}
70
71
static void applyParallel(const FileNode& directive, Halide::Func& func)
72
{
73
std::string varName;
74
for (int i = 0, n = directive.size(); i < n; ++i)
75
{
76
directive[i] >> varName;
77
func.parallel(Halide::Var(varName));
78
}
79
}
80
81
static void applyUnroll(const FileNode& directive, Halide::Func& func)
82
{
83
std::string varName;
84
for (int i = 0, n = directive.size(); i < n; ++i)
85
{
86
directive[i] >> varName;
87
func.unroll(Halide::Var(varName));
88
}
89
}
90
91
static void applyVectorize(const FileNode& directive, Halide::Func& func,
92
const FileNode& params)
93
{
94
for (const auto& varNode : directive)
95
{
96
const std::string varName = varNode.name();
97
const std::string factorName = (std::string)varNode;
98
// If split factor is integer or parameters map has parameter value.
99
CV_Assert(varNode.isString() && !params[factorName].empty() ||
100
varNode.isInt());
101
int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
102
Halide::Var var(varName);
103
Halide::Var inner(varName + "v");
104
func.split(var, var, inner, factor);
105
func.vectorize(inner);
106
}
107
}
108
109
static void applyStoreAt(const FileNode& directive, Halide::Func& func,
110
std::map<std::string, Halide::Func>& funcsMap)
111
{
112
for (const auto& funcNode : directive)
113
{
114
const std::string targetFuncName = funcNode.name();
115
if (funcsMap.find(targetFuncName) == funcsMap.end())
116
CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
117
" is not represented in Halide pipeline");
118
Halide::Func targetFunc = funcsMap[targetFuncName];
119
func.store_at(targetFunc, (std::string)funcNode);
120
break;
121
}
122
}
123
124
static void applyComputeAt(const FileNode& directive, Halide::Func& func,
125
std::map<std::string, Halide::Func>& funcsMap)
126
{
127
for (const auto& funcNode : directive)
128
{
129
const std::string targetFuncName = funcNode.name();
130
if (funcsMap.find(targetFuncName) == funcsMap.end())
131
CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
132
" is not represented in Halide pipeline");
133
Halide::Func targetFunc = funcsMap[targetFuncName];
134
func.compute_at(targetFunc, (std::string)funcNode);
135
break;
136
}
137
}
138
139
static void applyComputeRoot(const FileNode& directive, Halide::Func& func)
140
{
141
bool compute_root;
142
directive >> compute_root;
143
if (compute_root)
144
func.compute_root();
145
}
146
147
static void applyGpuBlocks(const FileNode& directive, Halide::Func& func)
148
{
149
std::string varName;
150
for (int i = 0, n = directive.size(); i < n; ++i)
151
{
152
directive[i] >> varName;
153
func.gpu_blocks(Halide::Var(varName));
154
}
155
}
156
157
static void applyGpuThreads(const FileNode& directive, Halide::Func& func)
158
{
159
std::string varName;
160
for (int i = 0, n = directive.size(); i < n; ++i)
161
{
162
directive[i] >> varName;
163
func.gpu_threads(Halide::Var(varName));
164
}
165
}
166
167
static void apply(const FileNode& directives, Halide::Func& func,
168
std::map<std::string, Halide::Func>& funcsMap,
169
const FileNode& params)
170
{
171
for (const auto& directive : directives)
172
{
173
if (directive.name() == "split")
174
applySplit(directive, func, params);
175
else if (directive.name() == "reorder")
176
applyReorder(directive, func);
177
else if (directive.name() == "fuse")
178
applyFuse(directive, func);
179
else if (directive.name() == "parallel")
180
applyParallel(directive, func);
181
else if (directive.name() == "unroll")
182
applyUnroll(directive, func);
183
else if (directive.name() == "vectorize")
184
applyVectorize(directive, func, params);
185
else if (directive.name() == "store_at")
186
applyStoreAt(directive, func, funcsMap);
187
else if (directive.name() == "compute_at")
188
applyComputeAt(directive, func, funcsMap);
189
else if (directive.name() == "compute_root")
190
applyComputeRoot(directive, func);
191
else if (directive.name() == "gpu_blocks")
192
applyGpuBlocks(directive, func);
193
else if (directive.name() == "gpu_threads")
194
applyGpuThreads(directive, func);
195
else
196
CV_Error(Error::StsNotImplemented, "Scheduling directive " +
197
directive.name() + " is not implemented.");
198
}
199
}
200
201
// Remove any numeric symbols after '$' sign.
202
static std::string Deunique(std::string str)
203
{
204
int pos = -1;
205
do
206
{
207
pos = str.find('$');
208
if (pos != -1)
209
{
210
int len = str.find_first_not_of("0123456789", pos + 1) - pos;
211
str = str.replace(pos, len, "");
212
}
213
}
214
while (pos != -1);
215
return str;
216
}
217
#endif // HAVE_HALIDE
218
219
HalideScheduler::HalideScheduler(const std::string& configFile)
220
{
221
if (!configFile.empty())
222
fs = FileStorage(configFile, FileStorage::READ);
223
}
224
225
HalideScheduler::~HalideScheduler()
226
{
227
if (fs.isOpened())
228
fs.release();
229
}
230
231
bool HalideScheduler::process(Ptr<BackendNode>& node)
232
{
233
#ifdef HAVE_HALIDE
234
if (!fs.isOpened())
235
return false;
236
237
const FileNode& scheduleNode = fs["scheduling"];
238
if (scheduleNode.empty())
239
CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node");
240
241
std::string str;
242
std::map<std::string, Halide::Func> funcsMap; // Scheduled functions.
243
// For every function, from top to bottom, we try to find a scheduling node.
244
// Scheduling is successful (return true) if for the first function (top)
245
// node is represented.
246
CV_Assert(!node.empty());
247
std::vector<Halide::Func>& funcs = node.dynamicCast<HalideBackendNode>()->funcs;
248
for (int i = funcs.size() - 1; i >= 0; --i)
249
{
250
Halide::Func& func = funcs[i];
251
// For functions with the same name Halide generates unique names
252
// for example func, func$1, func$2.
253
// They are always formed with '$' and number.
254
std::string funcName = Deunique(func.name());
255
256
const FileNode& funcNode = scheduleNode[funcName];
257
if (!funcNode.empty())
258
{
259
if (!funcNode["pattern"].empty())
260
{
261
funcNode["pattern"] >> str;
262
if (fs["patterns"][str].empty())
263
CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str +
264
" is not defined");
265
apply(fs["patterns"][str], func, funcsMap, funcNode["params"]);
266
}
267
else
268
{
269
apply(funcNode, func, funcsMap, funcNode["params"]);
270
}
271
}
272
else
273
{
274
if (funcsMap.empty())
275
return false;
276
}
277
funcsMap[funcName] = func;
278
}
279
return true;
280
#endif // HAVE_HALIDE
281
return false;
282
}
283
284
} // namespace dnn
285
} // namespace cv
286
287