Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/models/bwd_forget_mult_cuda_kernel.cu
841 views
1
#include <ATen/ATen.h>
2
#include <THC/THC.h>
3
4
#include <cuda.h>
5
#include <cuda_runtime.h>
6
7
#include <vector>
8
9
template <typename scalar_t>
10
__global__ void bwd_forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,
11
const scalar_t* __restrict__ f, scalar_t* __restrict__ output,
12
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
13
/*
14
Note: output is assumed to be one timestep longer than f or x where output[seq_length] = h_{+1}
15
This means output array has a size of seq_length+1 on the word dimension
16
*/
17
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
18
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
19
if (hid < n_hidden && bid < batch_size){
20
for (int ts = seq_length-1; ts >= 0; ts--) {
21
int i = 0;
22
int dst_i = 0;
23
int dst_iplus1 = 0;
24
if (batch_first){
25
i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;
26
dst_i = bid * n_hidden * (seq_length+1) + (ts+0) * n_hidden + hid;
27
dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;
28
}
29
else {
30
i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
31
dst_i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
32
dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;
33
}
34
output[dst_i] = f[i] * x[i];
35
output[dst_i] += (1 - f[i]) * output[dst_iplus1];
36
}
37
}
38
}
39
40
template <typename scalar_t>
41
__global__ void bwd_forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,
42
const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,
43
const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,
44
scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,
45
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
46
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
47
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
48
double running_f = 0;
49
if(hid < n_hidden && bid < batch_size){
50
for (int ts = 0; ts < seq_length; ts++) {
51
int i = 0;
52
int dst_iplus1 = 0;
53
if (batch_first){
54
i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;
55
dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;
56
}
57
else {
58
i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
59
dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;
60
}
61
running_f += grad_output[i];
62
grad_x[i] = f[i] * running_f;
63
grad_f[i] = (x[i] - output[dst_iplus1]) * running_f;
64
// The line below is likely more numerically stable than (1 - f[i]) * running_f;
65
running_f = running_f - f[i] * running_f;
66
}
67
grad_h[bid * n_hidden + hid] = running_f;
68
}
69
}
70
71
at::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
72
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
73
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
74
const auto n_hidden = x.size(2);
75
76
const int threads = 1024;
77
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
78
AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {
79
bwd_forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
80
x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,
81
seq_length, n_hidden, batch_first);
82
}));
83
84
THCudaCheck(cudaGetLastError());
85
return output;
86
}
87
88
std::vector<at::Tensor> bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f,
89
at::Tensor output, at::Tensor grad_output, bool batch_first) {
90
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
91
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
92
const auto n_hidden = x.size(2);
93
94
auto grad_x = at::zeros_like(x);
95
auto grad_f = at::zeros_like(x);
96
auto grad_h = at::zeros({batch_size, n_hidden}, x.options());
97
98
const int threads = 1024;
99
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
100
AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {
101
bwd_forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
102
x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),
103
grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,
104
seq_length, n_hidden, batch_first);
105
}));
106
107
THCudaCheck(cudaGetLastError());
108
return {grad_x, grad_f, grad_h};
109
}
110
111
112