Path: blob/master/fastai/text/models/bwd_forget_mult_cuda_kernel.cu
841 views
#include <ATen/ATen.h>1#include <THC/THC.h>23#include <cuda.h>4#include <cuda_runtime.h>56#include <vector>78template <typename scalar_t>9__global__ void bwd_forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,10const scalar_t* __restrict__ f, scalar_t* __restrict__ output,11size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {12/*13Note: output is assumed to be one timestep longer than f or x where output[seq_length] = h_{+1}14This means output array has a size of seq_length+1 on the word dimension15*/16const int hid = blockIdx.x * blockDim.x + threadIdx.x;17const int bid = blockIdx.y * blockDim.y + threadIdx.y;18if (hid < n_hidden && bid < batch_size){19for (int ts = seq_length-1; ts >= 0; ts--) {20int i = 0;21int dst_i = 0;22int dst_iplus1 = 0;23if (batch_first){24i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;25dst_i = bid * n_hidden * (seq_length+1) + (ts+0) * n_hidden + hid;26dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;27}28else {29i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;30dst_i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;31dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;32}33output[dst_i] = f[i] * x[i];34output[dst_i] += (1 - f[i]) * output[dst_iplus1];35}36}37}3839template <typename scalar_t>40__global__ void bwd_forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,41const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,42const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,43scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,44size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {45const int hid = blockIdx.x * blockDim.x + threadIdx.x;46const int bid = blockIdx.y * blockDim.y + threadIdx.y;47double running_f = 0;48if(hid < n_hidden && bid < batch_size){49for (int ts = 0; ts < seq_length; ts++) {50int i = 0;51int dst_iplus1 = 0;52if (batch_first){53i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;54dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;55}56else {57i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;58dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;59}60running_f += grad_output[i];61grad_x[i] = f[i] * running_f;62grad_f[i] = (x[i] - output[dst_iplus1]) * running_f;63// The line below is likely more numerically stable than (1 - f[i]) * running_f;64running_f = running_f - f[i] * running_f;65}66grad_h[bid * n_hidden + hid] = running_f;67}68}6970at::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {71const auto batch_size = (batch_first) ? x.size(0) : x.size(1);72const auto seq_length = (batch_first) ? x.size(1) : x.size(0);73const auto n_hidden = x.size(2);7475const int threads = 1024;76const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);77AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {78bwd_forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(79x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,80seq_length, n_hidden, batch_first);81}));8283THCudaCheck(cudaGetLastError());84return output;85}8687std::vector<at::Tensor> bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f,88at::Tensor output, at::Tensor grad_output, bool batch_first) {89const auto batch_size = (batch_first) ? x.size(0) : x.size(1);90const auto seq_length = (batch_first) ? x.size(1) : x.size(0);91const auto n_hidden = x.size(2);9293auto grad_x = at::zeros_like(x);94auto grad_f = at::zeros_like(x);95auto grad_h = at::zeros({batch_size, n_hidden}, x.options());9697const int threads = 1024;98const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);99AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {100bwd_forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(101x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),102grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,103seq_length, n_hidden, batch_first);104}));105106THCudaCheck(cudaGetLastError());107return {grad_x, grad_f, grad_h};108}109110111112