Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snail007
GitHub Repository: snail007/goproxy
Path: blob/master/services/udp.go
686 views
1
package services
2
3
import (
4
"bufio"
5
"fmt"
6
"hash/crc32"
7
"io"
8
"log"
9
"net"
10
"github.com/snail007/goproxy/utils"
11
"runtime/debug"
12
"strconv"
13
"strings"
14
"time"
15
)
16
17
type UDP struct {
18
p utils.ConcurrentMap
19
outPool utils.OutPool
20
cfg UDPArgs
21
sc *utils.ServerChannel
22
}
23
24
func NewUDP() Service {
25
return &UDP{
26
outPool: utils.OutPool{},
27
p: utils.NewConcurrentMap(),
28
}
29
}
30
func (s *UDP) InitService() {
31
if *s.cfg.ParentType != TYPE_UDP {
32
s.InitOutConnPool()
33
}
34
}
35
func (s *UDP) StopService() {
36
if s.outPool.Pool != nil {
37
s.outPool.Pool.ReleaseAll()
38
}
39
}
40
func (s *UDP) Start(args interface{}) (err error) {
41
s.cfg = args.(UDPArgs)
42
if *s.cfg.Parent != "" {
43
log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
44
} else {
45
log.Fatalf("parent required for udp %s", *s.cfg.Local)
46
}
47
48
s.InitService()
49
50
host, port, _ := net.SplitHostPort(*s.cfg.Local)
51
p, _ := strconv.Atoi(port)
52
sc := utils.NewServerChannel(host, p)
53
s.sc = &sc
54
err = sc.ListenUDP(s.callback)
55
if err != nil {
56
return
57
}
58
log.Printf("udp proxy on %s", (*sc.UDPListener).LocalAddr())
59
return
60
}
61
62
func (s *UDP) Clean() {
63
s.StopService()
64
}
65
func (s *UDP) callback(packet []byte, localAddr, srcAddr *net.UDPAddr) {
66
defer func() {
67
if err := recover(); err != nil {
68
log.Printf("udp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
69
}
70
}()
71
var err error
72
switch *s.cfg.ParentType {
73
case TYPE_TCP:
74
fallthrough
75
case TYPE_TLS:
76
err = s.OutToTCP(packet, localAddr, srcAddr)
77
case TYPE_UDP:
78
err = s.OutToUDP(packet, localAddr, srcAddr)
79
default:
80
err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType)
81
}
82
if err != nil {
83
log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
84
}
85
}
86
func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) {
87
isNew = !s.p.Has(connKey)
88
var _conn interface{}
89
if isNew {
90
_conn, err = s.outPool.Pool.Get()
91
if err != nil {
92
return nil, false, err
93
}
94
s.p.Set(connKey, _conn)
95
} else {
96
_conn, _ = s.p.Get(connKey)
97
}
98
conn = _conn.(net.Conn)
99
return
100
}
101
func (s *UDP) OutToTCP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
102
numLocal := crc32.ChecksumIEEE([]byte(localAddr.String()))
103
numSrc := crc32.ChecksumIEEE([]byte(srcAddr.String()))
104
mod := uint32(*s.cfg.PoolSize)
105
if mod == 0 {
106
mod = 10
107
}
108
connKey := uint64((numLocal/10)*10 + numSrc%mod)
109
conn, isNew, err := s.GetConn(fmt.Sprintf("%d", connKey))
110
if err != nil {
111
log.Printf("upd get conn to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
112
return
113
}
114
if isNew {
115
go func() {
116
defer func() {
117
if err := recover(); err != nil {
118
log.Printf("udp conn handler out to tcp crashed with err : %s \nstack: %s", err, string(debug.Stack()))
119
}
120
}()
121
log.Printf("conn %d created , local: %s", connKey, srcAddr.String())
122
for {
123
srcAddrFromConn, body, err := utils.ReadUDPPacket(&conn)
124
if err == io.EOF || err == io.ErrUnexpectedEOF {
125
//log.Printf("connection %d released", connKey)
126
s.p.Remove(fmt.Sprintf("%d", connKey))
127
break
128
}
129
if err != nil {
130
log.Printf("parse revecived udp packet fail, err: %s", err)
131
continue
132
}
133
//log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn)
134
_srcAddr := strings.Split(srcAddrFromConn, ":")
135
if len(_srcAddr) != 2 {
136
log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn)
137
continue
138
}
139
port, _ := strconv.Atoi(_srcAddr[1])
140
dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port}
141
_, err = s.sc.UDPListener.WriteToUDP(body, dstAddr)
142
if err != nil {
143
log.Printf("udp response to local %s fail,ERR:%s", srcAddr, err)
144
continue
145
}
146
//log.Printf("udp response to local %s success", srcAddr)
147
}
148
}()
149
}
150
//log.Printf("select conn %d , local: %s", connKey, srcAddr.String())
151
writer := bufio.NewWriter(conn)
152
//fmt.Println(conn, writer)
153
writer.Write(utils.UDPPacket(srcAddr.String(), packet))
154
err = writer.Flush()
155
if err != nil {
156
log.Printf("write udp packet to %s fail ,flush err:%s", *s.cfg.Parent, err)
157
return
158
}
159
//log.Printf("write packet %v", packet)
160
return
161
}
162
func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
163
//log.Printf("udp packet revecived:%s,%v", srcAddr, packet)
164
dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent)
165
if err != nil {
166
log.Printf("resolve udp addr %s fail fail,ERR:%s", dstAddr.String(), err)
167
return
168
}
169
clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
170
conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
171
if err != nil {
172
log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
173
return
174
}
175
conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
176
_, err = conn.Write(packet)
177
if err != nil {
178
log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
179
return
180
}
181
//log.Printf("send udp packet to %s success", dstAddr.String())
182
buf := make([]byte, 512)
183
len, _, err := conn.ReadFromUDP(buf)
184
if err != nil {
185
log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
186
return
187
}
188
//log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
189
_, err = s.sc.UDPListener.WriteToUDP(buf[0:len], srcAddr)
190
if err != nil {
191
log.Printf("send udp response to cluster fail ,ERR:%s", err)
192
return
193
}
194
//log.Printf("send udp response to cluster success ,from:%s", dstAddr.String())
195
return
196
}
197
func (s *UDP) InitOutConnPool() {
198
if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP {
199
//dur int, isTLS bool, certBytes, keyBytes []byte,
200
//parent string, timeout int, InitialCap int, MaxCap int
201
s.outPool = utils.NewOutPool(
202
*s.cfg.CheckParentInterval,
203
*s.cfg.ParentType == TYPE_TLS,
204
s.cfg.CertBytes, s.cfg.KeyBytes,
205
*s.cfg.Parent,
206
*s.cfg.Timeout,
207
*s.cfg.PoolSize,
208
*s.cfg.PoolSize*2,
209
)
210
}
211
}
212
213