github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/socks4a/server.go (about)

     1  package socks4a
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"unsafe"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/log"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    18  )
    19  
    20  const (
    21  	CommandConnect byte = 0x01
    22  	CommandBind    byte = 0x02
    23  )
    24  
    25  type Server struct {
    26  	lis        net.Listener
    27  	usernameID string
    28  
    29  	*netapi.ChannelServer
    30  }
    31  
    32  func (s *Server) Handle(conn net.Conn) error {
    33  	addr, err := s.Handshake(conn)
    34  	if err != nil {
    35  		_, _ = conn.Write([]byte{0, 91, 0, 0, 0, 0, 0, 0})
    36  		return fmt.Errorf("handshake failed: %w", err)
    37  	}
    38  
    39  	return s.SendStream(&netapi.StreamMeta{
    40  		Source:      conn.RemoteAddr(),
    41  		Destination: addr,
    42  		Inbound:     conn.LocalAddr(),
    43  		Src:         conn,
    44  		Address:     addr,
    45  	})
    46  }
    47  
    48  func (s *Server) Handshake(conn net.Conn) (netapi.Address, error) {
    49  	buf := pool.GetBytesBuffer(8)
    50  	defer buf.Free()
    51  
    52  	if _, err := io.ReadFull(conn, buf.Bytes()); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	if buf.Bytes()[0] != 0x04 {
    57  		return nil, fmt.Errorf("unknown socks version: %d", buf.Bytes()[0])
    58  	}
    59  
    60  	if buf.Bytes()[1] != CommandConnect {
    61  		return nil, fmt.Errorf("unsupported command: %d", buf.Bytes()[1])
    62  	}
    63  
    64  	port := binary.BigEndian.Uint16(buf.Bytes()[2:4])
    65  	dstAddr := buf.Bytes()[4:8]
    66  	userId, err := readData(conn)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	if s.usernameID != "" && !bytes.Equal(userId, unsafe.Slice(unsafe.StringData(s.usernameID), len(s.usernameID))) {
    72  		return nil, fmt.Errorf("username not match")
    73  	}
    74  
    75  	var target netapi.Address
    76  	if dstAddr[0] == 0 && dstAddr[1] == 0 && dstAddr[2] == 0 && dstAddr[3] != 0 {
    77  		host, err := readData(conn)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		target = netapi.ParseAddressPort(statistic.Type_tcp, string(host), netapi.ParsePort(port))
    82  	} else {
    83  		target = netapi.ParseIPAddrPort(statistic.Type_tcp, dstAddr, int(port))
    84  	}
    85  
    86  	_, _ = conn.Write([]byte{0, 90})
    87  	_, _ = conn.Write(buf.Bytes()[2:8])
    88  	return target, nil
    89  }
    90  
    91  func readData(conn net.Conn) ([]byte, error) {
    92  	var data []byte
    93  
    94  	buf := make([]byte, 1)
    95  
    96  	for {
    97  		if _, err := io.ReadFull(conn, buf); err != nil {
    98  			return nil, err
    99  		}
   100  
   101  		if buf[0] == 0 {
   102  			break
   103  		}
   104  
   105  		data = append(data, buf[0])
   106  	}
   107  
   108  	return data, nil
   109  }
   110  
   111  func (s *Server) Close() error {
   112  	s.ChannelServer.Close()
   113  
   114  	if s.lis != nil {
   115  		return s.lis.Close()
   116  	}
   117  
   118  	return nil
   119  }
   120  
   121  func (s *Server) Server() {
   122  	defer s.Close()
   123  	for {
   124  		conn, err := s.lis.Accept()
   125  		if err != nil {
   126  			log.Error("socks5 accept failed", "err", err)
   127  
   128  			if ne, ok := err.(net.Error); ok && ne.Temporary() {
   129  				continue
   130  			}
   131  			return
   132  		}
   133  
   134  		go func() {
   135  			if err := s.Handle(conn); err != nil {
   136  				if errors.Is(err, netapi.ErrBlocked) {
   137  					log.Debug(err.Error())
   138  				} else {
   139  					log.Error("socks5 server handle failed", "err", err)
   140  				}
   141  			}
   142  		}()
   143  
   144  	}
   145  }
   146  
   147  func (s *Server) AcceptPacket() (*netapi.Packet, error) {
   148  	return nil, io.EOF
   149  }
   150  
   151  func init() {
   152  	listener.RegisterProtocol(NewServer)
   153  }
   154  
   155  func NewServer(o *listener.Inbound_Socks4A) func(netapi.Listener) (netapi.Accepter, error) {
   156  	return func(ii netapi.Listener) (netapi.Accepter, error) {
   157  		lis, err := ii.Stream(context.TODO())
   158  		if err != nil {
   159  			return nil, err
   160  		}
   161  
   162  		s := &Server{
   163  			usernameID:    o.Socks4A.Username,
   164  			lis:           lis,
   165  			ChannelServer: netapi.NewChannelServer(),
   166  		}
   167  
   168  		go s.Server()
   169  
   170  		return s, nil
   171  	}
   172  }