|
| 1 | +#include <torch/extension.h> |
| 2 | + |
| 3 | +#include <cuda.h> |
| 4 | +#include <cuda_runtime.h> |
| 5 | + |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +namespace { |
| 9 | +template <typename scalar_t> |
| 10 | +__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { |
| 11 | + return 1.0 / (1.0 + exp(-z)); |
| 12 | +} |
| 13 | + |
| 14 | +template <typename scalar_t> |
| 15 | +__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { |
| 16 | + const auto s = sigmoid(z); |
| 17 | + return (1.0 - s) * s; |
| 18 | +} |
| 19 | + |
| 20 | +template <typename scalar_t> |
| 21 | +__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { |
| 22 | + const auto t = tanh(z); |
| 23 | + return 1 - (t * t); |
| 24 | +} |
| 25 | + |
| 26 | +template <typename scalar_t> |
| 27 | +__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { |
| 28 | + return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); |
| 29 | +} |
| 30 | + |
| 31 | +template <typename scalar_t> |
| 32 | +__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { |
| 33 | + const auto e = exp(z); |
| 34 | + const auto d_relu = z < 0.0 ? 0.0 : 1.0; |
| 35 | + return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); |
| 36 | +} |
| 37 | + |
| 38 | +template <typename scalar_t> |
| 39 | +__global__ void lltm_cuda_forward_kernel( |
| 40 | + const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates, |
| 41 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell, |
| 42 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h, |
| 43 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell, |
| 44 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate, |
| 45 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate, |
| 46 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) { |
| 47 | + //batch index |
| 48 | + const int n = blockIdx.y; |
| 49 | + // column index |
| 50 | + const int c = blockIdx.x * blockDim.x + threadIdx.x; |
| 51 | + if (c < gates.size(2)){ |
| 52 | + input_gate[n][c] = sigmoid(gates[n][0][c]); |
| 53 | + output_gate[n][c] = sigmoid(gates[n][1][c]); |
| 54 | + candidate_cell[n][c] = elu(gates[n][2][c]); |
| 55 | + new_cell[n][c] = |
| 56 | + old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; |
| 57 | + new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +template <typename scalar_t> |
| 62 | +__global__ void lltm_cuda_backward_kernel( |
| 63 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell, |
| 64 | + torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates, |
| 65 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h, |
| 66 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell, |
| 67 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell, |
| 68 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate, |
| 69 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate, |
| 70 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell, |
| 71 | + const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) { |
| 72 | + //batch index |
| 73 | + const int n = blockIdx.y; |
| 74 | + // column index |
| 75 | + const int c = blockIdx.x * blockDim.x + threadIdx.x; |
| 76 | + if (c < d_gates.size(2)){ |
| 77 | + const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; |
| 78 | + const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; |
| 79 | + const auto d_new_cell = |
| 80 | + d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; |
| 81 | + |
| 82 | + |
| 83 | + d_old_cell[n][c] = d_new_cell; |
| 84 | + const auto d_candidate_cell = input_gate[n][c] * d_new_cell; |
| 85 | + const auto d_input_gate = candidate_cell[n][c] * d_new_cell; |
| 86 | + |
| 87 | + d_gates[n][0][c] = |
| 88 | + d_input_gate * d_sigmoid(gate_weights[n][0][c]); |
| 89 | + d_gates[n][1][c] = |
| 90 | + d_output_gate * d_sigmoid(gate_weights[n][1][c]); |
| 91 | + d_gates[n][2][c] = |
| 92 | + d_candidate_cell * d_elu(gate_weights[n][2][c]); |
| 93 | + } |
| 94 | +} |
| 95 | +} // namespace |
| 96 | + |
| 97 | +std::vector<torch::Tensor> lltm_cuda_forward( |
| 98 | + torch::Tensor input, |
| 99 | + torch::Tensor weights, |
| 100 | + torch::Tensor bias, |
| 101 | + torch::Tensor old_h, |
| 102 | + torch::Tensor old_cell) { |
| 103 | + auto X = torch::cat({old_h, input}, /*dim=*/1); |
| 104 | + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); |
| 105 | + |
| 106 | + const auto batch_size = old_cell.size(0); |
| 107 | + const auto state_size = old_cell.size(1); |
| 108 | + |
| 109 | + auto gates = gate_weights.reshape({batch_size, 3, state_size}); |
| 110 | + auto new_h = torch::zeros_like(old_cell); |
| 111 | + auto new_cell = torch::zeros_like(old_cell); |
| 112 | + auto input_gate = torch::zeros_like(old_cell); |
| 113 | + auto output_gate = torch::zeros_like(old_cell); |
| 114 | + auto candidate_cell = torch::zeros_like(old_cell); |
| 115 | + |
| 116 | + const int threads = 1024; |
| 117 | + const dim3 blocks((state_size + threads - 1) / threads, batch_size); |
| 118 | + |
| 119 | + AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { |
| 120 | + lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>( |
| 121 | + gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(), |
| 122 | + old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 123 | + new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 124 | + new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 125 | + input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 126 | + output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 127 | + candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()); |
| 128 | + })); |
| 129 | + |
| 130 | + return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; |
| 131 | +} |
| 132 | + |
| 133 | +std::vector<torch::Tensor> lltm_cuda_backward( |
| 134 | + torch::Tensor grad_h, |
| 135 | + torch::Tensor grad_cell, |
| 136 | + torch::Tensor new_cell, |
| 137 | + torch::Tensor input_gate, |
| 138 | + torch::Tensor output_gate, |
| 139 | + torch::Tensor candidate_cell, |
| 140 | + torch::Tensor X, |
| 141 | + torch::Tensor gates, |
| 142 | + torch::Tensor weights) { |
| 143 | + auto d_old_cell = torch::zeros_like(new_cell); |
| 144 | + auto d_gates = torch::zeros_like(gates); |
| 145 | + |
| 146 | + const auto batch_size = new_cell.size(0); |
| 147 | + const auto state_size = new_cell.size(1); |
| 148 | + |
| 149 | + const int threads = 1024; |
| 150 | + const dim3 blocks((state_size + threads - 1) / threads, batch_size); |
| 151 | + |
| 152 | + AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { |
| 153 | + lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>( |
| 154 | + d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 155 | + d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(), |
| 156 | + grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 157 | + grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 158 | + new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 159 | + input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 160 | + output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 161 | + candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 162 | + gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>()); |
| 163 | + })); |
| 164 | + |
| 165 | + auto d_gate_weights = d_gates.flatten(1, 2); |
| 166 | + auto d_weights = d_gate_weights.t().mm(X); |
| 167 | + auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); |
| 168 | + |
| 169 | + auto d_X = d_gate_weights.mm(weights); |
| 170 | + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); |
| 171 | + auto d_input = d_X.slice(/*dim=*/1, state_size); |
| 172 | + |
| 173 | + return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; |
| 174 | +} |
0 commit comments