github.com/ipfans/trojan-go@v0.11.0/tunnel/tproxy/server.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package tproxy
     5  
     6  import (
     7  	"context"
     8  	"io"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/ipfans/trojan-go/common"
    14  	"github.com/ipfans/trojan-go/config"
    15  	"github.com/ipfans/trojan-go/log"
    16  	"github.com/ipfans/trojan-go/tunnel"
    17  )
    18  
    19  const MaxPacketSize = 1024 * 8
    20  
    21  type Server struct {
    22  	tcpListener net.Listener
    23  	udpListener *net.UDPConn
    24  	packetChan  chan tunnel.PacketConn
    25  	timeout     time.Duration
    26  	mappingLock sync.RWMutex
    27  	mapping     map[string]*PacketConn
    28  	ctx         context.Context
    29  	cancel      context.CancelFunc
    30  }
    31  
    32  func (s *Server) Close() error {
    33  	s.cancel()
    34  	s.tcpListener.Close()
    35  	return s.udpListener.Close()
    36  }
    37  
    38  func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) {
    39  	conn, err := s.tcpListener.Accept()
    40  	if err != nil {
    41  		select {
    42  		case <-s.ctx.Done():
    43  		default:
    44  			log.Fatal(common.NewError("tproxy failed to accept connection").Base(err))
    45  		}
    46  		return nil, common.NewError("tproxy failed to accept conn")
    47  	}
    48  	dst, err := getOriginalTCPDest(conn.(*net.TCPConn))
    49  	if err != nil {
    50  		return nil, common.NewError("tproxy failed to obtain original address of tcp socket").Base(err)
    51  	}
    52  	address, err := tunnel.NewAddressFromAddr("tcp", dst.String())
    53  	common.Must(err)
    54  	log.Info("tproxy connection from", conn.RemoteAddr().String(), "metadata", dst.String())
    55  	return &Conn{
    56  		metadata: &tunnel.Metadata{
    57  			Address: address,
    58  		},
    59  		Conn: conn,
    60  	}, nil
    61  }
    62  
    63  func (s *Server) packetDispatchLoop() {
    64  	type tproxyPacketInfo struct {
    65  		src     *net.UDPAddr
    66  		dst     *net.UDPAddr
    67  		payload []byte
    68  	}
    69  	packetQueue := make(chan *tproxyPacketInfo, 1024)
    70  
    71  	go func() {
    72  		for {
    73  			buf := make([]byte, MaxPacketSize)
    74  			n, src, dst, err := ReadFromUDP(s.udpListener, buf)
    75  			if err != nil {
    76  				select {
    77  				case <-s.ctx.Done():
    78  				default:
    79  					log.Fatal(common.NewError("tproxy failed to read from udp socket").Base(err))
    80  				}
    81  				s.Close()
    82  				return
    83  			}
    84  			log.Debug("udp packet from", src, "metadata", dst, "size", n)
    85  			packetQueue <- &tproxyPacketInfo{
    86  				src:     src,
    87  				dst:     dst,
    88  				payload: buf[:n],
    89  			}
    90  		}
    91  	}()
    92  
    93  	for {
    94  		var info *tproxyPacketInfo
    95  		select {
    96  		case info = <-packetQueue:
    97  		case <-s.ctx.Done():
    98  			log.Debug("exiting")
    99  			return
   100  		}
   101  
   102  		s.mappingLock.RLock()
   103  		conn, found := s.mapping[info.src.String()]
   104  		s.mappingLock.RUnlock()
   105  
   106  		if !found {
   107  			ctx, cancel := context.WithCancel(s.ctx)
   108  			conn = &PacketConn{
   109  				input:      make(chan *packetInfo, 128),
   110  				output:     make(chan *packetInfo, 128),
   111  				PacketConn: s.udpListener,
   112  				ctx:        ctx,
   113  				cancel:     cancel,
   114  				src:        info.src,
   115  			}
   116  
   117  			s.mappingLock.Lock()
   118  			s.mapping[info.src.String()] = conn
   119  			s.mappingLock.Unlock()
   120  
   121  			log.Info("new tproxy udp session from", info.src.String(), "metadata", info.dst.String())
   122  			s.packetChan <- conn
   123  
   124  			go func(conn *PacketConn) {
   125  				defer conn.Close()
   126  				log.Debug("udp packet daemon for", conn.src.String())
   127  				for {
   128  					select {
   129  					case info := <-conn.output:
   130  						if info.metadata.AddressType != tunnel.IPv4 &&
   131  							info.metadata.AddressType != tunnel.IPv6 {
   132  							log.Error("tproxy invalid response metadata address", info.metadata)
   133  							continue
   134  						}
   135  						back, err := DialUDP(
   136  							"udp",
   137  							&net.UDPAddr{
   138  								IP:   info.metadata.IP,
   139  								Port: info.metadata.Port,
   140  							},
   141  							conn.src.(*net.UDPAddr),
   142  						)
   143  						if err != nil {
   144  							log.Error(common.NewError("failed to dial tproxy udp").Base(err))
   145  							return
   146  						}
   147  						n, err := back.Write(info.payload)
   148  						if err != nil {
   149  							log.Error(common.NewError("tproxy udp write error").Base(err))
   150  							return
   151  						}
   152  						log.Debug("recv packet, send back to", conn.src, "payload", len(info.payload), "sent", n)
   153  						back.Close()
   154  					case <-s.ctx.Done():
   155  						log.Debug("exiting")
   156  						return
   157  					case <-time.After(s.timeout):
   158  						s.mappingLock.Lock()
   159  						delete(s.mapping, conn.src.String())
   160  						s.mappingLock.Unlock()
   161  						log.Debug("packet session ", conn.src.String(), "timeout")
   162  						return
   163  					}
   164  				}
   165  			}(conn)
   166  		}
   167  
   168  		newInfo := &packetInfo{
   169  			metadata: &tunnel.Metadata{
   170  				Address: tunnel.NewAddressFromHostPort("udp", info.dst.IP.String(), info.dst.Port),
   171  			},
   172  			payload: info.payload,
   173  		}
   174  
   175  		select {
   176  		case conn.input <- newInfo:
   177  			log.Debug("tproxy packet sent with metadata", newInfo.metadata, "size", len(info.payload))
   178  		default:
   179  			// if we got too many packets, simply drop it
   180  			log.Warn("tproxy udp relay queue full!")
   181  		}
   182  	}
   183  }
   184  
   185  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   186  	select {
   187  	case conn := <-s.packetChan:
   188  		log.Info("tproxy packet conn accepted")
   189  		return conn, nil
   190  	case <-s.ctx.Done():
   191  		return nil, io.EOF
   192  	}
   193  }
   194  
   195  func NewServer(ctx context.Context, _ tunnel.Server) (*Server, error) {
   196  	cfg := config.FromContext(ctx, Name).(*Config)
   197  	ctx, cancel := context.WithCancel(ctx)
   198  	listenAddr := tunnel.NewAddressFromHostPort("tcp", cfg.LocalHost, cfg.LocalPort)
   199  	ip, err := listenAddr.ResolveIP()
   200  	if err != nil {
   201  		cancel()
   202  		return nil, common.NewError("invalid tproxy local address").Base(err)
   203  	}
   204  	tcpListener, err := ListenTCP("tcp", &net.TCPAddr{
   205  		IP:   ip,
   206  		Port: cfg.LocalPort,
   207  	})
   208  	if err != nil {
   209  		cancel()
   210  		return nil, common.NewError("tproxy failed to listen tcp").Base(err)
   211  	}
   212  
   213  	udpListener, err := ListenUDP("udp", &net.UDPAddr{
   214  		IP:   ip,
   215  		Port: cfg.LocalPort,
   216  	})
   217  	if err != nil {
   218  		cancel()
   219  		return nil, common.NewError("tproxy failed to listen udp").Base(err)
   220  	}
   221  
   222  	server := &Server{
   223  		tcpListener: tcpListener,
   224  		udpListener: udpListener,
   225  		ctx:         ctx,
   226  		cancel:      cancel,
   227  		timeout:     time.Duration(cfg.UDPTimeout) * time.Second,
   228  		mapping:     make(map[string]*PacketConn),
   229  		packetChan:  make(chan tunnel.PacketConn, 32),
   230  	}
   231  	go server.packetDispatchLoop()
   232  	log.Info("tproxy server listening on", tcpListener.Addr(), "(tcp)", udpListener.LocalAddr(), "(udp)")
   233  	log.Debug("tproxy server created")
   234  	return server, nil
   235  }