Skip to content

Commit 9d01b29

Browse files
authored
Rewrite the semiring classes and integrated in new kernels (#37)
1 parent b34fe11 commit 9d01b29

21 files changed

+1218
-433
lines changed

Diff for: kernels/lltm_cuda.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include <torch/extension.h>
2+
3+
#include <vector>
4+
5+
// CUDA forward declarations
6+
7+
std::vector<torch::Tensor> lltm_cuda_forward(
8+
torch::Tensor input,
9+
torch::Tensor weights,
10+
torch::Tensor bias,
11+
torch::Tensor old_h,
12+
torch::Tensor old_cell);
13+
14+
std::vector<torch::Tensor> lltm_cuda_backward(
15+
torch::Tensor grad_h,
16+
torch::Tensor grad_cell,
17+
torch::Tensor new_cell,
18+
torch::Tensor input_gate,
19+
torch::Tensor output_gate,
20+
torch::Tensor candidate_cell,
21+
torch::Tensor X,
22+
torch::Tensor gate_weights,
23+
torch::Tensor weights);
24+
25+
// C++ interface
26+
27+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
28+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
29+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
30+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
31+
32+
std::vector<torch::Tensor> lltm_forward(
33+
torch::Tensor input,
34+
torch::Tensor weights,
35+
torch::Tensor bias,
36+
torch::Tensor old_h,
37+
torch::Tensor old_cell) {
38+
CHECK_INPUT(input);
39+
CHECK_INPUT(weights);
40+
CHECK_INPUT(bias);
41+
CHECK_INPUT(old_h);
42+
CHECK_INPUT(old_cell);
43+
44+
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
45+
}
46+
47+
std::vector<torch::Tensor> lltm_backward(
48+
torch::Tensor grad_h,
49+
torch::Tensor grad_cell,
50+
torch::Tensor new_cell,
51+
torch::Tensor input_gate,
52+
torch::Tensor output_gate,
53+
torch::Tensor candidate_cell,
54+
torch::Tensor X,
55+
torch::Tensor gate_weights,
56+
torch::Tensor weights) {
57+
CHECK_INPUT(grad_h);
58+
CHECK_INPUT(grad_cell);
59+
CHECK_INPUT(input_gate);
60+
CHECK_INPUT(output_gate);
61+
CHECK_INPUT(candidate_cell);
62+
CHECK_INPUT(X);
63+
CHECK_INPUT(gate_weights);
64+
CHECK_INPUT(weights);
65+
66+
return lltm_cuda_backward(
67+
grad_h,
68+
grad_cell,
69+
new_cell,
70+
input_gate,
71+
output_gate,
72+
candidate_cell,
73+
X,
74+
gate_weights,
75+
weights);
76+
}
77+
78+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
79+
m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
80+
m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
81+
}

Diff for: kernels/lltm_cuda_kernel.cu

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include <torch/extension.h>
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
#include <vector>
7+
8+
namespace {
9+
template <typename scalar_t>
10+
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
11+
return 1.0 / (1.0 + exp(-z));
12+
}
13+
14+
template <typename scalar_t>
15+
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
16+
const auto s = sigmoid(z);
17+
return (1.0 - s) * s;
18+
}
19+
20+
template <typename scalar_t>
21+
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
22+
const auto t = tanh(z);
23+
return 1 - (t * t);
24+
}
25+
26+
template <typename scalar_t>
27+
__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
28+
return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0));
29+
}
30+
31+
template <typename scalar_t>
32+
__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
33+
const auto e = exp(z);
34+
const auto d_relu = z < 0.0 ? 0.0 : 1.0;
35+
return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);
36+
}
37+
38+
template <typename scalar_t>
39+
__global__ void lltm_cuda_forward_kernel(
40+
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
41+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
42+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
43+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
44+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
45+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
46+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
47+
//batch index
48+
const int n = blockIdx.y;
49+
// column index
50+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
51+
if (c < gates.size(2)){
52+
input_gate[n][c] = sigmoid(gates[n][0][c]);
53+
output_gate[n][c] = sigmoid(gates[n][1][c]);
54+
candidate_cell[n][c] = elu(gates[n][2][c]);
55+
new_cell[n][c] =
56+
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
57+
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
58+
}
59+
}
60+
61+
template <typename scalar_t>
62+
__global__ void lltm_cuda_backward_kernel(
63+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
64+
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
65+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
66+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
67+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
68+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
69+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
70+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
71+
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
72+
//batch index
73+
const int n = blockIdx.y;
74+
// column index
75+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
76+
if (c < d_gates.size(2)){
77+
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
78+
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
79+
const auto d_new_cell =
80+
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
81+
82+
83+
d_old_cell[n][c] = d_new_cell;
84+
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
85+
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;
86+
87+
d_gates[n][0][c] =
88+
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
89+
d_gates[n][1][c] =
90+
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
91+
d_gates[n][2][c] =
92+
d_candidate_cell * d_elu(gate_weights[n][2][c]);
93+
}
94+
}
95+
} // namespace
96+
97+
std::vector<torch::Tensor> lltm_cuda_forward(
98+
torch::Tensor input,
99+
torch::Tensor weights,
100+
torch::Tensor bias,
101+
torch::Tensor old_h,
102+
torch::Tensor old_cell) {
103+
auto X = torch::cat({old_h, input}, /*dim=*/1);
104+
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
105+
106+
const auto batch_size = old_cell.size(0);
107+
const auto state_size = old_cell.size(1);
108+
109+
auto gates = gate_weights.reshape({batch_size, 3, state_size});
110+
auto new_h = torch::zeros_like(old_cell);
111+
auto new_cell = torch::zeros_like(old_cell);
112+
auto input_gate = torch::zeros_like(old_cell);
113+
auto output_gate = torch::zeros_like(old_cell);
114+
auto candidate_cell = torch::zeros_like(old_cell);
115+
116+
const int threads = 1024;
117+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
118+
119+
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
120+
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
121+
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
122+
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
123+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
124+
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
125+
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
126+
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
127+
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
128+
}));
129+
130+
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
131+
}
132+
133+
std::vector<torch::Tensor> lltm_cuda_backward(
134+
torch::Tensor grad_h,
135+
torch::Tensor grad_cell,
136+
torch::Tensor new_cell,
137+
torch::Tensor input_gate,
138+
torch::Tensor output_gate,
139+
torch::Tensor candidate_cell,
140+
torch::Tensor X,
141+
torch::Tensor gates,
142+
torch::Tensor weights) {
143+
auto d_old_cell = torch::zeros_like(new_cell);
144+
auto d_gates = torch::zeros_like(gates);
145+
146+
const auto batch_size = new_cell.size(0);
147+
const auto state_size = new_cell.size(1);
148+
149+
const int threads = 1024;
150+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
151+
152+
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
153+
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
154+
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
155+
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
156+
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
157+
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
158+
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
159+
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
160+
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
161+
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
162+
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
163+
}));
164+
165+
auto d_gate_weights = d_gates.flatten(1, 2);
166+
auto d_weights = d_gate_weights.t().mm(X);
167+
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);
168+
169+
auto d_X = d_gate_weights.mm(weights);
170+
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
171+
auto d_input = d_X.slice(/*dim=*/1, state_size);
172+
173+
return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
174+
}

