github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/core/dst/mux.go (about) 1 // Copyright 2014 The DST Authors. All rights reserved. 2 // Use of this source code is governed by an MIT-style 3 // license that can be found in the LICENSE file. 4 5 package dst 6 7 import ( 8 "fmt" 9 "runtime/debug" 10 11 "net" 12 "sync" 13 "time" 14 ) 15 16 const ( 17 maxIncomingRequests = 1024 18 maxPacketSize = 500 19 handshakeTimeout = 5 * time.Second 20 handshakeInterval = 1 * time.Second 21 ) 22 23 // Mux is a UDP multiplexer of DST connections. 24 type Mux struct { 25 conn net.PacketConn 26 packetSize int 27 28 conns map[connectionID]*Conn 29 handshakes map[connectionID]chan packet 30 connsMut sync.Mutex 31 32 incoming chan *Conn 33 closed chan struct{} 34 closeOnce sync.Once 35 36 buffers *sync.Pool 37 } 38 39 // NewMux creates a new DST Mux on top of a packet connection. 40 func NewMux(conn net.PacketConn, packetSize int) *Mux { 41 if packetSize <= 0 { 42 packetSize = maxPacketSize 43 } 44 m := &Mux{ 45 conn: conn, 46 packetSize: packetSize, 47 conns: map[connectionID]*Conn{}, 48 handshakes: make(map[connectionID]chan packet), 49 incoming: make(chan *Conn, maxIncomingRequests), 50 closed: make(chan struct{}), 51 buffers: &sync.Pool{ 52 New: func() interface{} { 53 return make([]byte, packetSize) 54 }, 55 }, 56 } 57 58 // Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5 59 // MB at a time. 60 61 if conn, ok := conn.(*net.UDPConn); ok { 62 for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 { 63 err := conn.SetReadBuffer(buf) 64 if err == nil { 65 if debugMux { 66 log.Println(m, "read buffer is", buf) 67 } 68 break 69 } 70 } 71 for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 { 72 err := conn.SetWriteBuffer(buf) 73 if err == nil { 74 if debugMux { 75 log.Println(m, "write buffer is", buf) 76 } 77 break 78 } 79 } 80 } 81 82 go func() { 83 defer func() { 84 if e := recover(); e != nil { 85 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 86 } 87 }() 88 m.readerLoop() 89 }() 90 return m 91 } 92 93 // Accept waits for and returns the next connection to the listener. 94 func (m *Mux) Accept() (net.Conn, error) { 95 return m.AcceptDST() 96 } 97 98 // AcceptDST waits for and returns the next connection to the listener. 99 func (m *Mux) AcceptDST() (*Conn, error) { 100 conn, ok := <-m.incoming 101 if !ok { 102 return nil, ErrClosedMux 103 } 104 return conn, nil 105 } 106 107 // Close closes the listener. 108 // Any blocked Accept operations will be unblocked and return errors. 109 func (m *Mux) Close() error { 110 var err error = ErrClosedMux 111 m.closeOnce.Do(func() { 112 err = m.conn.Close() 113 close(m.incoming) 114 close(m.closed) 115 }) 116 return err 117 } 118 119 // Addr returns the listener's network address. 120 func (m *Mux) Addr() net.Addr { 121 return m.conn.LocalAddr() 122 } 123 124 // Dial connects to the address on the named network. 125 // 126 // Network must be "dst". 127 // 128 // Addresses have the form host:port. If host is a literal IPv6 address or 129 // host name, it must be enclosed in square brackets as in "[::1]:80", 130 // "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and 131 // SplitHostPort manipulate addresses in this form. 132 // 133 // Examples: 134 // Dial("dst", "12.34.56.78:80") 135 // Dial("dst", "google.com:http") 136 // Dial("dst", "[2001:db8::1]:http") 137 // Dial("dst", "[fe80::1%lo0]:80") 138 func (m *Mux) Dial(network, addr string) (net.Conn, error) { 139 return m.DialDST(network, addr) 140 } 141 142 // Dial connects to the address on the named network. 143 // 144 // Network must be "dst". 145 // 146 // Addresses have the form host:port. If host is a literal IPv6 address or 147 // host name, it must be enclosed in square brackets as in "[::1]:80", 148 // "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and 149 // SplitHostPort manipulate addresses in this form. 150 // 151 // Examples: 152 // Dial("dst", "12.34.56.78:80") 153 // Dial("dst", "google.com:http") 154 // Dial("dst", "[2001:db8::1]:http") 155 // Dial("dst", "[fe80::1%lo0]:80") 156 func (m *Mux) DialDST(network, addr string) (*Conn, error) { 157 if network != "dst" { 158 return nil, ErrNotDST 159 } 160 161 dst, err := net.ResolveUDPAddr("udp", addr) 162 if err != nil { 163 return nil, err 164 } 165 166 resp := make(chan packet) 167 168 m.connsMut.Lock() 169 connID := m.newConnID() 170 m.handshakes[connID] = resp 171 m.connsMut.Unlock() 172 173 conn, err := m.clientHandshake(dst, connID, resp) 174 175 m.connsMut.Lock() 176 defer m.connsMut.Unlock() 177 delete(m.handshakes, connID) 178 179 if err != nil { 180 return nil, err 181 } 182 183 m.conns[connID] = conn 184 return conn, nil 185 } 186 187 // handshake performs the client side handshake (i.e. Dial) 188 func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) { 189 if debugMux { 190 log.Printf("%v dial %v connID %v", m, dst, connID) 191 } 192 193 nextHandshake := time.NewTimer(0) 194 defer nextHandshake.Stop() 195 196 handshakeTimeout := time.NewTimer(handshakeTimeout) 197 defer handshakeTimeout.Stop() 198 199 var remoteCookie uint32 200 seqNo := randomSeqNo() 201 202 for { 203 select { 204 case <-m.closed: 205 // Failure. The mux has been closed. 206 return nil, ErrClosedConn 207 208 case <-handshakeTimeout.C: 209 // Handshake timeout. Close and abort. 210 return nil, ErrHandshakeTimeout 211 212 case <-nextHandshake.C: 213 // Send a handshake request. 214 215 m.write(packet{ 216 src: connID, 217 dst: dst, 218 hdr: header{ 219 packetType: typeHandshake, 220 flags: flagRequest, 221 connID: 0, 222 sequenceNo: seqNo, 223 timestamp: timestampMicros(), 224 }, 225 data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(), 226 }) 227 nextHandshake.Reset(handshakeInterval) 228 229 case pkt := <-resp: 230 hd := unmarshalHandshakeData(pkt.data) 231 232 if pkt.hdr.flags&flagCookie == flagCookie { 233 // We should resend the handshake request with a different cookie value. 234 remoteCookie = hd.cookie 235 nextHandshake.Reset(0) 236 } else if pkt.hdr.flags&flagResponse == flagResponse { 237 // Successfull handshake response. 238 conn := newConn(m, dst) 239 240 conn.connID = connID 241 conn.remoteConnID = hd.connID 242 conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1 243 conn.packetSize = int(hd.packetSize) 244 if conn.packetSize > m.packetSize { 245 conn.packetSize = m.packetSize 246 } 247 248 conn.nextSeqNo = seqNo + 1 249 250 conn.start() 251 252 return conn, nil 253 } 254 } 255 } 256 } 257 258 func (m *Mux) readerLoop() { 259 buf := make([]byte, m.packetSize) 260 for { 261 buf = buf[:cap(buf)] 262 n, from, err := m.conn.ReadFrom(buf) 263 if err != nil { 264 m.Close() 265 return 266 } 267 buf = buf[:n] 268 269 hdr := unmarshalHeader(buf) 270 271 var bufCopy []byte 272 if len(buf) > dstHeaderLen { 273 bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen] 274 copy(bufCopy, buf[dstHeaderLen:]) 275 } 276 277 pkt := packet{hdr: hdr, data: bufCopy} 278 if debugMux { 279 log.Println(m, "read", pkt) 280 } 281 282 if hdr.packetType == typeHandshake { 283 m.incomingHandshake(from, hdr, bufCopy) 284 } else { 285 m.connsMut.Lock() 286 conn, ok := m.conns[hdr.connID] 287 m.connsMut.Unlock() 288 289 if ok { 290 conn.in <- packet{ 291 dst: nil, 292 hdr: hdr, 293 data: bufCopy, 294 } 295 } else if debugMux && hdr.packetType != typeShutdown { 296 log.Printf("packet %v for unknown conn %v", hdr, hdr.connID) 297 } 298 } 299 } 300 } 301 302 func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) { 303 if hdr.connID == 0 { 304 // A new incoming handshake request. 305 m.incomingHandshakeRequest(from, hdr, data) 306 } else { 307 // A response to an ongoing handshake. 308 m.incomingHandshakeResponse(from, hdr, data) 309 } 310 } 311 312 func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) { 313 if hdr.flags&flagRequest != flagRequest { 314 log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags) 315 return 316 } 317 318 hd := unmarshalHandshakeData(data) 319 320 correctCookie := cookie(from) 321 if hd.cookie != correctCookie { 322 // Incorrect or missing SYN cookie. Send back a handshake 323 // with the expected one. 324 m.write(packet{ 325 dst: from, 326 hdr: header{ 327 packetType: typeHandshake, 328 flags: flagResponse | flagCookie, 329 connID: hd.connID, 330 timestamp: timestampMicros(), 331 }, 332 data: handshakeData{ 333 packetSize: uint32(m.packetSize), 334 cookie: correctCookie, 335 }.marshal(), 336 }) 337 return 338 } 339 340 seqNo := randomSeqNo() 341 342 m.connsMut.Lock() 343 connID := m.newConnID() 344 345 conn := newConn(m, from) 346 conn.connID = connID 347 conn.remoteConnID = hd.connID 348 conn.nextSeqNo = seqNo + 1 349 conn.nextRecvSeqNo = hdr.sequenceNo + 1 350 conn.packetSize = int(hd.packetSize) 351 if conn.packetSize > m.packetSize { 352 conn.packetSize = m.packetSize 353 } 354 conn.start() 355 356 m.conns[connID] = conn 357 m.connsMut.Unlock() 358 359 m.write(packet{ 360 dst: from, 361 hdr: header{ 362 packetType: typeHandshake, 363 flags: flagResponse, 364 connID: hd.connID, 365 sequenceNo: seqNo, 366 timestamp: timestampMicros(), 367 }, 368 data: handshakeData{ 369 connID: conn.connID, 370 packetSize: uint32(conn.packetSize), 371 }.marshal(), 372 }) 373 374 m.incoming <- conn 375 } 376 377 func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) { 378 m.connsMut.Lock() 379 handShake, ok := m.handshakes[hdr.connID] 380 m.connsMut.Unlock() 381 382 if ok { 383 // This is a response to a handshake in progress. 384 handShake <- packet{ 385 dst: nil, 386 hdr: hdr, 387 data: data, 388 } 389 } else if debugMux && hdr.packetType != typeShutdown { 390 log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID) 391 } 392 } 393 394 func (m *Mux) write(pkt packet) (int, error) { 395 buf := m.buffers.Get().([]byte) 396 buf = buf[:dstHeaderLen+len(pkt.data)] 397 pkt.hdr.marshal(buf) 398 copy(buf[dstHeaderLen:], pkt.data) 399 if debugMux { 400 log.Println(m, "write", pkt) 401 } 402 n, err := m.conn.WriteTo(buf, pkt.dst) 403 m.buffers.Put(buf) 404 return n, err 405 } 406 407 func (m *Mux) String() string { 408 return fmt.Sprintf("Mux-%v", m.Addr()) 409 } 410 411 // Find a unique connection ID 412 func (m *Mux) newConnID() connectionID { 413 for { 414 connID := randomConnID() 415 if _, ok := m.conns[connID]; ok { 416 continue 417 } 418 if _, ok := m.handshakes[connID]; ok { 419 continue 420 } 421 return connID 422 } 423 } 424 425 func (m *Mux) removeConn(c *Conn) { 426 m.connsMut.Lock() 427 delete(m.conns, c.connID) 428 m.connsMut.Unlock() 429 }