github.com/sagernet/sing-box@v1.9.0-rc.20/transport/trojan/service.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common/auth"
     8  	"github.com/sagernet/sing/common/buf"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  	N "github.com/sagernet/sing/common/network"
    13  	"github.com/sagernet/sing/common/rw"
    14  )
    15  
    16  type Handler interface {
    17  	N.TCPConnectionHandler
    18  	N.UDPConnectionHandler
    19  	E.Handler
    20  }
    21  
    22  type Service[K comparable] struct {
    23  	users           map[K][56]byte
    24  	keys            map[[56]byte]K
    25  	handler         Handler
    26  	fallbackHandler N.TCPConnectionHandler
    27  }
    28  
    29  func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] {
    30  	return &Service[K]{
    31  		users:           make(map[K][56]byte),
    32  		keys:            make(map[[56]byte]K),
    33  		handler:         handler,
    34  		fallbackHandler: fallbackHandler,
    35  	}
    36  }
    37  
    38  var ErrUserExists = E.New("user already exists")
    39  
    40  func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
    41  	users := make(map[K][56]byte)
    42  	keys := make(map[[56]byte]K)
    43  	for i, user := range userList {
    44  		if _, loaded := users[user]; loaded {
    45  			return ErrUserExists
    46  		}
    47  		key := Key(passwordList[i])
    48  		if oldUser, loaded := keys[key]; loaded {
    49  			return E.Extend(ErrUserExists, "password used by ", oldUser)
    50  		}
    51  		users[user] = key
    52  		keys[key] = user
    53  	}
    54  	s.users = users
    55  	s.keys = keys
    56  	return nil
    57  }
    58  
    59  func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    60  	var key [KeyLength]byte
    61  	n, err := conn.Read(key[:])
    62  	if err != nil {
    63  		return err
    64  	} else if n != KeyLength {
    65  		return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size"))
    66  	}
    67  
    68  	if user, loaded := s.keys[key]; loaded {
    69  		ctx = auth.ContextWithUser(ctx, user)
    70  	} else {
    71  		return s.fallback(ctx, conn, metadata, key[:], E.New("bad request"))
    72  	}
    73  
    74  	err = rw.SkipN(conn, 2)
    75  	if err != nil {
    76  		return E.Cause(err, "skip crlf")
    77  	}
    78  
    79  	command, err := rw.ReadByte(conn)
    80  	if err != nil {
    81  		return E.Cause(err, "read command")
    82  	}
    83  
    84  	switch command {
    85  	case CommandTCP, CommandUDP, CommandMux:
    86  	default:
    87  		return E.New("unknown command ", command)
    88  	}
    89  
    90  	// var destination M.Socksaddr
    91  	destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
    92  	if err != nil {
    93  		return E.Cause(err, "read destination")
    94  	}
    95  
    96  	err = rw.SkipN(conn, 2)
    97  	if err != nil {
    98  		return E.Cause(err, "skip crlf")
    99  	}
   100  
   101  	metadata.Protocol = "trojan"
   102  	metadata.Destination = destination
   103  
   104  	switch command {
   105  	case CommandTCP:
   106  		return s.handler.NewConnection(ctx, conn, metadata)
   107  	case CommandUDP:
   108  		return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
   109  	// case CommandMux:
   110  	default:
   111  		return HandleMuxConnection(ctx, conn, metadata, s.handler)
   112  	}
   113  }
   114  
   115  func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error {
   116  	if s.fallbackHandler == nil {
   117  		return E.Extend(err, "fallback disabled")
   118  	}
   119  	conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
   120  	return s.fallbackHandler.NewConnection(ctx, conn, metadata)
   121  }
   122  
   123  type PacketConn struct {
   124  	net.Conn
   125  	readWaitOptions N.ReadWaitOptions
   126  }
   127  
   128  func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   129  	return ReadPacket(c.Conn, buffer)
   130  }
   131  
   132  func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   133  	return WritePacket(c.Conn, buffer, destination)
   134  }
   135  
   136  func (c *PacketConn) FrontHeadroom() int {
   137  	return M.MaxSocksaddrLength + 4
   138  }
   139  
   140  func (c *PacketConn) NeedAdditionalReadDeadline() bool {
   141  	return true
   142  }
   143  
   144  func (c *PacketConn) Upstream() any {
   145  	return c.Conn
   146  }