github.com/sagernet/sing-box@v1.2.7/transport/vless/service.go (about) 1 package vless 2 3 import ( 4 "context" 5 "encoding/binary" 6 "io" 7 "net" 8 9 "github.com/sagernet/sing-vmess" 10 "github.com/sagernet/sing/common/auth" 11 "github.com/sagernet/sing/common/buf" 12 "github.com/sagernet/sing/common/bufio" 13 E "github.com/sagernet/sing/common/exceptions" 14 "github.com/sagernet/sing/common/logger" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 18 "github.com/gofrs/uuid/v5" 19 ) 20 21 type Service[T comparable] struct { 22 userMap map[[16]byte]T 23 userFlow map[T]string 24 logger logger.Logger 25 handler Handler 26 } 27 28 type Handler interface { 29 N.TCPConnectionHandler 30 N.UDPConnectionHandler 31 E.Handler 32 } 33 34 func NewService[T comparable](logger logger.Logger, handler Handler) *Service[T] { 35 return &Service[T]{ 36 logger: logger, 37 handler: handler, 38 } 39 } 40 41 func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string, userFlowList []string) { 42 userMap := make(map[[16]byte]T) 43 userFlowMap := make(map[T]string) 44 for i, userName := range userList { 45 userID := uuid.FromStringOrNil(userUUIDList[i]) 46 if userID == uuid.Nil { 47 userID = uuid.NewV5(uuid.Nil, userUUIDList[i]) 48 } 49 userMap[userID] = userName 50 userFlowMap[userName] = userFlowList[i] 51 } 52 s.userMap = userMap 53 s.userFlow = userFlowMap 54 } 55 56 var _ N.TCPConnectionHandler = (*Service[int])(nil) 57 58 func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 59 request, err := ReadRequest(conn) 60 if err != nil { 61 return err 62 } 63 user, loaded := s.userMap[request.UUID] 64 if !loaded { 65 return E.New("unknown UUID: ", uuid.FromBytesOrNil(request.UUID[:])) 66 } 67 ctx = auth.ContextWithUser(ctx, user) 68 metadata.Destination = request.Destination 69 70 userFlow := s.userFlow[user] 71 72 var responseWriter io.Writer 73 if request.Command == vmess.CommandTCP { 74 if request.Flow != userFlow { 75 return E.New("flow mismatch: expected ", flowName(userFlow), ", but got ", flowName(request.Flow)) 76 } 77 switch userFlow { 78 case "": 79 case FlowVision: 80 responseWriter = conn 81 conn, err = NewVisionConn(conn, request.UUID, s.logger) 82 if err != nil { 83 return E.Cause(err, "initialize vision") 84 } 85 } 86 } 87 88 switch request.Command { 89 case vmess.CommandTCP: 90 return s.handler.NewConnection(ctx, &serverConn{Conn: conn, responseWriter: responseWriter}, metadata) 91 case vmess.CommandUDP: 92 return s.handler.NewPacketConnection(ctx, &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(conn), destination: request.Destination}, metadata) 93 case vmess.CommandMux: 94 return vmess.HandleMuxConnection(ctx, &serverConn{Conn: conn, responseWriter: responseWriter}, s.handler) 95 default: 96 return E.New("unknown command: ", request.Command) 97 } 98 } 99 100 func flowName(value string) string { 101 if value == "" { 102 return "none" 103 } 104 return value 105 } 106 107 type serverConn struct { 108 net.Conn 109 responseWriter io.Writer 110 responseWritten bool 111 } 112 113 func (c *serverConn) Read(b []byte) (n int, err error) { 114 return c.Conn.Read(b) 115 } 116 117 func (c *serverConn) Write(b []byte) (n int, err error) { 118 if !c.responseWritten { 119 if c.responseWriter == nil { 120 _, err = bufio.WriteVectorised(bufio.NewVectorisedWriter(c.Conn), [][]byte{{Version, 0}, b}) 121 if err == nil { 122 n = len(b) 123 } 124 c.responseWritten = true 125 return 126 } else { 127 _, err = c.responseWriter.Write([]byte{Version, 0}) 128 if err != nil { 129 return 130 } 131 c.responseWritten = true 132 } 133 } 134 return c.Conn.Write(b) 135 } 136 137 func (c *serverConn) NeedAdditionalReadDeadline() bool { 138 return true 139 } 140 141 func (c *serverConn) Upstream() any { 142 return c.Conn 143 } 144 145 type serverPacketConn struct { 146 N.ExtendedConn 147 responseWriter io.Writer 148 responseWritten bool 149 destination M.Socksaddr 150 } 151 152 func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 153 n, err = c.ExtendedConn.Read(p) 154 if err != nil { 155 return 156 } 157 if c.destination.IsFqdn() { 158 addr = c.destination 159 } else { 160 addr = c.destination.UDPAddr() 161 } 162 return 163 } 164 165 func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 166 if !c.responseWritten { 167 if c.responseWriter == nil { 168 var packetLen [2]byte 169 binary.BigEndian.PutUint16(packetLen[:], uint16(len(p))) 170 _, err = bufio.WriteVectorised(bufio.NewVectorisedWriter(c.ExtendedConn), [][]byte{{Version, 0}, packetLen[:], p}) 171 if err == nil { 172 n = len(p) 173 } 174 c.responseWritten = true 175 return 176 } else { 177 _, err = c.responseWriter.Write([]byte{Version, 0}) 178 if err != nil { 179 return 180 } 181 c.responseWritten = true 182 } 183 } 184 return c.ExtendedConn.Write(p) 185 } 186 187 func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 188 var packetLen uint16 189 err = binary.Read(c.ExtendedConn, binary.BigEndian, &packetLen) 190 if err != nil { 191 return 192 } 193 194 _, err = buffer.ReadFullFrom(c.ExtendedConn, int(packetLen)) 195 if err != nil { 196 return 197 } 198 199 destination = c.destination 200 return 201 } 202 203 func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 204 if !c.responseWritten { 205 if c.responseWriter == nil { 206 var packetLen [2]byte 207 binary.BigEndian.PutUint16(packetLen[:], uint16(buffer.Len())) 208 err := bufio.NewVectorisedWriter(c.ExtendedConn).WriteVectorised([]*buf.Buffer{buf.As([]byte{Version, 0}), buf.As(packetLen[:]), buffer}) 209 c.responseWritten = true 210 return err 211 } else { 212 _, err := c.responseWriter.Write([]byte{Version, 0}) 213 if err != nil { 214 return err 215 } 216 c.responseWritten = true 217 } 218 } 219 packetLen := buffer.Len() 220 binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(packetLen)) 221 return c.ExtendedConn.WriteBuffer(buffer) 222 } 223 224 func (c *serverPacketConn) FrontHeadroom() int { 225 return 2 226 } 227 228 func (c *serverPacketConn) Upstream() any { 229 return c.ExtendedConn 230 }