github.com/xxf098/lite-proxy@v0.15.1-0.20230422081941-12c69f323218/tunnel/socks/server.go (about)

     1  package socks
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/xxf098/lite-proxy/common"
    15  	"github.com/xxf098/lite-proxy/log"
    16  	"github.com/xxf098/lite-proxy/tunnel"
    17  )
    18  
    19  const (
    20  	Connect       tunnel.Command = 1
    21  	Associate     tunnel.Command = 3
    22  	MaxPacketSize                = 1024 * 8
    23  )
    24  
    25  type Server struct {
    26  	connChan         chan tunnel.Conn
    27  	packetChan       chan tunnel.PacketConn
    28  	underlay         tunnel.Server
    29  	localHost        string
    30  	localPort        int
    31  	timeout          time.Duration
    32  	listenPacketConn net.PacketConn
    33  	mapping          map[string]*PacketConn
    34  	mappingLock      sync.RWMutex
    35  	ctx              context.Context
    36  	cancel           context.CancelFunc
    37  }
    38  
    39  func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) {
    40  	select {
    41  	case conn := <-s.connChan:
    42  		return conn, nil
    43  	case <-s.ctx.Done():
    44  		return nil, errors.New("socks server closed")
    45  	}
    46  }
    47  
    48  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
    49  	select {
    50  	case conn := <-s.packetChan:
    51  		return conn, nil
    52  	case <-s.ctx.Done():
    53  		return nil, common.NewError("socks server closed")
    54  	}
    55  }
    56  
    57  func (s *Server) Close() error {
    58  	s.cancel()
    59  	return s.underlay.Close()
    60  }
    61  
    62  func (s *Server) handshake(conn net.Conn) (*Conn, error) {
    63  	version := [1]byte{}
    64  	if _, err := conn.Read(version[:]); err != nil {
    65  		return nil, common.NewError("failed to read socks version").Base(err)
    66  	}
    67  	if version[0] != 5 {
    68  		return nil, common.NewError(fmt.Sprintf("invalid socks version %d", version[0]))
    69  	}
    70  	nmethods := [1]byte{}
    71  	if _, err := conn.Read(nmethods[:]); err != nil {
    72  		return nil, common.NewError("failed to read NMETHODS")
    73  	}
    74  	if _, err := io.CopyN(ioutil.Discard, conn, int64(nmethods[0])); err != nil {
    75  		return nil, common.NewError("socks failed to read methods").Base(err)
    76  	}
    77  	if _, err := conn.Write([]byte{0x5, 0x0}); err != nil {
    78  		return nil, common.NewError("failed to respond auth").Base(err)
    79  	}
    80  
    81  	buf := [3]byte{}
    82  	if _, err := conn.Read(buf[:]); err != nil {
    83  		return nil, common.NewError("failed to read command")
    84  	}
    85  
    86  	addr := new(tunnel.Address)
    87  	if err := addr.ReadFrom(conn); err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	return &Conn{
    92  		metadata: &tunnel.Metadata{
    93  			Command: tunnel.Command(buf[1]),
    94  			Address: addr,
    95  		},
    96  		Conn: conn,
    97  	}, nil
    98  }
    99  
   100  func (s *Server) connect(conn net.Conn) error {
   101  	_, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
   102  	return err
   103  }
   104  
   105  func (s *Server) associate(conn net.Conn, addr *tunnel.Address) error {
   106  	buf := bytes.NewBuffer([]byte{0x05, 0x00, 0x00})
   107  	common.Must(addr.WriteTo(buf))
   108  	_, err := conn.Write(buf.Bytes())
   109  	return err
   110  }
   111  
   112  func (s *Server) packetDispatchLoop() {
   113  	for {
   114  		buf := make([]byte, MaxPacketSize)
   115  		n, src, err := s.listenPacketConn.ReadFrom(buf)
   116  		if err != nil {
   117  			select {
   118  			case <-s.ctx.Done():
   119  				log.D("exiting")
   120  				return
   121  			default:
   122  				continue
   123  			}
   124  		}
   125  		log.D("socks recv udp packet from", src)
   126  		s.mappingLock.RLock()
   127  		conn, found := s.mapping[src.String()]
   128  		s.mappingLock.RUnlock()
   129  		if !found {
   130  			ctx, cancel := context.WithCancel(s.ctx)
   131  			conn = &PacketConn{
   132  				input:      make(chan *packetInfo, 128),
   133  				output:     make(chan *packetInfo, 128),
   134  				ctx:        ctx,
   135  				cancel:     cancel,
   136  				PacketConn: s.listenPacketConn,
   137  				src:        src,
   138  			}
   139  			go func(conn *PacketConn) {
   140  				defer conn.Close()
   141  				for {
   142  					select {
   143  					case info := <-conn.output:
   144  						buf := bytes.NewBuffer(make([]byte, 0, MaxPacketSize))
   145  						buf.Write([]byte{0, 0, 0}) //RSV, FRAG
   146  						common.Must(info.metadata.Address.WriteTo(buf))
   147  						buf.Write(info.payload)
   148  						_, err := s.listenPacketConn.WriteTo(buf.Bytes(), conn.src)
   149  						if err != nil {
   150  							log.E("socks failed to respond packet to", src)
   151  							return
   152  						}
   153  						log.D("socks respond udp packet to", src, "metadata", info.metadata)
   154  					case <-time.After(time.Second * 5):
   155  						log.I("socks udp session timeout, closed")
   156  						s.mappingLock.Lock()
   157  						delete(s.mapping, src.String())
   158  						s.mappingLock.Unlock()
   159  						return
   160  					case <-conn.ctx.Done():
   161  						log.I("socks udp session closed")
   162  						return
   163  					}
   164  				}
   165  			}(conn)
   166  
   167  			s.mappingLock.Lock()
   168  			s.mapping[src.String()] = conn
   169  			s.mappingLock.Unlock()
   170  
   171  			s.packetChan <- conn
   172  			log.I("socks new udp session from", src)
   173  		}
   174  		r := bytes.NewBuffer(buf[3:n])
   175  		address := new(tunnel.Address)
   176  		if err := address.ReadFrom(r); err != nil {
   177  			log.Error(common.NewError("socks failed to parse incoming packet").Base(err))
   178  			continue
   179  		}
   180  		payload := make([]byte, MaxPacketSize)
   181  		length, err := r.Read(payload)
   182  		select {
   183  		case conn.input <- &packetInfo{
   184  			metadata: &tunnel.Metadata{
   185  				Address: address,
   186  			},
   187  			payload: payload[:length],
   188  		}:
   189  		default:
   190  			log.W("socks udp queue full")
   191  		}
   192  	}
   193  }
   194  
   195  func (s *Server) acceptLoop() {
   196  	for {
   197  		conn, err := s.underlay.AcceptConn(&Tunnel{})
   198  		if err != nil {
   199  			log.Error(common.NewError("socks accept err").Base(err))
   200  			return
   201  		}
   202  		go func(conn net.Conn) {
   203  			newConn, err := s.handshake(conn)
   204  			if err != nil {
   205  				log.Error(common.NewError("socks failed to handshake with client").Base(err))
   206  				return
   207  			}
   208  			log.I("socks connection from", conn.RemoteAddr(), "metadata", newConn.metadata.String())
   209  			switch newConn.metadata.Command {
   210  			case Connect:
   211  				if err := s.connect(newConn); err != nil {
   212  					log.Error(common.NewError("socks failed to respond CONNECT").Base(err))
   213  					newConn.Close()
   214  					return
   215  				}
   216  				s.connChan <- newConn
   217  				return
   218  			case Associate:
   219  				defer newConn.Close()
   220  				associateAddr := tunnel.NewAddressFromHostPort("udp", s.localHost, s.localPort)
   221  				if err := s.associate(newConn, associateAddr); err != nil {
   222  					log.Error(common.NewError("socks failed to respond to associate request").Base(err))
   223  					return
   224  				}
   225  				buf := [16]byte{}
   226  				newConn.Read(buf[:])
   227  				log.D("socks udp session ends")
   228  			default:
   229  				log.Error(common.NewError(fmt.Sprintf("unknown socks command %d", newConn.metadata.Command)))
   230  				newConn.Close()
   231  			}
   232  		}(conn)
   233  	}
   234  }
   235  
   236  func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) {
   237  	listenPacketConn, err := underlay.AcceptPacket(&Tunnel{})
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	localHost := ctx.Value("LocalHost").(string)
   242  	localPort := ctx.Value("LocalPort").(int)
   243  	ctx, cancel := context.WithCancel(ctx)
   244  	server := &Server{
   245  		underlay:         underlay,
   246  		ctx:              ctx,
   247  		cancel:           cancel,
   248  		connChan:         make(chan tunnel.Conn, 32),
   249  		packetChan:       make(chan tunnel.PacketConn, 32),
   250  		localHost:        localHost,
   251  		localPort:        localPort,
   252  		timeout:          5 * time.Second,
   253  		listenPacketConn: listenPacketConn,
   254  		mapping:          make(map[string]*PacketConn),
   255  	}
   256  	go server.acceptLoop()
   257  	go server.packetDispatchLoop()
   258  	log.D("socks server created", localPort)
   259  	return server, nil
   260  }