github.com/sagernet/sing-box@v1.2.7/transport/trojan/service.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/auth"
     9  	"github.com/sagernet/sing/common/buf"
    10  	"github.com/sagernet/sing/common/bufio"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  	"github.com/sagernet/sing/common/rw"
    15  )
    16  
    17  type Handler interface {
    18  	N.TCPConnectionHandler
    19  	N.UDPConnectionHandler
    20  	E.Handler
    21  }
    22  
    23  type Service[K comparable] struct {
    24  	users           map[K][56]byte
    25  	keys            map[[56]byte]K
    26  	handler         Handler
    27  	fallbackHandler N.TCPConnectionHandler
    28  }
    29  
    30  func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] {
    31  	return &Service[K]{
    32  		users:           make(map[K][56]byte),
    33  		keys:            make(map[[56]byte]K),
    34  		handler:         handler,
    35  		fallbackHandler: fallbackHandler,
    36  	}
    37  }
    38  
    39  var ErrUserExists = E.New("user already exists")
    40  
    41  func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
    42  	users := make(map[K][56]byte)
    43  	keys := make(map[[56]byte]K)
    44  	for i, user := range userList {
    45  		if _, loaded := users[user]; loaded {
    46  			return ErrUserExists
    47  		}
    48  		key := Key(passwordList[i])
    49  		if oldUser, loaded := keys[key]; loaded {
    50  			return E.Extend(ErrUserExists, "password used by ", oldUser)
    51  		}
    52  		users[user] = key
    53  		keys[key] = user
    54  	}
    55  	s.users = users
    56  	s.keys = keys
    57  	return nil
    58  }
    59  
    60  func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    61  	var key [KeyLength]byte
    62  	n, err := conn.Read(common.Dup(key[:]))
    63  	if err != nil {
    64  		return err
    65  	} else if n != KeyLength {
    66  		return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size"))
    67  	}
    68  
    69  	if user, loaded := s.keys[key]; loaded {
    70  		ctx = auth.ContextWithUser(ctx, user)
    71  	} else {
    72  		return s.fallback(ctx, conn, metadata, key[:], E.New("bad request"))
    73  	}
    74  
    75  	err = rw.SkipN(conn, 2)
    76  	if err != nil {
    77  		return E.Cause(err, "skip crlf")
    78  	}
    79  
    80  	command, err := rw.ReadByte(conn)
    81  	if err != nil {
    82  		return E.Cause(err, "read command")
    83  	}
    84  
    85  	switch command {
    86  	case CommandTCP, CommandUDP, CommandMux:
    87  	default:
    88  		return E.New("unknown command ", command)
    89  	}
    90  
    91  	// var destination M.Socksaddr
    92  	destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
    93  	if err != nil {
    94  		return E.Cause(err, "read destination")
    95  	}
    96  
    97  	err = rw.SkipN(conn, 2)
    98  	if err != nil {
    99  		return E.Cause(err, "skip crlf")
   100  	}
   101  
   102  	metadata.Protocol = "trojan"
   103  	metadata.Destination = destination
   104  
   105  	switch command {
   106  	case CommandTCP:
   107  		return s.handler.NewConnection(ctx, conn, metadata)
   108  	case CommandUDP:
   109  		return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata)
   110  	// case CommandMux:
   111  	default:
   112  		return HandleMuxConnection(ctx, conn, metadata, s.handler)
   113  	}
   114  }
   115  
   116  func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error {
   117  	if s.fallbackHandler == nil {
   118  		return E.Extend(err, "fallback disabled")
   119  	}
   120  	conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
   121  	return s.fallbackHandler.NewConnection(ctx, conn, metadata)
   122  }
   123  
   124  type PacketConn struct {
   125  	net.Conn
   126  }
   127  
   128  func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   129  	return ReadPacket(c.Conn, buffer)
   130  }
   131  
   132  func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   133  	return WritePacket(c.Conn, buffer, destination)
   134  }
   135  
   136  func (c *PacketConn) FrontHeadroom() int {
   137  	return M.MaxSocksaddrLength + 4
   138  }
   139  
   140  func (c *PacketConn) NeedAdditionalReadDeadline() bool {
   141  	return true
   142  }
   143  
   144  func (c *PacketConn) Upstream() any {
   145  	return c.Conn
   146  }