Skip to content

Commit 503b469

Browse files
committed
Simplify sliding window API
1 parent 1c7c14e commit 503b469

File tree

4 files changed

+39
-38
lines changed

4 files changed

+39
-38
lines changed

Diff for: compress_notjs.go

+25-18
Original file line numberDiff line numberDiff line change
@@ -120,41 +120,48 @@ type slidingWindow struct {
120120

121121
var swPool = map[int]*sync.Pool{}
122122

123-
func newSlidingWindow(n int) *slidingWindow {
123+
func (sw *slidingWindow) init(n int) {
124+
if sw.buf != nil {
125+
return
126+
}
127+
124128
p, ok := swPool[n]
125129
if !ok {
126130
p = &sync.Pool{}
127131
swPool[n] = p
128132
}
129-
sw, ok := p.Get().(*slidingWindow)
133+
buf, ok := p.Get().([]byte)
130134
if ok {
131-
return sw
132-
}
133-
return &slidingWindow{
134-
buf: make([]byte, 0, n),
135+
sw.buf = buf[:0]
136+
} else {
137+
sw.buf = make([]byte, 0, n)
135138
}
136139
}
137140

138-
func returnSlidingWindow(sw *slidingWindow) {
139-
sw.buf = sw.buf[:0]
140-
swPool[cap(sw.buf)].Put(sw)
141+
func (sw *slidingWindow) close() {
142+
if sw.buf == nil {
143+
return
144+
}
145+
146+
swPool[cap(sw.buf)].Put(sw.buf)
147+
sw.buf = nil
141148
}
142149

143-
func (w *slidingWindow) write(p []byte) {
144-
if len(p) >= cap(w.buf) {
145-
w.buf = w.buf[:cap(w.buf)]
146-
p = p[len(p)-cap(w.buf):]
147-
copy(w.buf, p)
150+
func (sw *slidingWindow) write(p []byte) {
151+
if len(p) >= cap(sw.buf) {
152+
sw.buf = sw.buf[:cap(sw.buf)]
153+
p = p[len(p)-cap(sw.buf):]
154+
copy(sw.buf, p)
148155
return
149156
}
150157

151-
left := cap(w.buf) - len(w.buf)
158+
left := cap(sw.buf) - len(sw.buf)
152159
if left < len(p) {
153160
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
154161
spaceNeeded := len(p) - left
155-
copy(w.buf, w.buf[spaceNeeded:])
156-
w.buf = w.buf[:len(w.buf)-spaceNeeded]
162+
copy(sw.buf, sw.buf[spaceNeeded:])
163+
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
157164
}
158165

159-
w.buf = append(w.buf, p...)
166+
sw.buf = append(sw.buf, p...)
160167
}

Diff for: compress_test.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ func Test_slidingWindow(t *testing.T) {
2121

2222
input := xrand.String(maxWindow)
2323
windowLength := xrand.Int(maxWindow)
24-
r := newSlidingWindow(windowLength)
25-
r.write([]byte(input))
24+
var sw slidingWindow
25+
sw.init(windowLength)
26+
sw.write([]byte(input))
2627

27-
assert.Equal(t, "window length", windowLength, cap(r.buf))
28-
if !strings.HasSuffix(input, string(r.buf)) {
29-
t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf)
28+
assert.Equal(t, "window length", windowLength, cap(sw.buf))
29+
if !strings.HasSuffix(input, string(sw.buf)) {
30+
t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf)
3031
}
3132
})
3233
}

Diff for: conn_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,12 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
351351
ctx, cancel := context.WithCancel(tt.ctx)
352352

353353
discardLoopErr := xsync.Go(func() error {
354+
defer c.Close(websocket.StatusInternalError, "")
355+
354356
for {
355357
_, _, err := c.Read(ctx)
356-
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
357-
return nil
358-
}
359358
if err != nil {
360-
return err
359+
return assertCloseStatus(websocket.StatusNormalClosure, err)
361360
}
362361
}
363362
})

Diff for: read.go

+5-11
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,11 @@ func newMsgReader(c *Conn) *msgReader {
8787
}
8888

8989
func (mr *msgReader) resetFlate() {
90-
if mr.flateContextTakeover() && mr.dict == nil {
91-
mr.dict = newSlidingWindow(32768)
92-
}
93-
9490
if mr.flateContextTakeover() {
95-
mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
96-
} else {
97-
mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
91+
mr.dict.init(32768)
9892
}
93+
94+
mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
9995
mr.limitReader.r = mr.flateReader
10096
mr.flateTail.Reset(deflateMessageTail)
10197
}
@@ -111,9 +107,7 @@ func (mr *msgReader) close() {
111107
mr.c.readMu.Lock(context.Background())
112108
mr.returnFlateReader()
113109

114-
if mr.dict != nil {
115-
returnSlidingWindow(mr.dict)
116-
}
110+
mr.dict.close()
117111
}
118112

119113
func (mr *msgReader) flateContextTakeover() bool {
@@ -325,7 +319,7 @@ type msgReader struct {
325319
flateReader io.Reader
326320
flateTail strings.Reader
327321
limitReader *limitReader
328-
dict *slidingWindow
322+
dict slidingWindow
329323

330324
fin bool
331325
payloadLength int64

0 commit comments

Comments
 (0)