Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/op_vkcom.cpp
16337 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
//
5
// Copyright (C) 2018, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
#include "precomp.hpp"
9
#include <opencv2/dnn/shape_utils.hpp>
10
#include "op_vkcom.hpp"
11
12
namespace cv
13
{
14
namespace dnn
15
{
16
#ifdef HAVE_VULKAN
17
void copyToTensor(vkcom::Tensor &dst, const Mat &src)
18
{
19
CV_Assert(src.isContinuous() && src.type() == CV_32F);
20
21
std::vector<int> mat_shape = shape(src);
22
dst.reshape((const char*)src.data, mat_shape);
23
}
24
25
void copyToMat(Mat &dst, vkcom::Tensor &src)
26
{
27
CV_Assert(dst.type() == CV_32F);
28
29
std::vector<int> shape = src.getShape();
30
void *data = src.map();
31
Mat tmp(shape, CV_32F, data);
32
tmp.copyTo(dst);
33
src.unMap();
34
}
35
36
vkcom::Tensor VkComTensor(const Ptr<BackendWrapper>& ptr)
37
{
38
CV_Assert(!ptr.empty());
39
return ptr.dynamicCast<VkComBackendWrapper>()->getTensor();
40
}
41
42
void setDirty(std::vector<Ptr<BackendWrapper> >& ptrs)
43
{
44
for (const Ptr<BackendWrapper>& ptr : ptrs)
45
{
46
ptr.dynamicCast<VkComBackendWrapper>()->setDeviceDirty();
47
}
48
}
49
50
std::vector<vkcom::Tensor> VkComTensors(const std::vector<Ptr<BackendWrapper> >& ptrs)
51
{
52
std::vector<vkcom::Tensor> vec;
53
vec.reserve(ptrs.size());
54
for (const Ptr<BackendWrapper>& ptr : ptrs)
55
{
56
vec.push_back(VkComTensor(ptr));
57
}
58
return vec;
59
}
60
61
VkComBackendNode::VkComBackendNode(const std::vector<Ptr<BackendWrapper> >& inputsWrapper,
62
const std::shared_ptr<vkcom::OpBase>& op,
63
const std::vector<Ptr<BackendWrapper> >& blobsWrapper)
64
: BackendNode(DNN_BACKEND_VKCOM)
65
{
66
operation = op;
67
68
inputsWrapper_ = inputsWrapper;
69
ins = VkComTensors(inputsWrapper_);
70
71
if (!blobsWrapper.empty())
72
{
73
blobs = VkComTensors(blobsWrapper);
74
}
75
}
76
77
bool VkComBackendNode::forward(std::vector<vkcom::Tensor>& outs)
78
{
79
for (int i = 0, n = inputsWrapper_.size(); i < n; ++i)
80
{
81
inputsWrapper_[i].dynamicCast<VkComBackendWrapper>()->copyToDevice();
82
}
83
84
return operation->forward(ins, blobs, outs);
85
}
86
87
VkComBackendWrapper::VkComBackendWrapper(Mat& m) : BackendWrapper(DNN_BACKEND_VKCOM, DNN_TARGET_VULKAN)
88
{
89
copyToTensor(tensor, m);
90
host = &m;
91
hostDirty = false;
92
deviceDirty = false;
93
}
94
95
VkComBackendWrapper::VkComBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m)
96
: BackendWrapper(DNN_BACKEND_VKCOM, DNN_TARGET_VULKAN)
97
{
98
Ptr<VkComBackendWrapper> base = baseBuffer.dynamicCast<VkComBackendWrapper>();
99
CV_Assert(!base.empty());
100
101
host = &m;
102
tensor = base->tensor;
103
CV_Assert(tensor.count() >= m.total());
104
tensor.reshape(0, shape(m));
105
hostDirty = false;
106
deviceDirty = false;
107
}
108
109
void VkComBackendWrapper::copyToHost()
110
{
111
if (deviceDirty)
112
copyToMat(*host, tensor);
113
}
114
115
void VkComBackendWrapper::setHostDirty()
116
{
117
hostDirty = true;
118
};
119
120
void VkComBackendWrapper::setDeviceDirty()
121
{
122
deviceDirty = true;
123
};
124
125
void VkComBackendWrapper::copyToDevice()
126
{
127
if (hostDirty)
128
{
129
copyToTensor(tensor, *host);
130
hostDirty = false;
131
}
132
}
133
134
vkcom::Tensor VkComBackendWrapper::getTensor()
135
{
136
return tensor;
137
}
138
#endif
139
void forwardVkCom(std::vector<Ptr<BackendWrapper> > &outputs,
140
const Ptr<BackendNode>& node)
141
{
142
#ifdef HAVE_VULKAN
143
CV_Assert(!node.empty());
144
145
Ptr<VkComBackendNode> node_ = node.dynamicCast<VkComBackendNode>();
146
std::vector<vkcom::Tensor> outs = VkComTensors(outputs);
147
node_->forward(outs);
148
setDirty(outputs);
149
#endif
150
}
151
152
bool haveVulkan()
153
{
154
#ifdef HAVE_VULKAN
155
return vkcom::isAvailable();
156
#else
157
return false;
158
#endif // HAVE_VULKAN
159
}
160
161
} // namespace dnn
162
} // namespace cv
163
164