github.com/ipfans/trojan-go@v0.11.0/tunnel/socks/server.go (about)

     1  package socks
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/ipfans/trojan-go/common"
    14  	"github.com/ipfans/trojan-go/config"
    15  	"github.com/ipfans/trojan-go/log"
    16  	"github.com/ipfans/trojan-go/tunnel"
    17  )
    18  
    19  const (
    20  	Connect   tunnel.Command = 1
    21  	Associate tunnel.Command = 3
    22  )
    23  
    24  const (
    25  	MaxPacketSize = 1024 * 8
    26  )
    27  
    28  type Server struct {
    29  	connChan         chan tunnel.Conn
    30  	packetChan       chan tunnel.PacketConn
    31  	underlay         tunnel.Server
    32  	localHost        string
    33  	localPort        int
    34  	timeout          time.Duration
    35  	listenPacketConn tunnel.PacketConn
    36  	mapping          map[string]*PacketConn
    37  	mappingLock      sync.RWMutex
    38  	ctx              context.Context
    39  	cancel           context.CancelFunc
    40  }
    41  
    42  func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) {
    43  	select {
    44  	case conn := <-s.connChan:
    45  		return conn, nil
    46  	case <-s.ctx.Done():
    47  		return nil, common.NewError("socks server closed")
    48  	}
    49  }
    50  
    51  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
    52  	select {
    53  	case conn := <-s.packetChan:
    54  		return conn, nil
    55  	case <-s.ctx.Done():
    56  		return nil, common.NewError("socks server closed")
    57  	}
    58  }
    59  
    60  func (s *Server) Close() error {
    61  	s.cancel()
    62  	return s.underlay.Close()
    63  }
    64  
    65  func (s *Server) handshake(conn net.Conn) (*Conn, error) {
    66  	version := [1]byte{}
    67  	if _, err := conn.Read(version[:]); err != nil {
    68  		return nil, common.NewError("failed to read socks version").Base(err)
    69  	}
    70  	if version[0] != 5 {
    71  		return nil, common.NewError(fmt.Sprintf("invalid socks version %d", version[0]))
    72  	}
    73  	nmethods := [1]byte{}
    74  	if _, err := conn.Read(nmethods[:]); err != nil {
    75  		return nil, common.NewError("failed to read NMETHODS")
    76  	}
    77  	if _, err := io.CopyN(ioutil.Discard, conn, int64(nmethods[0])); err != nil {
    78  		return nil, common.NewError("socks failed to read methods").Base(err)
    79  	}
    80  	if _, err := conn.Write([]byte{0x5, 0x0}); err != nil {
    81  		return nil, common.NewError("failed to respond auth").Base(err)
    82  	}
    83  
    84  	buf := [3]byte{}
    85  	if _, err := conn.Read(buf[:]); err != nil {
    86  		return nil, common.NewError("failed to read command")
    87  	}
    88  
    89  	addr := new(tunnel.Address)
    90  	if err := addr.ReadFrom(conn); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return &Conn{
    95  		metadata: &tunnel.Metadata{
    96  			Command: tunnel.Command(buf[1]),
    97  			Address: addr,
    98  		},
    99  		Conn: conn,
   100  	}, nil
   101  }
   102  
   103  func (s *Server) connect(conn net.Conn) error {
   104  	_, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
   105  	return err
   106  }
   107  
   108  func (s *Server) associate(conn net.Conn, addr *tunnel.Address) error {
   109  	buf := bytes.NewBuffer([]byte{0x05, 0x00, 0x00})
   110  	common.Must(addr.WriteTo(buf))
   111  	_, err := conn.Write(buf.Bytes())
   112  	return err
   113  }
   114  
   115  func (s *Server) packetDispatchLoop() {
   116  	for {
   117  		buf := make([]byte, MaxPacketSize)
   118  		n, src, err := s.listenPacketConn.ReadFrom(buf)
   119  		if err != nil {
   120  			select {
   121  			case <-s.ctx.Done():
   122  				log.Debug("exiting")
   123  				return
   124  			default:
   125  				continue
   126  			}
   127  		}
   128  		log.Debug("socks recv udp packet from", src)
   129  		s.mappingLock.RLock()
   130  		conn, found := s.mapping[src.String()]
   131  		s.mappingLock.RUnlock()
   132  		if !found {
   133  			ctx, cancel := context.WithCancel(s.ctx)
   134  			conn = &PacketConn{
   135  				input:      make(chan *packetInfo, 128),
   136  				output:     make(chan *packetInfo, 128),
   137  				ctx:        ctx,
   138  				cancel:     cancel,
   139  				PacketConn: s.listenPacketConn,
   140  				src:        src,
   141  			}
   142  			go func(conn *PacketConn) {
   143  				defer conn.Close()
   144  				for {
   145  					select {
   146  					case info := <-conn.output:
   147  						buf := bytes.NewBuffer(make([]byte, 0, MaxPacketSize))
   148  						buf.Write([]byte{0, 0, 0}) // RSV, FRAG
   149  						common.Must(info.metadata.Address.WriteTo(buf))
   150  						buf.Write(info.payload)
   151  						_, err := s.listenPacketConn.WriteTo(buf.Bytes(), conn.src)
   152  						if err != nil {
   153  							log.Error("socks failed to respond packet to", src)
   154  							return
   155  						}
   156  						log.Debug("socks respond udp packet to", src, "metadata", info.metadata)
   157  					case <-time.After(time.Second * 5):
   158  						log.Info("socks udp session timeout, closed")
   159  						s.mappingLock.Lock()
   160  						delete(s.mapping, src.String())
   161  						s.mappingLock.Unlock()
   162  						return
   163  					case <-conn.ctx.Done():
   164  						log.Info("socks udp session closed")
   165  						return
   166  					}
   167  				}
   168  			}(conn)
   169  
   170  			s.mappingLock.Lock()
   171  			s.mapping[src.String()] = conn
   172  			s.mappingLock.Unlock()
   173  
   174  			s.packetChan <- conn
   175  			log.Info("socks new udp session from", src)
   176  		}
   177  		r := bytes.NewBuffer(buf[3:n])
   178  		address := new(tunnel.Address)
   179  		if err := address.ReadFrom(r); err != nil {
   180  			log.Error(common.NewError("socks failed to parse incoming packet").Base(err))
   181  			continue
   182  		}
   183  		payload := make([]byte, MaxPacketSize)
   184  		length, _ := r.Read(payload)
   185  		select {
   186  		case conn.input <- &packetInfo{
   187  			metadata: &tunnel.Metadata{
   188  				Address: address,
   189  			},
   190  			payload: payload[:length],
   191  		}:
   192  		default:
   193  			log.Warn("socks udp queue full")
   194  		}
   195  	}
   196  }
   197  
   198  func (s *Server) acceptLoop() {
   199  	for {
   200  		conn, err := s.underlay.AcceptConn(&Tunnel{})
   201  		if err != nil {
   202  			log.Error(common.NewError("socks accept err").Base(err))
   203  			return
   204  		}
   205  		go func(conn net.Conn) {
   206  			newConn, err := s.handshake(conn)
   207  			if err != nil {
   208  				log.Error(common.NewError("socks failed to handshake with client").Base(err))
   209  				return
   210  			}
   211  			log.Info("socks connection from", conn.RemoteAddr(), "metadata", newConn.metadata.String())
   212  			switch newConn.metadata.Command {
   213  			case Connect:
   214  				if err := s.connect(newConn); err != nil {
   215  					log.Error(common.NewError("socks failed to respond CONNECT").Base(err))
   216  					newConn.Close()
   217  					return
   218  				}
   219  				s.connChan <- newConn
   220  				return
   221  			case Associate:
   222  				defer newConn.Close()
   223  				associateAddr := tunnel.NewAddressFromHostPort("udp", s.localHost, s.localPort)
   224  				if err := s.associate(newConn, associateAddr); err != nil {
   225  					log.Error(common.NewError("socks failed to respond to associate request").Base(err))
   226  					return
   227  				}
   228  				buf := [16]byte{}
   229  				newConn.Read(buf[:])
   230  				log.Debug("socks udp session ends")
   231  			default:
   232  				log.Error(common.NewError(fmt.Sprintf("unknown socks command %d", newConn.metadata.Command)))
   233  				newConn.Close()
   234  			}
   235  		}(conn)
   236  	}
   237  }
   238  
   239  // NewServer create a socks server
   240  func NewServer(ctx context.Context, underlay tunnel.Server) (tunnel.Server, error) {
   241  	cfg := config.FromContext(ctx, Name).(*Config)
   242  	listenPacketConn, err := underlay.AcceptPacket(&Tunnel{})
   243  	if err != nil {
   244  		return nil, common.NewError("socks failed to listen packet from underlying server")
   245  	}
   246  	ctx, cancel := context.WithCancel(ctx)
   247  	server := &Server{
   248  		underlay:         underlay,
   249  		ctx:              ctx,
   250  		cancel:           cancel,
   251  		connChan:         make(chan tunnel.Conn, 32),
   252  		packetChan:       make(chan tunnel.PacketConn, 32),
   253  		localHost:        cfg.LocalHost,
   254  		localPort:        cfg.LocalPort,
   255  		timeout:          time.Duration(cfg.UDPTimeout) * time.Second,
   256  		listenPacketConn: listenPacketConn,
   257  		mapping:          make(map[string]*PacketConn),
   258  	}
   259  	go server.acceptLoop()
   260  	go server.packetDispatchLoop()
   261  	log.Debug("socks server created")
   262  	return server, nil
   263  }