Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/layers/shuffle_channel_layer.cpp
16337 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
5
// Copyright (C) 2018, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
#include "../precomp.hpp"
8
9
namespace cv { namespace dnn {
10
11
class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer
12
{
13
public:
14
ShuffleChannelLayerImpl(const LayerParams& params)
15
{
16
group = params.get<int>("group", 1);
17
setParamsFrom(params);
18
}
19
20
bool getMemoryShapes(const std::vector<MatShape> &inputs,
21
const int requiredOutputs,
22
std::vector<MatShape> &outputs,
23
std::vector<MatShape> &internals) const CV_OVERRIDE
24
{
25
CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);
26
CV_Assert(inputs[0][1] % group == 0);
27
Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);
28
return group == 1;
29
}
30
31
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
32
{
33
if (group != 1)
34
{
35
std::vector<Mat> inputs, outputs;
36
inputs_arr.getMatVector(inputs);
37
outputs_arr.getMatVector(outputs);
38
39
LayerParams lp;
40
float order[] = {0, 2, 1, 3};
41
lp.set("order", DictValue::arrayInt(&order[0], 4));
42
permute = PermuteLayer::create(lp);
43
44
const Mat& inp = inputs[0];
45
const Mat& out = outputs[0];
46
47
permuteInpShape.resize(4);
48
permuteInpShape[0] = inp.size[0];
49
permuteInpShape[1] = group;
50
permuteInpShape[2] = inp.size[1] / group;
51
permuteInpShape[3] = inp.size[2]*inp.size[3];
52
53
permuteOutShape.resize(4);
54
permuteOutShape[0] = permuteInpShape[0];
55
permuteOutShape[1] = permuteInpShape[2];
56
permuteOutShape[2] = permuteInpShape[1];
57
permuteOutShape[3] = permuteInpShape[3];
58
59
std::vector<Mat> permuteInputs(1, inp.reshape(1, permuteInpShape));
60
std::vector<Mat> permuteOutputs(1, out.reshape(1, permuteOutShape));
61
permute->finalize(permuteInputs, permuteOutputs);
62
}
63
}
64
65
#ifdef HAVE_OPENCL
66
bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
67
{
68
std::vector<UMat> inputs;
69
std::vector<UMat> outputs;
70
71
inps.getUMatVector(inputs);
72
outs.getUMatVector(outputs);
73
74
if (inputs[0].u != outputs[0].u)
75
{
76
if (!permute.empty())
77
{
78
inputs[0] = inputs[0].reshape(1, permuteInpShape.size(), &permuteInpShape[0]);
79
outputs[0] = outputs[0].reshape(1, permuteOutShape.size(), &permuteOutShape[0]);
80
permute->preferableTarget = preferableTarget;
81
permute->forward(inputs, outputs, internals);
82
}
83
else
84
inputs[0].copyTo(outputs[0]);
85
}
86
return true;
87
}
88
#endif
89
90
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
91
{
92
CV_TRACE_FUNCTION();
93
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
94
95
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
96
forward_ocl(inputs_arr, outputs_arr, internals_arr))
97
98
if (inputs_arr.depth() == CV_16S)
99
{
100
forward_fallback(inputs_arr, outputs_arr, internals_arr);
101
return;
102
}
103
104
std::vector<Mat> inputs, outputs, internals;
105
inputs_arr.getMatVector(inputs);
106
outputs_arr.getMatVector(outputs);
107
internals_arr.getMatVector(internals);
108
109
Mat inp = inputs[0];
110
Mat out = outputs[0];
111
if (inp.data != out.data)
112
{
113
if (!permute.empty())
114
{
115
inp = inp.reshape(1, permuteInpShape);
116
out = out.reshape(1, permuteOutShape);
117
std::vector<Mat> permuteInputs(1, inp);
118
std::vector<Mat> permuteOutputs(1, out);
119
permute->forward(permuteInputs, permuteOutputs, internals);
120
}
121
else
122
inp.copyTo(out);
123
}
124
}
125
126
private:
127
Ptr<PermuteLayer> permute;
128
std::vector<int> permuteInpShape, permuteOutShape;
129
};
130
131
Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)
132
{
133
return Ptr<Layer>(new ShuffleChannelLayerImpl(params));
134
}
135
136
} // namespace dnn
137
} // namespace cv
138
139