github.com/sagernet/sing-box@v1.9.0-rc.20/transport/trojan/service.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "net" 6 7 "github.com/sagernet/sing/common/auth" 8 "github.com/sagernet/sing/common/buf" 9 "github.com/sagernet/sing/common/bufio" 10 E "github.com/sagernet/sing/common/exceptions" 11 M "github.com/sagernet/sing/common/metadata" 12 N "github.com/sagernet/sing/common/network" 13 "github.com/sagernet/sing/common/rw" 14 ) 15 16 type Handler interface { 17 N.TCPConnectionHandler 18 N.UDPConnectionHandler 19 E.Handler 20 } 21 22 type Service[K comparable] struct { 23 users map[K][56]byte 24 keys map[[56]byte]K 25 handler Handler 26 fallbackHandler N.TCPConnectionHandler 27 } 28 29 func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] { 30 return &Service[K]{ 31 users: make(map[K][56]byte), 32 keys: make(map[[56]byte]K), 33 handler: handler, 34 fallbackHandler: fallbackHandler, 35 } 36 } 37 38 var ErrUserExists = E.New("user already exists") 39 40 func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error { 41 users := make(map[K][56]byte) 42 keys := make(map[[56]byte]K) 43 for i, user := range userList { 44 if _, loaded := users[user]; loaded { 45 return ErrUserExists 46 } 47 key := Key(passwordList[i]) 48 if oldUser, loaded := keys[key]; loaded { 49 return E.Extend(ErrUserExists, "password used by ", oldUser) 50 } 51 users[user] = key 52 keys[key] = user 53 } 54 s.users = users 55 s.keys = keys 56 return nil 57 } 58 59 func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 60 var key [KeyLength]byte 61 n, err := conn.Read(key[:]) 62 if err != nil { 63 return err 64 } else if n != KeyLength { 65 return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size")) 66 } 67 68 if user, loaded := s.keys[key]; loaded { 69 ctx = auth.ContextWithUser(ctx, user) 70 } else { 71 return s.fallback(ctx, conn, metadata, key[:], E.New("bad request")) 72 } 73 74 err = rw.SkipN(conn, 2) 75 if err != nil { 76 return E.Cause(err, "skip crlf") 77 } 78 79 command, err := rw.ReadByte(conn) 80 if err != nil { 81 return E.Cause(err, "read command") 82 } 83 84 switch command { 85 case CommandTCP, CommandUDP, CommandMux: 86 default: 87 return E.New("unknown command ", command) 88 } 89 90 // var destination M.Socksaddr 91 destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) 92 if err != nil { 93 return E.Cause(err, "read destination") 94 } 95 96 err = rw.SkipN(conn, 2) 97 if err != nil { 98 return E.Cause(err, "skip crlf") 99 } 100 101 metadata.Protocol = "trojan" 102 metadata.Destination = destination 103 104 switch command { 105 case CommandTCP: 106 return s.handler.NewConnection(ctx, conn, metadata) 107 case CommandUDP: 108 return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata) 109 // case CommandMux: 110 default: 111 return HandleMuxConnection(ctx, conn, metadata, s.handler) 112 } 113 } 114 115 func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error { 116 if s.fallbackHandler == nil { 117 return E.Extend(err, "fallback disabled") 118 } 119 conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned()) 120 return s.fallbackHandler.NewConnection(ctx, conn, metadata) 121 } 122 123 type PacketConn struct { 124 net.Conn 125 readWaitOptions N.ReadWaitOptions 126 } 127 128 func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 129 return ReadPacket(c.Conn, buffer) 130 } 131 132 func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 133 return WritePacket(c.Conn, buffer, destination) 134 } 135 136 func (c *PacketConn) FrontHeadroom() int { 137 return M.MaxSocksaddrLength + 4 138 } 139 140 func (c *PacketConn) NeedAdditionalReadDeadline() bool { 141 return true 142 } 143 144 func (c *PacketConn) Upstream() any { 145 return c.Conn 146 }