Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/vkcom/include/op_pool.hpp
16344 views
1
// This file is part of OpenCV project.
2
// It is subject to the license terms in the LICENSE file found in the top-level directory
3
// of this distribution and at http://opencv.org/license.html.
4
//
5
// Copyright (C) 2018, Intel Corporation, all rights reserved.
6
// Third party copyrights are property of their respective owners.
7
8
#ifndef OPENCV_DNN_VKCOM_OP_POOL_HPP
9
#define OPENCV_DNN_VKCOM_OP_POOL_HPP
10
11
#include "vkcom.hpp"
12
#include "op_base.hpp"
13
14
namespace cv { namespace dnn { namespace vkcom {
15
16
#ifdef HAVE_VULKAN
17
18
enum PoolType { kPoolTypeAvg, kPoolTypeMax, kPoolTypeNum };
19
20
struct PoolShaderConfig
21
{
22
int local_size_x;
23
int local_size_y;
24
int local_size_z;
25
int block_height;
26
int block_width;
27
int block_depth;
28
};
29
30
class OpPool: public OpBase
31
{
32
public:
33
OpPool(const int* filter_size, const int* pad, const int* stride,
34
const int padding_mode, const PoolType pool_type,
35
const bool avg_pool_padded_area);
36
bool forward(Tensor& in, Tensor& out, Tensor& mask);
37
void reshapeOutTensor(Tensor& in, Tensor& out);
38
virtual bool forward(std::vector<Tensor>& ins,
39
std::vector<Tensor>& blobs,
40
std::vector<Tensor>& outs) CV_OVERRIDE;
41
private:
42
bool init(const int* filter_size, const int* pad, const int* stride,
43
const int padding_mode, const PoolType type, const bool avg_pool_padded_area);
44
bool computeGroupCount();
45
46
int batch_;
47
int channels_;
48
int in_height_;
49
int in_width_;
50
int out_height_;
51
int out_width_;
52
int filter_height_;
53
int filter_width_;
54
int stride_height_;
55
int stride_width_;
56
int padding_left_;
57
int padding_top_;
58
PoolType pool_type_;
59
int avg_pool_padded_area_;
60
int need_mask_;
61
PaddingMode padding_mode_;
62
int activation_;
63
PoolShaderConfig config_;
64
};
65
66
#endif // HAVE_VULKAN
67
68
}}} // namespace cv::dnn::vkcom
69
70
#endif // OPENCV_DNN_VKCOM_OP_POOL_HPP
71
72