Diff for: kernels/setup.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from setuptools import setup
2+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3+
4+
setup(
5+
name='lltm_cuda',
6+
ext_modules=[
7+
CUDAExtension('lltm_cuda', [
8+
'lltm_cuda.cpp',
9+
'lltm_cuda_kernel.cu',
10+
]),
11+
],
12+
cmdclass={
13+
'build_ext': BuildExtension
14+
})

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
version="0.3",
66
author="Alexander Rush",
77
author_email="arush@cornell.edu",
8-
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],
8+
packages=["torch_struct", "torch_struct.data", "torch_struct.networks", "torch_struct.semirings"],
99
package_data={"torch_struct": []},
1010
url="https://github.com/harvardnlp/pytorch_struct",
1111
install_requires=["torch"],

Diff for: torch_struct/__init__.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@
1919
from .rl import SelfCritical
2020
from .semirings import (
2121
LogSemiring,
22+
FastLogSemiring,
23+
TempMax,
24+
FastMaxSemiring,
25+
FastSampleSemiring,
2226
StdSemiring,
2327
KMaxSemiring,
2428
SparseMaxSemiring,
2529
SampledSemiring,
2630
MaxSemiring,
2731
EntropySemiring,
2832
MultiSampledSemiring,
33+
CheckpointSemiring,
34+
CheckpointShardSemiring,
2935
)
3036

3137

32-
version = "0.2"
38+
version = "0.3"
3339

3440
# For flake8 compatibility.
3541
__all__ = [
@@ -44,6 +50,9 @@
4450
MaxSemiring,
4551
SparseMaxSemiring,
4652
KMaxSemiring,
53+
FastLogSemiring,
54+
FastMaxSemiring,
55+
FastSampleSemiring,
4756
EntropySemiring,
4857
MultiSampledSemiring,
4958
SelfCritical,
@@ -59,4 +68,7 @@
5968
HMM,
6069
AlignmentCRF,
6170
Alignment,
71+
CheckpointSemiring,
72+
CheckpointShardSemiring,
73+
TempMax
6274
]

Diff for: torch_struct/autoregressive.py

-13
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,6 @@ def greedy_tempmax(self, alpha):
205205
a, b, c = self._beam_search(TempMax(alpha), alpha)
206206
return a.squeeze(0), b.squeeze(0), c.squeeze(0)
207207

208-
def greedy_tempmax(self, alpha):
209-
"""
210-
Compute differentiable scheduled sampling using greedy search.
211-
212-
Based on:
213-
214-
* Differentiable Scheduled Sampling for Credit Assignment :cite:`goyal2017differentiable`
215-
216-
Returns:
217-
greedy_path (*batch x N x C*)
218-
"""
219-
return self._beam_search(TempMax(alpha), alpha)[0].squeeze(0)
220-
221208
def beam_topk(self, K):
222209
"""
223210
Compute "top-k" using beam search

0 commit comments

Comments
 (0)