Path: blob/master/fastai/text/models/forget_mult_cuda_kernel.cu
840 views
#include <ATen/ATen.h>1#include <THC/THC.h>23#include <cuda.h>4#include <cuda_runtime.h>56#include <vector>78template <typename scalar_t>9__global__ void forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,10const scalar_t* __restrict__ f, scalar_t* __restrict__ output,11size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {12/*13Note: output is assumed to be one timestep longer than f or x where output[0] = h_{-1}14This means output array has a size of seq_length+1 on the word dimension15*/16const int hid = blockIdx.x * blockDim.x + threadIdx.x;17const int bid = blockIdx.y * blockDim.y + threadIdx.y;18if (hid < n_hidden && bid < batch_size){19for (int ts = 1; ts < seq_length + 1; ts++) {20int i = 0;21int dst_i = 0;22int dst_iminus1 = 0;23if (batch_first){24i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;25dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;26dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;27}28else {29i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;30dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;31dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;32}33output[dst_i] = f[i] * x[i];34output[dst_i] += (1 - f[i]) * output[dst_iminus1];35}36}37}3839template <typename scalar_t>40__global__ void forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,41const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,42const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,43scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,44size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {45const int hid = blockIdx.x * blockDim.x + threadIdx.x;46const int bid = blockIdx.y * blockDim.y + threadIdx.y;47double running_f = 0;48if(hid < n_hidden && bid < batch_size){49for (int ts = seq_length; ts >= 0 + 1; ts--) {50int i = 0;51int dst_i = 0;52int dst_iminus1 = 0;53if (batch_first){54i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;55dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;56dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;57}58else {59i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;60dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;61dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;62}63running_f += grad_output[i];64grad_x[i] = f[i] * running_f;65grad_f[i] = (x[i] - output[dst_iminus1]) * running_f;66// The line below is likely more numerically stable than (1 - f[i]) * running_f;67running_f = running_f - f[i] * running_f;68}69grad_h[bid * n_hidden + hid] = running_f;70}71}7273at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {74const auto batch_size = (batch_first) ? x.size(0) : x.size(1);75const auto seq_length = (batch_first) ? x.size(1) : x.size(0);76const auto n_hidden = x.size(2);7778const int threads = 1024;79const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);80AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {81forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(82x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,83seq_length, n_hidden, batch_first);84}));8586THCudaCheck(cudaGetLastError());87return output;88}8990std::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f,91at::Tensor output, at::Tensor grad_output, bool batch_first) {92const auto batch_size = (batch_first) ? x.size(0) : x.size(1);93const auto seq_length = (batch_first) ? x.size(1) : x.size(0);94const auto n_hidden = x.size(2);9596auto grad_x = at::zeros_like(x);97auto grad_f = at::zeros_like(x);98auto grad_h = at::zeros({batch_size, n_hidden}, x.options());99100const int threads = 1024;101const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);102AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {103forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(104x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),105grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,106seq_length, n_hidden, batch_first);107}));108109THCudaCheck(cudaGetLastError());110return {grad_x, grad_f, grad_h};111}112113114115