Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snail007
GitHub Repository: snail007/goproxy
Path: blob/master/utils/structs.go
686 views
1
package utils
2
3
import (
4
"bytes"
5
"crypto/tls"
6
"encoding/base64"
7
"fmt"
8
"io"
9
"io/ioutil"
10
"log"
11
"net"
12
"net/url"
13
"strings"
14
"time"
15
)
16
17
type Checker struct {
18
data ConcurrentMap
19
blockedMap ConcurrentMap
20
directMap ConcurrentMap
21
interval int64
22
timeout int
23
}
24
type CheckerItem struct {
25
IsHTTPS bool
26
Method string
27
URL string
28
Domain string
29
Host string
30
Data []byte
31
SuccessCount uint
32
FailCount uint
33
}
34
35
//NewChecker args:
36
//timeout : tcp timeout milliseconds ,connect to host
37
//interval: recheck domain interval seconds
38
func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker {
39
ch := Checker{
40
data: NewConcurrentMap(),
41
interval: interval,
42
timeout: timeout,
43
}
44
ch.blockedMap = ch.loadMap(blockedFile)
45
ch.directMap = ch.loadMap(directFile)
46
if !ch.blockedMap.IsEmpty() {
47
log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
48
}
49
if !ch.directMap.IsEmpty() {
50
log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
51
}
52
ch.start()
53
return ch
54
}
55
56
func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
57
dataMap = NewConcurrentMap()
58
if PathExists(f) {
59
_contents, err := ioutil.ReadFile(f)
60
if err != nil {
61
log.Printf("load file err:%s", err)
62
return
63
}
64
for _, line := range strings.Split(string(_contents), "\n") {
65
line = strings.Trim(line, "\r \t")
66
if line != "" {
67
dataMap.Set(line, true)
68
}
69
}
70
}
71
return
72
}
73
func (c *Checker) start() {
74
go func() {
75
for {
76
for _, v := range c.data.Items() {
77
go func(item CheckerItem) {
78
if c.isNeedCheck(item) {
79
//log.Printf("check %s", item.Domain)
80
var conn net.Conn
81
var err error
82
if item.IsHTTPS {
83
conn, err = ConnectHost(item.Host, c.timeout)
84
if err == nil {
85
conn.SetDeadline(time.Now().Add(time.Millisecond))
86
conn.Close()
87
}
88
} else {
89
err = HTTPGet(item.URL, c.timeout)
90
}
91
if err != nil {
92
item.FailCount = item.FailCount + 1
93
} else {
94
item.SuccessCount = item.SuccessCount + 1
95
}
96
c.data.Set(item.Host, item)
97
}
98
}(v.(CheckerItem))
99
}
100
time.Sleep(time.Second * time.Duration(c.interval))
101
}
102
}()
103
}
104
func (c *Checker) isNeedCheck(item CheckerItem) bool {
105
var minCount uint = 5
106
if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount) ||
107
(item.FailCount >= minCount && item.SuccessCount > item.FailCount) ||
108
c.domainIsInMap(item.Host, false) ||
109
c.domainIsInMap(item.Host, true) {
110
return false
111
}
112
return true
113
}
114
func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) {
115
if c.domainIsInMap(address, true) {
116
//log.Printf("%s in blocked ? true", address)
117
return true, 0, 0
118
}
119
if c.domainIsInMap(address, false) {
120
//log.Printf("%s in direct ? true", address)
121
return false, 0, 0
122
}
123
124
_item, ok := c.data.Get(address)
125
if !ok {
126
//log.Printf("%s not in map, blocked true", address)
127
return true, 0, 0
128
}
129
item := _item.(CheckerItem)
130
131
return item.FailCount >= item.SuccessCount, item.FailCount, item.SuccessCount
132
}
133
func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
134
u, err := url.Parse("http://" + address)
135
if err != nil {
136
log.Printf("blocked check , url parse err:%s", err)
137
return true
138
}
139
domainSlice := strings.Split(u.Hostname(), ".")
140
if len(domainSlice) > 1 {
141
subSlice := domainSlice[:len(domainSlice)-1]
142
topDomain := strings.Join(domainSlice[len(domainSlice)-1:], ".")
143
checkDomain := topDomain
144
for i := len(subSlice) - 1; i >= 0; i-- {
145
checkDomain = subSlice[i] + "." + checkDomain
146
if !blockedMap && c.directMap.Has(checkDomain) {
147
return true
148
}
149
if blockedMap && c.blockedMap.Has(checkDomain) {
150
return true
151
}
152
}
153
}
154
return false
155
}
156
func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []byte) {
157
if c.domainIsInMap(address, false) || c.domainIsInMap(address, true) {
158
return
159
}
160
if !isHTTPS && strings.ToLower(method) != "get" {
161
return
162
}
163
var item CheckerItem
164
u := strings.Split(address, ":")
165
item = CheckerItem{
166
URL: URL,
167
Domain: u[0],
168
Host: address,
169
Data: data,
170
IsHTTPS: isHTTPS,
171
Method: method,
172
}
173
c.data.SetIfAbsent(item.Host, item)
174
}
175
176
type BasicAuth struct {
177
data ConcurrentMap
178
}
179
180
func NewBasicAuth() BasicAuth {
181
return BasicAuth{
182
data: NewConcurrentMap(),
183
}
184
}
185
func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
186
_content, err := ioutil.ReadFile(file)
187
if err != nil {
188
return
189
}
190
userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
191
for _, userpass := range userpassArr {
192
if strings.HasPrefix("#", userpass) {
193
continue
194
}
195
u := strings.Split(strings.Trim(userpass, " "), ":")
196
if len(u) == 2 {
197
ba.data.Set(u[0], u[1])
198
n++
199
}
200
}
201
return
202
}
203
204
func (ba *BasicAuth) Add(userpassArr []string) (n int) {
205
for _, userpass := range userpassArr {
206
u := strings.Split(userpass, ":")
207
if len(u) == 2 {
208
ba.data.Set(u[0], u[1])
209
n++
210
}
211
}
212
return
213
}
214
215
func (ba *BasicAuth) Check(userpass string) (ok bool) {
216
u := strings.Split(strings.Trim(userpass, " "), ":")
217
if len(u) == 2 {
218
if p, _ok := ba.data.Get(u[0]); _ok {
219
return p.(string) == u[1]
220
}
221
}
222
return
223
}
224
func (ba *BasicAuth) Total() (n int) {
225
n = ba.data.Count()
226
return
227
}
228
229
type HTTPRequest struct {
230
HeadBuf []byte
231
conn *net.Conn
232
Host string
233
Method string
234
URL string
235
hostOrURL string
236
isBasicAuth bool
237
basicAuth *BasicAuth
238
}
239
240
func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) {
241
buf := make([]byte, bufSize)
242
len := 0
243
req = HTTPRequest{
244
conn: inConn,
245
}
246
len, err = (*inConn).Read(buf[:])
247
if err != nil {
248
if err != io.EOF {
249
err = fmt.Errorf("http decoder read err:%s", err)
250
}
251
CloseConn(inConn)
252
return
253
}
254
req.HeadBuf = buf[:len]
255
index := bytes.IndexByte(req.HeadBuf, '\n')
256
if index == -1 {
257
err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50])
258
CloseConn(inConn)
259
return
260
}
261
fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
262
if req.Method == "" || req.hostOrURL == "" {
263
err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50])
264
CloseConn(inConn)
265
return
266
}
267
req.Method = strings.ToUpper(req.Method)
268
req.isBasicAuth = isBasicAuth
269
req.basicAuth = basicAuth
270
log.Printf("%s:%s", req.Method, req.hostOrURL)
271
272
if req.IsHTTPS() {
273
err = req.HTTPS()
274
} else {
275
err = req.HTTP()
276
}
277
return
278
}
279
func (req *HTTPRequest) HTTP() (err error) {
280
if req.isBasicAuth {
281
err = req.BasicAuth()
282
if err != nil {
283
return
284
}
285
}
286
req.URL, err = req.getHTTPURL()
287
if err == nil {
288
u, _ := url.Parse(req.URL)
289
req.Host = u.Host
290
req.addPortIfNot()
291
}
292
return
293
}
294
func (req *HTTPRequest) HTTPS() (err error) {
295
req.Host = req.hostOrURL
296
req.addPortIfNot()
297
//_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
298
return
299
}
300
func (req *HTTPRequest) HTTPSReply() (err error) {
301
_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
302
return
303
}
304
func (req *HTTPRequest) IsHTTPS() bool {
305
return req.Method == "CONNECT"
306
}
307
308
func (req *HTTPRequest) BasicAuth() (err error) {
309
310
//log.Printf("request :%s", string(b[:n]))
311
authorization, err := req.getHeader("Authorization")
312
if err != nil {
313
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
314
CloseConn(req.conn)
315
return
316
}
317
//log.Printf("Authorization:%s", authorization)
318
basic := strings.Fields(authorization)
319
if len(basic) != 2 {
320
err = fmt.Errorf("authorization data error,ERR:%s", authorization)
321
CloseConn(req.conn)
322
return
323
}
324
user, err := base64.StdEncoding.DecodeString(basic[1])
325
if err != nil {
326
err = fmt.Errorf("authorization data parse error,ERR:%s", err)
327
CloseConn(req.conn)
328
return
329
}
330
authOk := (*req.basicAuth).Check(string(user))
331
//log.Printf("auth %s,%v", string(user), authOk)
332
if !authOk {
333
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
334
CloseConn(req.conn)
335
err = fmt.Errorf("basic auth fail")
336
return
337
}
338
return
339
}
340
func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
341
if !strings.HasPrefix(req.hostOrURL, "/") {
342
return req.hostOrURL, nil
343
}
344
_host, err := req.getHeader("host")
345
if err != nil {
346
return
347
}
348
URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
349
return
350
}
351
func (req *HTTPRequest) getHeader(key string) (val string, err error) {
352
key = strings.ToUpper(key)
353
lines := strings.Split(string(req.HeadBuf), "\r\n")
354
for _, line := range lines {
355
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
356
if len(line) == 2 {
357
k := strings.ToUpper(strings.Trim(line[0], " "))
358
v := strings.Trim(line[1], " ")
359
if key == k {
360
val = v
361
return
362
}
363
}
364
}
365
err = fmt.Errorf("can not find HOST header")
366
return
367
}
368
369
func (req *HTTPRequest) addPortIfNot() (newHost string) {
370
//newHost = req.Host
371
port := "80"
372
if req.IsHTTPS() {
373
port = "443"
374
}
375
if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
376
//newHost = req.Host + ":" + port
377
//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
378
req.Host = req.Host + ":" + port
379
}
380
return
381
}
382
383
type OutPool struct {
384
Pool ConnPool
385
dur int
386
isTLS bool
387
certBytes []byte
388
keyBytes []byte
389
address string
390
timeout int
391
}
392
393
func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
394
op = OutPool{
395
dur: dur,
396
isTLS: isTLS,
397
certBytes: certBytes,
398
keyBytes: keyBytes,
399
address: address,
400
timeout: timeout,
401
}
402
var err error
403
op.Pool, err = NewConnPool(poolConfig{
404
IsActive: func(conn interface{}) bool { return true },
405
Release: func(conn interface{}) {
406
if conn != nil {
407
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
408
conn.(net.Conn).Close()
409
// log.Println("conn released")
410
}
411
},
412
InitialCap: InitialCap,
413
MaxCap: MaxCap,
414
Factory: func() (conn interface{}, err error) {
415
conn, err = op.getConn()
416
return
417
},
418
})
419
if err != nil {
420
log.Fatalf("init conn pool fail ,%s", err)
421
} else {
422
if InitialCap > 0 {
423
log.Printf("init conn pool success")
424
op.initPoolDeamon()
425
} else {
426
log.Printf("conn pool closed")
427
}
428
}
429
return
430
}
431
func (op *OutPool) getConn() (conn interface{}, err error) {
432
if op.isTLS {
433
var _conn tls.Conn
434
_conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes)
435
if err == nil {
436
conn = net.Conn(&_conn)
437
}
438
} else {
439
conn, err = ConnectHost(op.address, op.timeout)
440
}
441
return
442
}
443
444
func (op *OutPool) initPoolDeamon() {
445
go func() {
446
if op.dur <= 0 {
447
return
448
}
449
log.Printf("pool deamon started")
450
for {
451
time.Sleep(time.Second * time.Duration(op.dur))
452
conn, err := op.getConn()
453
if err != nil {
454
log.Printf("pool deamon err %s , release pool", err)
455
op.Pool.ReleaseAll()
456
} else {
457
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
458
conn.(net.Conn).Close()
459
}
460
}
461
}()
462
}
463
464