Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/vkcom/include/op_conv.hpp
16344 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
//
5
// Copyright (C) 2018, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
#ifndef OPENCV_DNN_VKCOM_OP_CONV_HPP
9
#define OPENCV_DNN_VKCOM_OP_CONV_HPP
10
11
#include "vkcom.hpp"
12
#include "op_base.hpp"
13
14
namespace cv { namespace dnn { namespace vkcom {
15
16
#ifdef HAVE_VULKAN
17
18
enum ConvShaderType
19
{
20
kConvShaderTypeBasic = 0,
21
kConvShaderTypeIDLF = 1,
22
kConvShaderTypeNum
23
};
24
25
struct ConvShaderConfig
26
{
27
int local_size_x;
28
int local_size_y;
29
int local_size_z;
30
int block_height;
31
int block_width;
32
int block_depth;
33
ConvShaderType shader_type;
34
};
35
36
class OpConv : public OpBase
37
{
38
public:
39
OpConv(const int out_channel, const bool has_bias,
40
const int* filter_size, const int* pad,
41
const int* stride, const int* dilation,
42
const int activation, const int group,
43
const int padding_mode);
44
void reshapeOutTensor(Tensor& in, Tensor& out);
45
bool forward(Tensor& in, Tensor& filter_weights, Tensor& bias, Tensor& out);
46
virtual bool forward(std::vector<Tensor>& ins,
47
std::vector<Tensor>& blobs,
48
std::vector<Tensor>& outs) CV_OVERRIDE;
49
private:
50
bool init(const int out_channel, const bool has_bias,
51
const int* filter_size, const int* pad,
52
const int* stride, const int* dilation,
53
const int activation, const int group,
54
const int padding_mode);
55
bool computeGroupCount();
56
57
int batch_;
58
int in_height_;
59
int in_width_;
60
int in_channel_;
61
int out_height_;
62
int out_width_;
63
int out_channel_;
64
int filter_height_;
65
int filter_width_;
66
int stride_height_;
67
int stride_width_;
68
int padding_top_;
69
int padding_left_;
70
int dilation_height_;
71
int dilation_width_;
72
int activation_;
73
PaddingMode padding_mode_;
74
int group_;
75
int has_bias_;
76
Tensor swizzled_weights;
77
ConvShaderConfig config_;
78
bool dwconv_;
79
};
80
81
#endif // HAVE_VULKAN
82
83
}}} // namespace cv::dnn::vkcom
84
85
#endif // OPENCV_DNN_VKCOM_OP_CONV_HPP
86
87