Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/gitpod-protocol/go/reconnecting-ws.go
2498 views
1
// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package protocol
6
7
import (
8
"context"
9
"encoding/json"
10
"errors"
11
"fmt"
12
"net/http"
13
"sync"
14
"time"
15
16
backoff "github.com/cenkalti/backoff/v4"
17
"github.com/gorilla/websocket"
18
"github.com/sirupsen/logrus"
19
)
20
21
// ErrClosed is returned when the reconnecting web socket is closed.
22
var ErrClosed = errors.New("reconnecting-ws: closed")
23
24
// ErrBadHandshake is returned when the server response to opening handshake is
25
// invalid.
26
type ErrBadHandshake struct {
27
URL string
28
Resp *http.Response
29
}
30
31
func (e *ErrBadHandshake) Error() string {
32
var statusCode int
33
if e.Resp != nil {
34
statusCode = e.Resp.StatusCode
35
}
36
return fmt.Sprintf("reconnecting-ws: bad handshake: code %v - URL: %v", statusCode, e.URL)
37
}
38
39
// The ReconnectingWebsocket represents a Reconnecting WebSocket connection.
40
type ReconnectingWebsocket struct {
41
url string
42
reqHeader http.Header
43
handshakeTimeout time.Duration
44
45
once sync.Once
46
closeErr error
47
closedCh chan struct{}
48
connCh chan chan *WebsocketConnection
49
errCh chan error
50
51
log *logrus.Entry
52
53
ReconnectionHandler func()
54
55
badHandshakeCount uint8
56
badHandshakeMax uint8
57
}
58
59
// NewReconnectingWebsocket creates a new instance of ReconnectingWebsocket
60
func NewReconnectingWebsocket(url string, reqHeader http.Header, log *logrus.Entry) *ReconnectingWebsocket {
61
return &ReconnectingWebsocket{
62
url: url,
63
reqHeader: reqHeader,
64
handshakeTimeout: 2 * time.Second,
65
connCh: make(chan chan *WebsocketConnection),
66
closedCh: make(chan struct{}),
67
errCh: make(chan error),
68
log: log,
69
badHandshakeCount: 0,
70
badHandshakeMax: 15,
71
}
72
}
73
74
// Close closes the underlying webscoket connection.
75
func (rc *ReconnectingWebsocket) Close() error {
76
return rc.closeWithError(ErrClosed)
77
}
78
79
func (rc *ReconnectingWebsocket) closeWithError(closeErr error) error {
80
rc.once.Do(func() {
81
rc.closeErr = closeErr
82
close(rc.closedCh)
83
})
84
return nil
85
}
86
87
// EnsureConnection ensures ws connections
88
// Returns only if connection is permanently failed
89
// If the passed handler returns false as closed then err is returned to the client,
90
// otherwise err is treated as a connection error, and new conneciton is provided.
91
func (rc *ReconnectingWebsocket) EnsureConnection(handler func(conn *WebsocketConnection) (closed bool, err error)) error {
92
for {
93
connCh := make(chan *WebsocketConnection, 1)
94
select {
95
case <-rc.closedCh:
96
return rc.closeErr
97
case rc.connCh <- connCh:
98
}
99
conn := <-connCh
100
closed, err := handler(conn)
101
if !closed {
102
return err
103
}
104
select {
105
case <-rc.closedCh:
106
return rc.closeErr
107
case rc.errCh <- err:
108
}
109
}
110
}
111
112
func isJSONError(err error) bool {
113
_, isJsonErr := err.(*json.MarshalerError)
114
if isJsonErr {
115
return true
116
}
117
_, isJsonErr = err.(*json.SyntaxError)
118
if isJsonErr {
119
return true
120
}
121
_, isJsonErr = err.(*json.UnsupportedTypeError)
122
if isJsonErr {
123
return true
124
}
125
_, isJsonErr = err.(*json.UnsupportedValueError)
126
return isJsonErr
127
}
128
129
// WriteObject writes the JSON encoding of v as a message.
130
// See the documentation for encoding/json Marshal for details about the conversion of Go values to JSON.
131
func (rc *ReconnectingWebsocket) WriteObject(v interface{}) error {
132
return rc.EnsureConnection(func(conn *WebsocketConnection) (bool, error) {
133
err := conn.WriteJSON(v)
134
closed := err != nil && !isJSONError(err)
135
return closed, err
136
})
137
}
138
139
// ReadObject reads the next JSON-encoded message from the connection and stores it in the value pointed to by v.
140
// See the documentation for the encoding/json Unmarshal function for details about the conversion of JSON to a Go value.
141
func (rc *ReconnectingWebsocket) ReadObject(v interface{}) error {
142
return rc.EnsureConnection(func(conn *WebsocketConnection) (bool, error) {
143
err := conn.ReadJSON(v)
144
closed := err != nil && !isJSONError(err)
145
return closed, err
146
})
147
}
148
149
// Dial creates a new client connection.
150
func (rc *ReconnectingWebsocket) Dial(ctx context.Context) error {
151
var conn *WebsocketConnection
152
defer func() {
153
if conn == nil {
154
return
155
}
156
rc.log.WithField("url", rc.url).Debug("connection is permanently closed")
157
conn.Close()
158
}()
159
160
conn = rc.connect(ctx)
161
162
for {
163
select {
164
case <-rc.closedCh:
165
return rc.closeErr
166
case connCh := <-rc.connCh:
167
connCh <- conn
168
case <-rc.errCh:
169
if conn != nil {
170
conn.Close()
171
}
172
173
time.Sleep(1 * time.Second)
174
conn = rc.connect(ctx)
175
if conn != nil && rc.ReconnectionHandler != nil {
176
go rc.ReconnectionHandler()
177
}
178
}
179
}
180
}
181
182
func (rc *ReconnectingWebsocket) connect(ctx context.Context) *WebsocketConnection {
183
exp := &backoff.ExponentialBackOff{
184
InitialInterval: 2 * time.Second,
185
RandomizationFactor: 0.5,
186
Multiplier: 1.5,
187
MaxInterval: 30 * time.Second,
188
MaxElapsedTime: 0,
189
Stop: backoff.Stop,
190
Clock: backoff.SystemClock,
191
}
192
exp.Reset()
193
for {
194
// Gorilla websocket does not check if context is valid when dialing so we do it prior
195
select {
196
case <-ctx.Done():
197
rc.log.WithField("url", rc.url).Debug("context done...closing")
198
rc.Close()
199
return nil
200
default:
201
}
202
203
dialer := websocket.Dialer{HandshakeTimeout: rc.handshakeTimeout}
204
conn, resp, err := dialer.DialContext(ctx, rc.url, rc.reqHeader)
205
if err == nil {
206
rc.log.WithField("url", rc.url).Debug("connection was successfully established")
207
ws, err := NewWebsocketConnection(context.Background(), conn, func(staleErr error) {
208
rc.errCh <- staleErr
209
})
210
if err == nil {
211
rc.badHandshakeCount = 0
212
return ws
213
}
214
}
215
216
var statusCode int
217
if resp != nil {
218
statusCode = resp.StatusCode
219
}
220
221
// 200 is bad gateway for ws, we should keep trying
222
if err == websocket.ErrBadHandshake && statusCode != 200 {
223
rc.badHandshakeCount++
224
// if mal-formed handshake request (unauthorized, forbidden) or client actions (redirect) are required then fail immediately
225
// otherwise try several times and fail, maybe temporarily unavailable, like server restart
226
if rc.badHandshakeCount > rc.badHandshakeMax || (http.StatusMultipleChoices <= statusCode && statusCode < http.StatusInternalServerError) {
227
_ = rc.closeWithError(&ErrBadHandshake{rc.url, resp})
228
return nil
229
}
230
}
231
232
delay := exp.NextBackOff()
233
rc.log.WithError(err).
234
WithField("url", rc.url).
235
WithField("badHandshakeCount", fmt.Sprintf("%d/%d", rc.badHandshakeCount, rc.badHandshakeMax)).
236
WithField("statusCode", statusCode).
237
WithField("delay", delay.String()).
238
Error("failed to connect, trying again...")
239
select {
240
case <-rc.closedCh:
241
return nil
242
case <-time.After(delay):
243
}
244
}
245
}
246
247