github.com/sagernet/sing-box@v1.2.7/transport/trojan/service.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "net" 6 7 "github.com/sagernet/sing/common" 8 "github.com/sagernet/sing/common/auth" 9 "github.com/sagernet/sing/common/buf" 10 "github.com/sagernet/sing/common/bufio" 11 E "github.com/sagernet/sing/common/exceptions" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 "github.com/sagernet/sing/common/rw" 15 ) 16 17 type Handler interface { 18 N.TCPConnectionHandler 19 N.UDPConnectionHandler 20 E.Handler 21 } 22 23 type Service[K comparable] struct { 24 users map[K][56]byte 25 keys map[[56]byte]K 26 handler Handler 27 fallbackHandler N.TCPConnectionHandler 28 } 29 30 func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] { 31 return &Service[K]{ 32 users: make(map[K][56]byte), 33 keys: make(map[[56]byte]K), 34 handler: handler, 35 fallbackHandler: fallbackHandler, 36 } 37 } 38 39 var ErrUserExists = E.New("user already exists") 40 41 func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error { 42 users := make(map[K][56]byte) 43 keys := make(map[[56]byte]K) 44 for i, user := range userList { 45 if _, loaded := users[user]; loaded { 46 return ErrUserExists 47 } 48 key := Key(passwordList[i]) 49 if oldUser, loaded := keys[key]; loaded { 50 return E.Extend(ErrUserExists, "password used by ", oldUser) 51 } 52 users[user] = key 53 keys[key] = user 54 } 55 s.users = users 56 s.keys = keys 57 return nil 58 } 59 60 func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 61 var key [KeyLength]byte 62 n, err := conn.Read(common.Dup(key[:])) 63 if err != nil { 64 return err 65 } else if n != KeyLength { 66 return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size")) 67 } 68 69 if user, loaded := s.keys[key]; loaded { 70 ctx = auth.ContextWithUser(ctx, user) 71 } else { 72 return s.fallback(ctx, conn, metadata, key[:], E.New("bad request")) 73 } 74 75 err = rw.SkipN(conn, 2) 76 if err != nil { 77 return E.Cause(err, "skip crlf") 78 } 79 80 command, err := rw.ReadByte(conn) 81 if err != nil { 82 return E.Cause(err, "read command") 83 } 84 85 switch command { 86 case CommandTCP, CommandUDP, CommandMux: 87 default: 88 return E.New("unknown command ", command) 89 } 90 91 // var destination M.Socksaddr 92 destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) 93 if err != nil { 94 return E.Cause(err, "read destination") 95 } 96 97 err = rw.SkipN(conn, 2) 98 if err != nil { 99 return E.Cause(err, "skip crlf") 100 } 101 102 metadata.Protocol = "trojan" 103 metadata.Destination = destination 104 105 switch command { 106 case CommandTCP: 107 return s.handler.NewConnection(ctx, conn, metadata) 108 case CommandUDP: 109 return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) 110 // case CommandMux: 111 default: 112 return HandleMuxConnection(ctx, conn, metadata, s.handler) 113 } 114 } 115 116 func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error { 117 if s.fallbackHandler == nil { 118 return E.Extend(err, "fallback disabled") 119 } 120 conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned()) 121 return s.fallbackHandler.NewConnection(ctx, conn, metadata) 122 } 123 124 type PacketConn struct { 125 net.Conn 126 } 127 128 func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 129 return ReadPacket(c.Conn, buffer) 130 } 131 132 func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 133 return WritePacket(c.Conn, buffer, destination) 134 } 135 136 func (c *PacketConn) FrontHeadroom() int { 137 return M.MaxSocksaddrLength + 4 138 } 139 140 func (c *PacketConn) NeedAdditionalReadDeadline() bool { 141 return true 142 } 143 144 func (c *PacketConn) Upstream() any { 145 return c.Conn 146 }