github.com/polevpn/netstack@v1.10.9/example/forward.go (about) 1 package example 2 3 import ( 4 "errors" 5 "io" 6 "net" 7 "runtime/debug" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/polevpn/elog" 13 "github.com/polevpn/netstack/tcpip" 14 "github.com/polevpn/netstack/tcpip/buffer" 15 "github.com/polevpn/netstack/tcpip/link/channel" 16 "github.com/polevpn/netstack/tcpip/network/arp" 17 "github.com/polevpn/netstack/tcpip/network/ipv4" 18 "github.com/polevpn/netstack/tcpip/stack" 19 "github.com/polevpn/netstack/tcpip/transport/tcp" 20 "github.com/polevpn/netstack/tcpip/transport/udp" 21 "github.com/polevpn/netstack/waiter" 22 ) 23 24 const ( 25 TCP_MAX_CONNECTION_SIZE = 1024 26 FORWARD_CH_WRITE_SIZE = 4096 27 UDP_MAX_BUFFER_SIZE = 8192 28 TCP_MAX_BUFFER_SIZE = 8192 29 UDP_READ_BUFFER_SIZE = 524288 30 UDP_WRITE_BUFFER_SIZE = 262144 31 TCP_READ_BUFFER_SIZE = 524288 32 TCP_WRITE_BUFFER_SIZE = 262144 33 UDP_CONNECTION_IDLE_TIME = 1 34 CH_WRITE_SIZE = 100 35 TCP_CONNECT_TIMEOUT = 5 36 TCP_CONNECT_RETRY = 3 37 ) 38 39 type LocalForwarder struct { 40 s *stack.Stack 41 ep *channel.Endpoint 42 wq *waiter.Queue 43 closed bool 44 handler func([]byte) 45 localip string 46 } 47 48 func PanicHandler() { 49 if err := recover(); err != nil { 50 elog.Error("Panic Exception:", err) 51 elog.Error(string(debug.Stack())) 52 } 53 } 54 55 func NewLocalForwarder() (*LocalForwarder, error) { 56 57 forwarder := &LocalForwarder{} 58 59 //create MAC address 60 maddr, err := net.ParseMAC("01:01:01:01:01:01") 61 if err != nil { 62 return nil, err 63 } 64 65 // Create the net stack with ip and tcp protocols, then add a tun-based 66 // NIC and address. 67 s := stack.New(stack.Options{ 68 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, 69 TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, 70 }) 71 72 //create link channel for packet input 73 ep := channel.New(FORWARD_CH_WRITE_SIZE, 1500, tcpip.LinkAddress(maddr)) 74 75 //create NIC 76 if err := s.CreateNIC(1, ep); err != nil { 77 return nil, errors.New(err.String()) 78 } 79 80 //create a subnet for 0.0.0.0/0 81 subnet1, err := tcpip.NewSubnet(tcpip.Address(net.IPv4(0, 0, 0, 0).To4()), tcpip.AddressMask(net.IPv4Mask(0, 0, 0, 0))) 82 if err != nil { 83 return nil, err 84 } 85 86 //add 0.0.0.0/0 to netstack,then netstack can process destination address in "0.0.0.0/0" 87 if err := s.AddAddressRange(1, ipv4.ProtocolNumber, subnet1); err != nil { 88 return nil, errors.New(err.String()) 89 } 90 91 //add arp address 92 if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 93 return nil, errors.New(err.String()) 94 } 95 96 subnet, err := tcpip.NewSubnet(tcpip.Address(net.IPv4(0, 0, 0, 0).To4()), tcpip.AddressMask(net.IPv4Mask(0, 0, 0, 0))) 97 if err != nil { 98 return nil, err 99 } 100 // Add default route. 101 s.SetRouteTable([]tcpip.Route{ 102 { 103 Destination: subnet, 104 NIC: 1, 105 }, 106 }) 107 108 //create udp forwarder 109 uf := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { 110 go forwarder.forwardUDP(r) 111 }) 112 113 //set udp packet handler 114 s.SetTransportProtocolHandler(udp.ProtocolNumber, uf.HandlePacket) 115 116 //create tcp forworder 117 tf := tcp.NewForwarder(s, 0, TCP_MAX_CONNECTION_SIZE, func(r *tcp.ForwarderRequest) { 118 go forwarder.forwardTCP(r) 119 }) 120 //set tcp packet handler 121 s.SetTransportProtocolHandler(tcp.ProtocolNumber, tf.HandlePacket) 122 forwarder.closed = false 123 forwarder.s = s 124 forwarder.ep = ep 125 forwarder.wq = &waiter.Queue{} 126 return forwarder, nil 127 128 } 129 130 func (lf *LocalForwarder) SetPacketHandler(handler func([]byte)) { 131 lf.handler = handler 132 } 133 134 func (lf *LocalForwarder) SetLocalIP(ip string) { 135 lf.localip = ip 136 } 137 138 //packet from tun device tcp/ip 139 func (lf *LocalForwarder) Write(pkg []byte) { 140 if lf.closed { 141 return 142 } 143 pkgBuffer := tcpip.PacketBuffer{Data: buffer.NewViewFromBytes(pkg).ToVectorisedView()} 144 lf.ep.InjectInbound(ipv4.ProtocolNumber, pkgBuffer) 145 } 146 147 //packet from netstack 148 func (lf *LocalForwarder) read() { 149 for { 150 pkgInfo, err := lf.ep.Read() 151 if err != nil { 152 elog.Info(err) 153 return 154 } 155 view := buffer.NewVectorisedView(1, []buffer.View{pkgInfo.Pkt.Header.View()}) 156 view.Append(pkgInfo.Pkt.Data) 157 if lf.handler != nil { 158 lf.handler(view.ToView()) 159 } 160 } 161 } 162 163 func (lf *LocalForwarder) StartProcess() { 164 go lf.read() 165 } 166 167 func (lf *LocalForwarder) ClearConnect() { 168 lf.wq.Notify(waiter.EventIn) 169 } 170 171 func (lf *LocalForwarder) Close() { 172 defer PanicHandler() 173 174 if lf.closed { 175 return 176 } 177 lf.closed = true 178 179 lf.wq.Notify(waiter.EventIn) 180 time.Sleep(time.Millisecond * 100) 181 lf.ep.Close() 182 } 183 184 func (lf *LocalForwarder) forwardTCP(r *tcp.ForwarderRequest) { 185 186 wq := &waiter.Queue{} 187 ep, err := r.CreateEndpoint(wq) 188 if err != nil { 189 elog.Error("create tcp endpint error", err) 190 r.Complete(true) 191 return 192 } 193 194 if lf.closed { 195 r.Complete(true) 196 ep.Close() 197 return 198 } 199 200 elog.Debug(r.ID(), "tcp connect") 201 202 var err1 error 203 204 localip := lf.localip 205 var laddr *net.TCPAddr 206 if localip != "" { 207 laddr, _ = net.ResolveTCPAddr("tcp4", localip+":0") 208 } 209 210 addr, _ := ep.GetLocalAddress() 211 raddr := addr.Addr.String() + ":" + strconv.Itoa(int(addr.Port)) 212 var conn net.Conn 213 for i := 0; i < TCP_CONNECT_RETRY; i++ { 214 d := net.Dialer{Timeout: time.Second * TCP_CONNECT_TIMEOUT, LocalAddr: laddr} 215 conn, err1 = d.Dial("tcp4", raddr) 216 if err1 != nil { 217 continue 218 } 219 break 220 } 221 222 if err1 != nil { 223 elog.Println("conn dial fail,", err1) 224 r.Complete(true) 225 ep.Close() 226 return 227 } 228 229 tcpconn := conn.(*net.TCPConn) 230 tcpconn.SetNoDelay(true) 231 tcpconn.SetKeepAlive(true) 232 tcpconn.SetWriteBuffer(TCP_WRITE_BUFFER_SIZE) 233 tcpconn.SetReadBuffer(TCP_READ_BUFFER_SIZE) 234 tcpconn.SetKeepAlivePeriod(time.Second * 15) 235 236 go lf.tcpRead(r, wq, ep, conn) 237 go lf.tcpWrite(r, wq, ep, conn) 238 } 239 240 func (lf *LocalForwarder) udpRead(r *udp.ForwarderRequest, ep tcpip.Endpoint, wq *waiter.Queue, conn *net.UDPConn, timer *time.Ticker) { 241 242 defer func() { 243 elog.Debug(r.ID(), "udp closed") 244 ep.Close() 245 conn.Close() 246 }() 247 248 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 249 wq.EventRegister(&waitEntry, waiter.EventIn) 250 defer wq.EventUnregister(&waitEntry) 251 252 gwaitEntry, gnotifyCh := waiter.NewChannelEntry(nil) 253 254 lf.wq.EventRegister(&gwaitEntry, waiter.EventIn) 255 defer lf.wq.EventUnregister(&gwaitEntry) 256 257 wch := make(chan []byte, CH_WRITE_SIZE) 258 259 defer close(wch) 260 261 writer := func() { 262 for { 263 pkt, ok := <-wch 264 if !ok { 265 elog.Debug("udp wch closed,exit write process") 266 return 267 } else { 268 _, err1 := conn.Write(pkt) 269 if err1 != nil { 270 if err1 != io.EOF && !strings.Contains(err1.Error(), "use of closed network connection") { 271 elog.Info("udp conn write error", err1) 272 } 273 return 274 } 275 } 276 } 277 } 278 279 go writer() 280 281 lastTime := time.Now() 282 283 for { 284 var addr tcpip.FullAddress 285 v, _, err := ep.Read(&addr) 286 if err != nil { 287 if err == tcpip.ErrWouldBlock { 288 289 select { 290 case <-notifyCh: 291 continue 292 case <-gnotifyCh: 293 return 294 case <-timer.C: 295 if time.Now().Sub(lastTime) > time.Minute*UDP_CONNECTION_IDLE_TIME { 296 elog.Infof("udp %v connection expired,close it", r.ID()) 297 timer.Stop() 298 return 299 } else { 300 continue 301 } 302 } 303 } else if err != tcpip.ErrClosedForReceive && err != tcpip.ErrClosedForSend { 304 elog.Info("udp ep read fail,", err) 305 } 306 return 307 } 308 309 wch <- v 310 lastTime = time.Now() 311 } 312 } 313 314 func (lf *LocalForwarder) udpWrite(r *udp.ForwarderRequest, ep tcpip.Endpoint, wq *waiter.Queue, conn *net.UDPConn, addr *tcpip.FullAddress) { 315 316 defer func() { 317 ep.Close() 318 conn.Close() 319 }() 320 321 for { 322 var udppkg []byte = make([]byte, UDP_MAX_BUFFER_SIZE) 323 n, err1 := conn.Read(udppkg) 324 325 if err1 != nil { 326 if err1 != io.EOF && 327 !strings.Contains(err1.Error(), "use of closed network connection") && 328 !strings.Contains(err1.Error(), "connection refused") { 329 elog.Info("udp conn read error,", err1) 330 } 331 return 332 } 333 udppkg1 := udppkg[:n] 334 _, _, err := ep.Write(tcpip.SlicePayload(udppkg1), tcpip.WriteOptions{To: addr}) 335 if err != nil { 336 elog.Info("udp ep write fail,", err) 337 return 338 } 339 } 340 } 341 342 func (lf *LocalForwarder) forwardUDP(r *udp.ForwarderRequest) { 343 wq := &waiter.Queue{} 344 ep, err := r.CreateEndpoint(wq) 345 if err != nil { 346 elog.Error("create udp endpint error", err) 347 return 348 } 349 350 if lf.closed { 351 ep.Close() 352 return 353 } 354 355 elog.Debug(r.ID(), "udp connect") 356 357 localip := lf.localip 358 var err1 error 359 var laddr *net.UDPAddr 360 if localip != "" { 361 laddr, _ = net.ResolveUDPAddr("udp4", localip+":0") 362 } 363 364 raddr, _ := net.ResolveUDPAddr("udp4", r.ID().LocalAddress.To4().String()+":"+strconv.Itoa(int(r.ID().LocalPort))) 365 366 conn, err1 := net.DialUDP("udp4", laddr, raddr) 367 if err1 != nil { 368 elog.Error("udp conn dial error ", err1) 369 ep.Close() 370 return 371 } 372 373 conn.SetReadBuffer(UDP_READ_BUFFER_SIZE) 374 conn.SetWriteBuffer(UDP_WRITE_BUFFER_SIZE) 375 376 timer := time.NewTicker(time.Minute) 377 addr := &tcpip.FullAddress{Addr: r.ID().RemoteAddress, Port: r.ID().RemotePort} 378 379 go lf.udpRead(r, ep, wq, conn, timer) 380 go lf.udpWrite(r, ep, wq, conn, addr) 381 } 382 383 func (lf *LocalForwarder) tcpRead(r *tcp.ForwarderRequest, wq *waiter.Queue, ep tcpip.Endpoint, conn net.Conn) { 384 defer func() { 385 elog.Debug(r.ID(), "tcp closed") 386 r.Complete(true) 387 ep.Close() 388 conn.Close() 389 }() 390 391 // Create wait queue entry that notifies a channel. 392 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 393 394 wq.EventRegister(&waitEntry, waiter.EventIn) 395 defer wq.EventUnregister(&waitEntry) 396 397 // Create wait queue entry that notifies a channel. 398 gwaitEntry, gnotifyCh := waiter.NewChannelEntry(nil) 399 400 lf.wq.EventRegister(&gwaitEntry, waiter.EventIn) 401 defer lf.wq.EventUnregister(&gwaitEntry) 402 403 wch := make(chan []byte, CH_WRITE_SIZE) 404 405 defer close(wch) 406 407 writer := func() { 408 for { 409 pkt, ok := <-wch 410 if !ok { 411 elog.Debug("wch closed,exit write process") 412 return 413 } else { 414 _, err1 := conn.Write(pkt) 415 if err1 != nil { 416 if err1 != io.EOF && !strings.Contains(err1.Error(), "use of closed network connection") { 417 elog.Infof("tcp %v conn write error,%v", r.ID(), err1) 418 } 419 return 420 } 421 } 422 } 423 } 424 425 go writer() 426 427 for { 428 v, _, err := ep.Read(nil) 429 if err != nil { 430 431 if err == tcpip.ErrWouldBlock { 432 select { 433 case <-notifyCh: 434 continue 435 case <-gnotifyCh: 436 return 437 } 438 439 } else if err != tcpip.ErrClosedForReceive && err != tcpip.ErrClosedForSend { 440 elog.Infof("tcp %v endpoint read fail,%v", r.ID(), err) 441 } 442 return 443 } 444 wch <- v 445 } 446 } 447 448 func (lf *LocalForwarder) tcpWrite(r *tcp.ForwarderRequest, wq *waiter.Queue, ep tcpip.Endpoint, conn net.Conn) { 449 defer func() { 450 ep.Close() 451 conn.Close() 452 }() 453 454 for { 455 var buf []byte = make([]byte, TCP_MAX_BUFFER_SIZE) 456 n, err := conn.Read(buf) 457 if err != nil { 458 if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { 459 elog.Infof("tcp %v conn read error,%v", r.ID(), err) 460 } 461 break 462 } 463 464 ep.Write(tcpip.SlicePayload(buf[:n]), tcpip.WriteOptions{}) 465 } 466 }