Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snail007
GitHub Repository: snail007/goproxy
Path: blob/master/utils/functions.go
686 views
1
package utils
2
3
import (
4
"bufio"
5
"bytes"
6
"crypto/tls"
7
"crypto/x509"
8
"encoding/binary"
9
"errors"
10
"fmt"
11
"io"
12
"log"
13
"net"
14
"net/http"
15
"os"
16
"os/exec"
17
"sync"
18
19
"runtime/debug"
20
"strconv"
21
"strings"
22
"time"
23
)
24
25
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
26
var one = &sync.Once{}
27
go func() {
28
defer func() {
29
if e := recover(); e != nil {
30
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
31
}
32
}()
33
var err error
34
var isSrcErr bool
35
if bytesPreSec > 0 {
36
newreader := NewReader(src)
37
newreader.SetRateLimit(bytesPreSec)
38
_, isSrcErr, err = ioCopy(dst, newreader, func(c int) {
39
cfn(c, false)
40
})
41
42
} else {
43
_, isSrcErr, err = ioCopy(dst, src, func(c int) {
44
cfn(c, false)
45
})
46
}
47
if err != nil {
48
one.Do(func() {
49
fn(isSrcErr, err)
50
})
51
}
52
}()
53
go func() {
54
defer func() {
55
if e := recover(); e != nil {
56
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
57
}
58
}()
59
var err error
60
var isSrcErr bool
61
if bytesPreSec > 0 {
62
newReader := NewReader(dst)
63
newReader.SetRateLimit(bytesPreSec)
64
_, isSrcErr, err = ioCopy(src, newReader, func(c int) {
65
cfn(c, true)
66
})
67
} else {
68
_, isSrcErr, err = ioCopy(src, dst, func(c int) {
69
cfn(c, true)
70
})
71
}
72
if err != nil {
73
one.Do(func() {
74
fn(isSrcErr, err)
75
})
76
}
77
}()
78
}
79
func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
80
buf := make([]byte, 32*1024)
81
for {
82
nr, er := src.Read(buf)
83
if nr > 0 {
84
nw, ew := dst.Write(buf[0:nr])
85
if nw > 0 {
86
written += int64(nw)
87
if len(fn) == 1 {
88
fn[0](nw)
89
}
90
}
91
if ew != nil {
92
err = ew
93
break
94
}
95
if nr != nw {
96
err = io.ErrShortWrite
97
break
98
}
99
}
100
if er != nil {
101
err = er
102
isSrcErr = true
103
break
104
}
105
}
106
return written, isSrcErr, err
107
}
108
func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
109
h := strings.Split(host, ":")
110
port, _ := strconv.Atoi(h[1])
111
return TlsConnect(h[0], port, timeout, certBytes, keyBytes)
112
}
113
114
func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
115
conf, err := getRequestTlsConfig(certBytes, keyBytes)
116
if err != nil {
117
return
118
}
119
_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
120
if err != nil {
121
return
122
}
123
return *tls.Client(_conn, conf), err
124
}
125
func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) {
126
var cert tls.Certificate
127
cert, err = tls.X509KeyPair(certBytes, keyBytes)
128
if err != nil {
129
return
130
}
131
serverCertPool := x509.NewCertPool()
132
ok := serverCertPool.AppendCertsFromPEM(certBytes)
133
if !ok {
134
err = errors.New("failed to parse root certificate")
135
}
136
conf = &tls.Config{
137
RootCAs: serverCertPool,
138
Certificates: []tls.Certificate{cert},
139
ServerName: "proxy",
140
InsecureSkipVerify: false,
141
}
142
return
143
}
144
145
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
146
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
147
return
148
}
149
func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) {
150
var cert tls.Certificate
151
cert, err = tls.X509KeyPair(certBytes, keyBytes)
152
if err != nil {
153
return
154
}
155
clientCertPool := x509.NewCertPool()
156
ok := clientCertPool.AppendCertsFromPEM(certBytes)
157
if !ok {
158
err = errors.New("failed to parse root certificate")
159
}
160
config := &tls.Config{
161
ClientCAs: clientCertPool,
162
ServerName: "proxy",
163
Certificates: []tls.Certificate{cert},
164
ClientAuth: tls.RequireAndVerifyClientCert,
165
}
166
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
167
if err == nil {
168
ln = &_ln
169
}
170
return
171
}
172
func PathExists(_path string) bool {
173
_, err := os.Stat(_path)
174
if err != nil && os.IsNotExist(err) {
175
return false
176
}
177
return true
178
}
179
func HTTPGet(URL string, timeout int) (err error) {
180
tr := &http.Transport{}
181
var resp *http.Response
182
var client *http.Client
183
defer func() {
184
if resp != nil && resp.Body != nil {
185
resp.Body.Close()
186
}
187
tr.CloseIdleConnections()
188
}()
189
client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
190
resp, err = client.Get(URL)
191
if err != nil {
192
return
193
}
194
return
195
}
196
197
func CloseConn(conn *net.Conn) {
198
if conn != nil && *conn != nil {
199
(*conn).SetDeadline(time.Now().Add(time.Millisecond))
200
(*conn).Close()
201
}
202
}
203
func Keygen() (err error) {
204
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
205
out, err := cmd.CombinedOutput()
206
if err != nil {
207
log.Printf("err:%s", err)
208
return
209
}
210
fmt.Println(string(out))
211
cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`)
212
out, err = cmd.CombinedOutput()
213
if err != nil {
214
log.Printf("err:%s", err)
215
return
216
}
217
fmt.Println(string(out))
218
return
219
}
220
func GetAllInterfaceAddr() ([]net.IP, error) {
221
222
ifaces, err := net.Interfaces()
223
if err != nil {
224
return nil, err
225
}
226
addresses := []net.IP{}
227
for _, iface := range ifaces {
228
229
if iface.Flags&net.FlagUp == 0 {
230
continue // interface down
231
}
232
// if iface.Flags&net.FlagLoopback != 0 {
233
// continue // loopback interface
234
// }
235
addrs, err := iface.Addrs()
236
if err != nil {
237
continue
238
}
239
240
for _, addr := range addrs {
241
var ip net.IP
242
switch v := addr.(type) {
243
case *net.IPNet:
244
ip = v.IP
245
case *net.IPAddr:
246
ip = v.IP
247
}
248
// if ip == nil || ip.IsLoopback() {
249
// continue
250
// }
251
ip = ip.To4()
252
if ip == nil {
253
continue // not an ipv4 address
254
}
255
addresses = append(addresses, ip)
256
}
257
}
258
if len(addresses) == 0 {
259
return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses)
260
}
261
//only need first
262
return addresses, nil
263
}
264
func UDPPacket(srcAddr string, packet []byte) []byte {
265
addrBytes := []byte(srcAddr)
266
addrLength := uint16(len(addrBytes))
267
bodyLength := uint16(len(packet))
268
pkg := new(bytes.Buffer)
269
binary.Write(pkg, binary.LittleEndian, addrLength)
270
binary.Write(pkg, binary.LittleEndian, addrBytes)
271
binary.Write(pkg, binary.LittleEndian, bodyLength)
272
binary.Write(pkg, binary.LittleEndian, packet)
273
return pkg.Bytes()
274
}
275
func ReadUDPPacket(conn *net.Conn) (srcAddr string, packet []byte, err error) {
276
reader := bufio.NewReader(*conn)
277
var addrLength uint16
278
var bodyLength uint16
279
err = binary.Read(reader, binary.LittleEndian, &addrLength)
280
if err != nil {
281
return
282
}
283
_srcAddr := make([]byte, addrLength)
284
n, err := reader.Read(_srcAddr)
285
if err != nil {
286
return
287
}
288
if n != int(addrLength) {
289
return
290
}
291
srcAddr = string(_srcAddr)
292
293
err = binary.Read(reader, binary.LittleEndian, &bodyLength)
294
if err != nil {
295
return
296
}
297
packet = make([]byte, bodyLength)
298
n, err = reader.Read(packet)
299
if err != nil {
300
return
301
}
302
if n != int(bodyLength) {
303
return
304
}
305
return
306
}
307
308
// type sockaddr struct {
309
// family uint16
310
// data [14]byte
311
// }
312
313
// const SO_ORIGINAL_DST = 80
314
315
// realServerAddress returns an intercepted connection's original destination.
316
// func realServerAddress(conn *net.Conn) (string, error) {
317
// tcpConn, ok := (*conn).(*net.TCPConn)
318
// if !ok {
319
// return "", errors.New("not a TCPConn")
320
// }
321
322
// file, err := tcpConn.File()
323
// if err != nil {
324
// return "", err
325
// }
326
327
// // To avoid potential problems from making the socket non-blocking.
328
// tcpConn.Close()
329
// *conn, err = net.FileConn(file)
330
// if err != nil {
331
// return "", err
332
// }
333
334
// defer file.Close()
335
// fd := file.Fd()
336
337
// var addr sockaddr
338
// size := uint32(unsafe.Sizeof(addr))
339
// err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size)
340
// if err != nil {
341
// return "", err
342
// }
343
344
// var ip net.IP
345
// switch addr.family {
346
// case syscall.AF_INET:
347
// ip = addr.data[2:6]
348
// default:
349
// return "", errors.New("unrecognized address family")
350
// }
351
352
// port := int(addr.data[0])<<8 + int(addr.data[1])
353
354
// return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil
355
// }
356
357
// func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) {
358
// _, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
359
// if e1 != 0 {
360
// err = e1
361
// }
362
// return
363
// }
364
365