github.com/ipfans/trojan-go@v0.11.0/tunnel/dokodemo/server.go (about)

     1  package dokodemo
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/ipfans/trojan-go/common"
    10  	"github.com/ipfans/trojan-go/config"
    11  	"github.com/ipfans/trojan-go/log"
    12  	"github.com/ipfans/trojan-go/tunnel"
    13  )
    14  
    15  type Server struct {
    16  	tunnel.Server
    17  	tcpListener net.Listener
    18  	udpListener net.PacketConn
    19  	packetChan  chan tunnel.PacketConn
    20  	timeout     time.Duration
    21  	targetAddr  *tunnel.Address
    22  	mappingLock sync.Mutex
    23  	mapping     map[string]*PacketConn
    24  	ctx         context.Context
    25  	cancel      context.CancelFunc
    26  }
    27  
    28  func (s *Server) dispatchLoop() {
    29  	fixedMetadata := &tunnel.Metadata{
    30  		Address: s.targetAddr,
    31  	}
    32  	for {
    33  		buf := make([]byte, MaxPacketSize)
    34  		n, addr, err := s.udpListener.ReadFrom(buf)
    35  		if err != nil {
    36  			select {
    37  			case <-s.ctx.Done():
    38  			default:
    39  				log.Fatal(common.NewError("dokodemo failed to read from udp socket").Base(err))
    40  			}
    41  			return
    42  		}
    43  		log.Debug("udp packet from", addr)
    44  		s.mappingLock.Lock()
    45  		if conn, found := s.mapping[addr.String()]; found {
    46  			conn.input <- buf[:n]
    47  			s.mappingLock.Unlock()
    48  			continue
    49  		}
    50  		ctx, cancel := context.WithCancel(s.ctx)
    51  		conn := &PacketConn{
    52  			input:      make(chan []byte, 16),
    53  			output:     make(chan []byte, 16),
    54  			metadata:   fixedMetadata,
    55  			src:        addr,
    56  			PacketConn: s.udpListener,
    57  			ctx:        ctx,
    58  			cancel:     cancel,
    59  		}
    60  		s.mapping[addr.String()] = conn
    61  		s.mappingLock.Unlock()
    62  
    63  		conn.input <- buf[:n]
    64  		s.packetChan <- conn
    65  
    66  		go func(conn *PacketConn) {
    67  			for {
    68  				select {
    69  				case payload := <-conn.output:
    70  					// "Multiple goroutines may invoke methods on a Conn simultaneously."
    71  					_, err := s.udpListener.WriteTo(payload, conn.src)
    72  					if err != nil {
    73  						log.Error(common.NewError("dokodemo udp write error").Base(err))
    74  						return
    75  					}
    76  				case <-s.ctx.Done():
    77  					return
    78  				case <-time.After(s.timeout):
    79  					s.mappingLock.Lock()
    80  					delete(s.mapping, conn.src.String())
    81  					s.mappingLock.Unlock()
    82  					conn.Close()
    83  					log.Debug("closing timeout packetConn")
    84  					return
    85  				}
    86  			}
    87  		}(conn)
    88  	}
    89  }
    90  
    91  func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) {
    92  	conn, err := s.tcpListener.Accept()
    93  	if err != nil {
    94  		log.Fatal(common.NewError("dokodemo failed to accept connection").Base(err))
    95  	}
    96  	return &Conn{
    97  		Conn: conn,
    98  		targetMetadata: &tunnel.Metadata{
    99  			Address: s.targetAddr,
   100  		},
   101  	}, nil
   102  }
   103  
   104  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   105  	select {
   106  	case conn := <-s.packetChan:
   107  		return conn, nil
   108  	case <-s.ctx.Done():
   109  		return nil, common.NewError("dokodemo server closed")
   110  	}
   111  }
   112  
   113  func (s *Server) Close() error {
   114  	s.cancel()
   115  	s.tcpListener.Close()
   116  	s.udpListener.Close()
   117  	return nil
   118  }
   119  
   120  func NewServer(ctx context.Context, _ tunnel.Server) (*Server, error) {
   121  	cfg := config.FromContext(ctx, Name).(*Config)
   122  	targetAddr := tunnel.NewAddressFromHostPort("tcp", cfg.TargetHost, cfg.TargetPort)
   123  	listenAddr := tunnel.NewAddressFromHostPort("tcp", cfg.LocalHost, cfg.LocalPort)
   124  
   125  	tcpListener, err := net.Listen("tcp", listenAddr.String())
   126  	if err != nil {
   127  		return nil, common.NewError("failed to listen tcp").Base(err)
   128  	}
   129  	udpListener, err := net.ListenPacket("udp", listenAddr.String())
   130  	if err != nil {
   131  		return nil, common.NewError("failed to listen udp").Base(err)
   132  	}
   133  
   134  	ctx, cancel := context.WithCancel(ctx)
   135  	server := &Server{
   136  		tcpListener: tcpListener,
   137  		udpListener: udpListener,
   138  		targetAddr:  targetAddr,
   139  		mapping:     make(map[string]*PacketConn),
   140  		packetChan:  make(chan tunnel.PacketConn, 32),
   141  		timeout:     time.Second * time.Duration(cfg.UDPTimeout),
   142  		ctx:         ctx,
   143  		cancel:      cancel,
   144  	}
   145  	go server.dispatchLoop()
   146  	return server, nil
   147  }