Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/models/forget_mult_cuda.cpp
840 views
1
#include <torch/torch.h>
2
3
#include <vector>
4
5
// CUDA forward declarations
6
at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first);
7
8
// C++ interface
9
10
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
11
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
12
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13
14
at::Tensor forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
15
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
16
return forget_mult_cuda_forward(x, f, output, batch_first);
17
}
18
19
std::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output,
20
at::Tensor grad_output, bool batch_first);
21
22
std::vector<at::Tensor> forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output,
23
at::Tensor grad_output, bool batch_first) {
24
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
25
return forget_mult_cuda_backward(x, f, output, grad_output, batch_first);
26
}
27
28
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29
m.def("forward", &forget_mult_forward, "ForgetMult forward (CUDA)");
30
m.def("backward", &forget_mult_backward, "ForgetMult backward (CUDA)");
31
}
32
33