github.com/lightlus/netstack@v1.2.0/example/forward.go (about)

     1  package example
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"net"
     7  	"runtime/debug"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/lightlus/elog"
    13  	"github.com/lightlus/netstack/tcpip"
    14  	"github.com/lightlus/netstack/tcpip/buffer"
    15  	"github.com/lightlus/netstack/tcpip/link/channel"
    16  	"github.com/lightlus/netstack/tcpip/network/arp"
    17  	"github.com/lightlus/netstack/tcpip/network/ipv4"
    18  	"github.com/lightlus/netstack/tcpip/stack"
    19  	"github.com/lightlus/netstack/tcpip/transport/tcp"
    20  	"github.com/lightlus/netstack/tcpip/transport/udp"
    21  	"github.com/lightlus/netstack/waiter"
    22  )
    23  
    24  const (
    25  	TCP_MAX_CONNECTION_SIZE  = 1024
    26  	FORWARD_CH_WRITE_SIZE    = 4096
    27  	UDP_MAX_BUFFER_SIZE      = 8192
    28  	TCP_MAX_BUFFER_SIZE      = 8192
    29  	UDP_READ_BUFFER_SIZE     = 524288
    30  	UDP_WRITE_BUFFER_SIZE    = 262144
    31  	TCP_READ_BUFFER_SIZE     = 524288
    32  	TCP_WRITE_BUFFER_SIZE    = 262144
    33  	UDP_CONNECTION_IDLE_TIME = 1
    34  	CH_WRITE_SIZE            = 100
    35  	TCP_CONNECT_TIMEOUT      = 5
    36  	TCP_CONNECT_RETRY        = 3
    37  )
    38  
    39  type LocalForwarder struct {
    40  	s       *stack.Stack
    41  	ep      *channel.Endpoint
    42  	wq      *waiter.Queue
    43  	closed  bool
    44  	handler func([]byte)
    45  	localip string
    46  }
    47  
    48  func PanicHandler() {
    49  	if err := recover(); err != nil {
    50  		elog.Error("Panic Exception:", err)
    51  		elog.Error(string(debug.Stack()))
    52  	}
    53  }
    54  
    55  func NewLocalForwarder() (*LocalForwarder, error) {
    56  
    57  	forwarder := &LocalForwarder{}
    58  
    59  	//create MAC address
    60  	maddr, err := net.ParseMAC("01:01:01:01:01:01")
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	// Create the net stack with ip and tcp protocols, then add a tun-based
    66  	// NIC and address.
    67  	s := stack.New(stack.Options{
    68  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
    69  		TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
    70  	})
    71  
    72  	//create link channel for packet input
    73  	ep := channel.New(FORWARD_CH_WRITE_SIZE, 1500, tcpip.LinkAddress(maddr))
    74  
    75  	//create NIC
    76  	if err := s.CreateNIC(1, ep); err != nil {
    77  		return nil, errors.New(err.String())
    78  	}
    79  
    80  	//create a subnet for 0.0.0.0/0
    81  	subnet1, err := tcpip.NewSubnet(tcpip.Address(net.IPv4(0, 0, 0, 0).To4()), tcpip.AddressMask(net.IPv4Mask(0, 0, 0, 0)))
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	//add 0.0.0.0/0 to netstack,then netstack can process destination address in "0.0.0.0/0"
    87  	if err := s.AddAddressRange(1, ipv4.ProtocolNumber, subnet1); err != nil {
    88  		return nil, errors.New(err.String())
    89  	}
    90  
    91  	//add arp address
    92  	if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
    93  		return nil, errors.New(err.String())
    94  	}
    95  
    96  	subnet, err := tcpip.NewSubnet(tcpip.Address(net.IPv4(0, 0, 0, 0).To4()), tcpip.AddressMask(net.IPv4Mask(0, 0, 0, 0)))
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	// Add default route.
   101  	s.SetRouteTable([]tcpip.Route{
   102  		{
   103  			Destination: subnet,
   104  			NIC:         1,
   105  		},
   106  	})
   107  
   108  	//create udp forwarder
   109  	uf := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
   110  		go forwarder.forwardUDP(r)
   111  	})
   112  
   113  	//set udp packet handler
   114  	s.SetTransportProtocolHandler(udp.ProtocolNumber, uf.HandlePacket)
   115  
   116  	//create tcp forworder
   117  	tf := tcp.NewForwarder(s, 0, TCP_MAX_CONNECTION_SIZE, func(r *tcp.ForwarderRequest) {
   118  		go forwarder.forwardTCP(r)
   119  	})
   120  	//set tcp packet handler
   121  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, tf.HandlePacket)
   122  	forwarder.closed = false
   123  	forwarder.s = s
   124  	forwarder.ep = ep
   125  	forwarder.wq = &waiter.Queue{}
   126  	return forwarder, nil
   127  
   128  }
   129  
   130  func (lf *LocalForwarder) SetPacketHandler(handler func([]byte)) {
   131  	lf.handler = handler
   132  }
   133  
   134  func (lf *LocalForwarder) SetLocalIP(ip string) {
   135  	lf.localip = ip
   136  }
   137  
   138  //packet from tun device tcp/ip
   139  func (lf *LocalForwarder) Write(pkg []byte) {
   140  	if lf.closed {
   141  		return
   142  	}
   143  	pkgBuffer := tcpip.PacketBuffer{Data: buffer.NewViewFromBytes(pkg).ToVectorisedView()}
   144  	lf.ep.InjectInbound(ipv4.ProtocolNumber, pkgBuffer)
   145  }
   146  
   147  //packet from netstack
   148  func (lf *LocalForwarder) read() {
   149  	for {
   150  		pkgInfo, err := lf.ep.Read()
   151  		if err != nil {
   152  			elog.Info(err)
   153  			return
   154  		}
   155  		view := buffer.NewVectorisedView(1, []buffer.View{pkgInfo.Pkt.Header.View()})
   156  		view.Append(pkgInfo.Pkt.Data)
   157  		if lf.handler != nil {
   158  			lf.handler(view.ToView())
   159  		}
   160  	}
   161  }
   162  
   163  func (lf *LocalForwarder) StartProcess() {
   164  	go lf.read()
   165  }
   166  
   167  func (lf *LocalForwarder) ClearConnect() {
   168  	lf.wq.Notify(waiter.EventIn)
   169  }
   170  
   171  func (lf *LocalForwarder) Close() {
   172  	defer PanicHandler()
   173  
   174  	if lf.closed {
   175  		return
   176  	}
   177  	lf.closed = true
   178  
   179  	lf.wq.Notify(waiter.EventIn)
   180  	time.Sleep(time.Millisecond * 100)
   181  	lf.ep.Close()
   182  }
   183  
   184  func (lf *LocalForwarder) forwardTCP(r *tcp.ForwarderRequest) {
   185  
   186  	wq := &waiter.Queue{}
   187  	ep, err := r.CreateEndpoint(wq)
   188  	if err != nil {
   189  		elog.Error("create tcp endpint error", err)
   190  		r.Complete(true)
   191  		return
   192  	}
   193  
   194  	if lf.closed {
   195  		r.Complete(true)
   196  		ep.Close()
   197  		return
   198  	}
   199  
   200  	elog.Debug(r.ID(), "tcp connect")
   201  
   202  	var err1 error
   203  
   204  	localip := lf.localip
   205  	var laddr *net.TCPAddr
   206  	if localip != "" {
   207  		laddr, _ = net.ResolveTCPAddr("tcp4", localip+":0")
   208  	}
   209  
   210  	addr, _ := ep.GetLocalAddress()
   211  	raddr := addr.Addr.String() + ":" + strconv.Itoa(int(addr.Port))
   212  	var conn net.Conn
   213  	for i := 0; i < TCP_CONNECT_RETRY; i++ {
   214  		d := net.Dialer{Timeout: time.Second * TCP_CONNECT_TIMEOUT, LocalAddr: laddr}
   215  		conn, err1 = d.Dial("tcp4", raddr)
   216  		if err1 != nil {
   217  			continue
   218  		}
   219  		break
   220  	}
   221  
   222  	if err1 != nil {
   223  		elog.Println("conn dial fail,", err1)
   224  		r.Complete(true)
   225  		ep.Close()
   226  		return
   227  	}
   228  
   229  	tcpconn := conn.(*net.TCPConn)
   230  	tcpconn.SetNoDelay(true)
   231  	tcpconn.SetKeepAlive(true)
   232  	tcpconn.SetWriteBuffer(TCP_WRITE_BUFFER_SIZE)
   233  	tcpconn.SetReadBuffer(TCP_READ_BUFFER_SIZE)
   234  	tcpconn.SetKeepAlivePeriod(time.Second * 15)
   235  
   236  	go lf.tcpRead(r, wq, ep, conn)
   237  	go lf.tcpWrite(r, wq, ep, conn)
   238  }
   239  
   240  func (lf *LocalForwarder) udpRead(r *udp.ForwarderRequest, ep tcpip.Endpoint, wq *waiter.Queue, conn *net.UDPConn, timer *time.Ticker) {
   241  
   242  	defer func() {
   243  		elog.Debug(r.ID(), "udp closed")
   244  		ep.Close()
   245  		conn.Close()
   246  	}()
   247  
   248  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   249  	wq.EventRegister(&waitEntry, waiter.EventIn)
   250  	defer wq.EventUnregister(&waitEntry)
   251  
   252  	gwaitEntry, gnotifyCh := waiter.NewChannelEntry(nil)
   253  
   254  	lf.wq.EventRegister(&gwaitEntry, waiter.EventIn)
   255  	defer lf.wq.EventUnregister(&gwaitEntry)
   256  
   257  	wch := make(chan []byte, CH_WRITE_SIZE)
   258  
   259  	defer close(wch)
   260  
   261  	writer := func() {
   262  		for {
   263  			pkt, ok := <-wch
   264  			if !ok {
   265  				elog.Debug("udp wch closed,exit write process")
   266  				return
   267  			} else {
   268  				_, err1 := conn.Write(pkt)
   269  				if err1 != nil {
   270  					if err1 != io.EOF && !strings.Contains(err1.Error(), "use of closed network connection") {
   271  						elog.Info("udp conn write error", err1)
   272  					}
   273  					return
   274  				}
   275  			}
   276  		}
   277  	}
   278  
   279  	go writer()
   280  
   281  	lastTime := time.Now()
   282  
   283  	for {
   284  		var addr tcpip.FullAddress
   285  		v, _, err := ep.Read(&addr)
   286  		if err != nil {
   287  			if err == tcpip.ErrWouldBlock {
   288  
   289  				select {
   290  				case <-notifyCh:
   291  					continue
   292  				case <-gnotifyCh:
   293  					return
   294  				case <-timer.C:
   295  					if time.Now().Sub(lastTime) > time.Minute*UDP_CONNECTION_IDLE_TIME {
   296  						elog.Infof("udp %v connection expired,close it", r.ID())
   297  						timer.Stop()
   298  						return
   299  					} else {
   300  						continue
   301  					}
   302  				}
   303  			} else if err != tcpip.ErrClosedForReceive && err != tcpip.ErrClosedForSend {
   304  				elog.Info("udp ep read fail,", err)
   305  			}
   306  			return
   307  		}
   308  
   309  		wch <- v
   310  		lastTime = time.Now()
   311  	}
   312  }
   313  
   314  func (lf *LocalForwarder) udpWrite(r *udp.ForwarderRequest, ep tcpip.Endpoint, wq *waiter.Queue, conn *net.UDPConn, addr *tcpip.FullAddress) {
   315  
   316  	defer func() {
   317  		ep.Close()
   318  		conn.Close()
   319  	}()
   320  
   321  	for {
   322  		var udppkg []byte = make([]byte, UDP_MAX_BUFFER_SIZE)
   323  		n, err1 := conn.Read(udppkg)
   324  
   325  		if err1 != nil {
   326  			if err1 != io.EOF &&
   327  				!strings.Contains(err1.Error(), "use of closed network connection") &&
   328  				!strings.Contains(err1.Error(), "connection refused") {
   329  				elog.Info("udp conn read error,", err1)
   330  			}
   331  			return
   332  		}
   333  		udppkg1 := udppkg[:n]
   334  		_, _, err := ep.Write(tcpip.SlicePayload(udppkg1), tcpip.WriteOptions{To: addr})
   335  		if err != nil {
   336  			elog.Info("udp ep write fail,", err)
   337  			return
   338  		}
   339  	}
   340  }
   341  
   342  func (lf *LocalForwarder) forwardUDP(r *udp.ForwarderRequest) {
   343  	wq := &waiter.Queue{}
   344  	ep, err := r.CreateEndpoint(wq)
   345  	if err != nil {
   346  		elog.Error("create udp endpint error", err)
   347  		return
   348  	}
   349  
   350  	if lf.closed {
   351  		ep.Close()
   352  		return
   353  	}
   354  
   355  	elog.Debug(r.ID(), "udp connect")
   356  
   357  	localip := lf.localip
   358  	var err1 error
   359  	var laddr *net.UDPAddr
   360  	if localip != "" {
   361  		laddr, _ = net.ResolveUDPAddr("udp4", localip+":0")
   362  	}
   363  
   364  	raddr, _ := net.ResolveUDPAddr("udp4", r.ID().LocalAddress.To4().String()+":"+strconv.Itoa(int(r.ID().LocalPort)))
   365  
   366  	conn, err1 := net.DialUDP("udp4", laddr, raddr)
   367  	if err1 != nil {
   368  		elog.Error("udp conn dial error ", err1)
   369  		ep.Close()
   370  		return
   371  	}
   372  
   373  	conn.SetReadBuffer(UDP_READ_BUFFER_SIZE)
   374  	conn.SetWriteBuffer(UDP_WRITE_BUFFER_SIZE)
   375  
   376  	timer := time.NewTicker(time.Minute)
   377  	addr := &tcpip.FullAddress{Addr: r.ID().RemoteAddress, Port: r.ID().RemotePort}
   378  
   379  	go lf.udpRead(r, ep, wq, conn, timer)
   380  	go lf.udpWrite(r, ep, wq, conn, addr)
   381  }
   382  
   383  func (lf *LocalForwarder) tcpRead(r *tcp.ForwarderRequest, wq *waiter.Queue, ep tcpip.Endpoint, conn net.Conn) {
   384  	defer func() {
   385  		elog.Debug(r.ID(), "tcp closed")
   386  		r.Complete(true)
   387  		ep.Close()
   388  		conn.Close()
   389  	}()
   390  
   391  	// Create wait queue entry that notifies a channel.
   392  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   393  
   394  	wq.EventRegister(&waitEntry, waiter.EventIn)
   395  	defer wq.EventUnregister(&waitEntry)
   396  
   397  	// Create wait queue entry that notifies a channel.
   398  	gwaitEntry, gnotifyCh := waiter.NewChannelEntry(nil)
   399  
   400  	lf.wq.EventRegister(&gwaitEntry, waiter.EventIn)
   401  	defer lf.wq.EventUnregister(&gwaitEntry)
   402  
   403  	wch := make(chan []byte, CH_WRITE_SIZE)
   404  
   405  	defer close(wch)
   406  
   407  	writer := func() {
   408  		for {
   409  			pkt, ok := <-wch
   410  			if !ok {
   411  				elog.Debug("wch closed,exit write process")
   412  				return
   413  			} else {
   414  				_, err1 := conn.Write(pkt)
   415  				if err1 != nil {
   416  					if err1 != io.EOF && !strings.Contains(err1.Error(), "use of closed network connection") {
   417  						elog.Infof("tcp %v conn write error,%v", r.ID(), err1)
   418  					}
   419  					return
   420  				}
   421  			}
   422  		}
   423  	}
   424  
   425  	go writer()
   426  
   427  	for {
   428  		v, _, err := ep.Read(nil)
   429  		if err != nil {
   430  
   431  			if err == tcpip.ErrWouldBlock {
   432  				select {
   433  				case <-notifyCh:
   434  					continue
   435  				case <-gnotifyCh:
   436  					return
   437  				}
   438  
   439  			} else if err != tcpip.ErrClosedForReceive && err != tcpip.ErrClosedForSend {
   440  				elog.Infof("tcp %v endpoint read fail,%v", r.ID(), err)
   441  			}
   442  			return
   443  		}
   444  		wch <- v
   445  	}
   446  }
   447  
   448  func (lf *LocalForwarder) tcpWrite(r *tcp.ForwarderRequest, wq *waiter.Queue, ep tcpip.Endpoint, conn net.Conn) {
   449  	defer func() {
   450  		ep.Close()
   451  		conn.Close()
   452  	}()
   453  
   454  	for {
   455  		var buf []byte = make([]byte, TCP_MAX_BUFFER_SIZE)
   456  		n, err := conn.Read(buf)
   457  		if err != nil {
   458  			if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
   459  				elog.Infof("tcp %v conn read error,%v", r.ID(), err)
   460  			}
   461  			break
   462  		}
   463  
   464  		ep.Write(tcpip.SlicePayload(buf[:n]), tcpip.WriteOptions{})
   465  	}
   466  }