github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/tuic/server.go (about) 1 //go:build with_quic 2 3 package tuic 4 5 import ( 6 "bytes" 7 "context" 8 "encoding/binary" 9 "github.com/inazumav/sing-box/option" 10 "github.com/sagernet/quic-go" 11 "io" 12 "net" 13 "runtime" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/inazumav/sing-box/common/baderror" 19 "github.com/inazumav/sing-box/common/qtls" 20 "github.com/inazumav/sing-box/common/tls" 21 "github.com/sagernet/sing/common" 22 "github.com/sagernet/sing/common/auth" 23 "github.com/sagernet/sing/common/buf" 24 "github.com/sagernet/sing/common/bufio" 25 E "github.com/sagernet/sing/common/exceptions" 26 "github.com/sagernet/sing/common/logger" 27 M "github.com/sagernet/sing/common/metadata" 28 N "github.com/sagernet/sing/common/network" 29 30 "github.com/gofrs/uuid/v5" 31 ) 32 33 type ServerOptions struct { 34 Context context.Context 35 Logger logger.Logger 36 TLSConfig tls.ServerConfig 37 Users []User 38 CongestionControl string 39 AuthTimeout time.Duration 40 ZeroRTTHandshake bool 41 Heartbeat time.Duration 42 Handler ServerHandler 43 } 44 45 type User struct { 46 Name string 47 UUID uuid.UUID 48 Password string 49 } 50 51 type ServerHandler interface { 52 N.TCPConnectionHandler 53 N.UDPConnectionHandler 54 } 55 56 type Server struct { 57 ctx context.Context 58 logger logger.Logger 59 tlsConfig tls.ServerConfig 60 heartbeat time.Duration 61 quicConfig *quic.Config 62 userMap map[uuid.UUID]User 63 congestionControl string 64 authTimeout time.Duration 65 handler ServerHandler 66 67 quicListener io.Closer 68 } 69 70 func NewServer(options ServerOptions) (*Server, error) { 71 if options.AuthTimeout == 0 { 72 options.AuthTimeout = 3 * time.Second 73 } 74 if options.Heartbeat == 0 { 75 options.Heartbeat = 10 * time.Second 76 } 77 quicConfig := &quic.Config{ 78 DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 79 MaxDatagramFrameSize: 1400, 80 EnableDatagrams: true, 81 Allow0RTT: options.ZeroRTTHandshake, 82 MaxIncomingStreams: 1 << 60, 83 MaxIncomingUniStreams: 1 << 60, 84 } 85 switch options.CongestionControl { 86 case "": 87 options.CongestionControl = "cubic" 88 case "cubic", "new_reno", "bbr": 89 default: 90 return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) 91 } 92 if len(options.Users) == 0 { 93 return nil, E.New("missing users") 94 } 95 userMap := make(map[uuid.UUID]User) 96 for _, user := range options.Users { 97 userMap[user.UUID] = user 98 } 99 return &Server{ 100 ctx: options.Context, 101 logger: options.Logger, 102 tlsConfig: options.TLSConfig, 103 heartbeat: options.Heartbeat, 104 quicConfig: quicConfig, 105 userMap: userMap, 106 congestionControl: options.CongestionControl, 107 authTimeout: options.AuthTimeout, 108 handler: options.Handler, 109 }, nil 110 } 111 112 func (s *Server) Start(conn net.PacketConn) error { 113 if !s.quicConfig.Allow0RTT { 114 listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) 115 if err != nil { 116 return err 117 } 118 s.quicListener = listener 119 go func() { 120 for { 121 connection, hErr := listener.Accept(s.ctx) 122 if hErr != nil { 123 if strings.Contains(hErr.Error(), "server closed") { 124 s.logger.Debug(E.Cause(hErr, "listener closed")) 125 } else { 126 s.logger.Error(E.Cause(hErr, "listener closed")) 127 } 128 return 129 } 130 go s.handleConnection(connection) 131 } 132 }() 133 } else { 134 listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig) 135 if err != nil { 136 return err 137 } 138 s.quicListener = listener 139 go func() { 140 for { 141 connection, hErr := listener.Accept(s.ctx) 142 if hErr != nil { 143 if strings.Contains(hErr.Error(), "server closed") { 144 s.logger.Debug(E.Cause(hErr, "listener closed")) 145 } else { 146 s.logger.Error(E.Cause(hErr, "listener closed")) 147 } 148 return 149 } 150 go s.handleConnection(connection) 151 } 152 }() 153 } 154 return nil 155 } 156 157 func (s *Server) Close() error { 158 return common.Close( 159 s.quicListener, 160 ) 161 } 162 163 func (s *Server) AddUsers(users []option.TUICUser) error { 164 for _, u := range users { 165 uuid, err := uuid.FromString(u.UUID) 166 if err != nil { 167 return E.Cause(err, "invalid uuid for user ", u.UUID) 168 } 169 s.userMap[uuid] = User{ 170 Name: u.Name, 171 UUID: uuid, 172 Password: u.Password, 173 } 174 } 175 return nil 176 } 177 178 func (s *Server) DelUsers(uuids []string) error { 179 for _, u := range uuids { 180 ud, err := uuid.FromString(u) 181 if err != nil { 182 return E.Cause(err, "invalid uuid for user ", ud) 183 } 184 delete(s.userMap, ud) 185 } 186 return nil 187 } 188 189 func (s *Server) handleConnection(connection quic.Connection) { 190 setCongestion(s.ctx, connection, s.congestionControl) 191 session := &serverSession{ 192 Server: s, 193 ctx: s.ctx, 194 quicConn: connection, 195 source: M.SocksaddrFromNet(connection.RemoteAddr()), 196 connDone: make(chan struct{}), 197 authDone: make(chan struct{}), 198 udpConnMap: make(map[uint16]*udpPacketConn), 199 } 200 session.handle() 201 } 202 203 type serverSession struct { 204 *Server 205 ctx context.Context 206 quicConn quic.Connection 207 source M.Socksaddr 208 connAccess sync.Mutex 209 connDone chan struct{} 210 connErr error 211 authDone chan struct{} 212 authUser *User 213 udpAccess sync.RWMutex 214 udpConnMap map[uint16]*udpPacketConn 215 } 216 217 func (s *serverSession) handle() { 218 if s.ctx.Done() != nil { 219 go func() { 220 select { 221 case <-s.ctx.Done(): 222 s.closeWithError(s.ctx.Err()) 223 case <-s.connDone: 224 } 225 }() 226 } 227 go s.loopUniStreams() 228 go s.loopStreams() 229 go s.loopMessages() 230 go s.handleAuthTimeout() 231 go s.loopHeartbeats() 232 } 233 234 func (s *serverSession) loopUniStreams() { 235 for { 236 uniStream, err := s.quicConn.AcceptUniStream(s.ctx) 237 if err != nil { 238 return 239 } 240 go func() { 241 err = s.handleUniStream(uniStream) 242 if err != nil { 243 s.closeWithError(E.Cause(err, "handle uni stream")) 244 } 245 }() 246 } 247 } 248 249 func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { 250 defer stream.CancelRead(0) 251 buffer := buf.New() 252 defer buffer.Release() 253 _, err := buffer.ReadAtLeastFrom(stream, 2) 254 if err != nil { 255 return E.Cause(err, "read request") 256 } 257 version := buffer.Byte(0) 258 if version != Version { 259 return E.New("unknown version ", buffer.Byte(0)) 260 } 261 command := buffer.Byte(1) 262 switch command { 263 case CommandAuthenticate: 264 select { 265 case <-s.authDone: 266 return E.New("authentication: multiple authentication requests") 267 default: 268 } 269 if buffer.Len() < AuthenticateLen { 270 _, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len()) 271 if err != nil { 272 return E.Cause(err, "authentication: read request") 273 } 274 } 275 userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16)) 276 user, loaded := s.userMap[userUUID] 277 if !loaded { 278 return E.New("authentication: unknown user ", userUUID) 279 } 280 handshakeState := s.quicConn.ConnectionState() 281 tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32) 282 if err != nil { 283 return E.Cause(err, "authentication: export keying material") 284 } 285 if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { 286 return E.New("authentication: token mismatch") 287 } 288 s.authUser = &user 289 close(s.authDone) 290 return nil 291 case CommandPacket: 292 select { 293 case <-s.connDone: 294 return s.connErr 295 case <-s.authDone: 296 } 297 message := udpMessagePool.Get().(*udpMessage) 298 err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) 299 if err != nil { 300 message.release() 301 return err 302 } 303 s.handleUDPMessage(message, true) 304 return nil 305 case CommandDissociate: 306 select { 307 case <-s.connDone: 308 return s.connErr 309 case <-s.authDone: 310 } 311 if buffer.Len() > 4 { 312 return E.New("invalid dissociate message") 313 } 314 var sessionID uint16 315 err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID) 316 if err != nil { 317 return err 318 } 319 s.udpAccess.RLock() 320 udpConn, loaded := s.udpConnMap[sessionID] 321 s.udpAccess.RUnlock() 322 if loaded { 323 udpConn.closeWithError(E.New("remote closed")) 324 s.udpAccess.Lock() 325 delete(s.udpConnMap, sessionID) 326 s.udpAccess.Unlock() 327 } 328 return nil 329 default: 330 return E.New("unknown command ", command) 331 } 332 } 333 334 func (s *serverSession) handleAuthTimeout() { 335 select { 336 case <-s.connDone: 337 case <-s.authDone: 338 case <-time.After(s.authTimeout): 339 s.closeWithError(E.New("authentication timeout")) 340 } 341 } 342 343 func (s *serverSession) loopStreams() { 344 for { 345 stream, err := s.quicConn.AcceptStream(s.ctx) 346 if err != nil { 347 return 348 } 349 go func() { 350 err = s.handleStream(stream) 351 if err != nil { 352 stream.CancelRead(0) 353 stream.Close() 354 s.logger.Error(E.Cause(err, "handle stream request")) 355 } 356 }() 357 } 358 } 359 360 func (s *serverSession) handleStream(stream quic.Stream) error { 361 buffer := buf.NewSize(2 + M.MaxSocksaddrLength) 362 defer buffer.Release() 363 _, err := buffer.ReadAtLeastFrom(stream, 2) 364 if err != nil { 365 return E.Cause(err, "read request") 366 } 367 version, _ := buffer.ReadByte() 368 if version != Version { 369 return E.New("unknown version ", buffer.Byte(0)) 370 } 371 command, _ := buffer.ReadByte() 372 if command != CommandConnect { 373 return E.New("unsupported stream command ", command) 374 } 375 destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream)) 376 if err != nil { 377 return E.Cause(err, "read request destination") 378 } 379 select { 380 case <-s.connDone: 381 return s.connErr 382 case <-s.authDone: 383 } 384 var conn net.Conn = &serverConn{ 385 Stream: stream, 386 destination: destination, 387 } 388 if buffer.IsEmpty() { 389 buffer.Release() 390 } else { 391 conn = bufio.NewCachedConn(conn, buffer) 392 } 393 ctx := s.ctx 394 if s.authUser.Name != "" { 395 ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) 396 } 397 _ = s.handler.NewConnection(ctx, conn, M.Metadata{ 398 Source: s.source, 399 Destination: destination, 400 }) 401 return nil 402 } 403 404 func (s *serverSession) loopHeartbeats() { 405 ticker := time.NewTicker(s.heartbeat) 406 defer ticker.Stop() 407 for { 408 select { 409 case <-s.connDone: 410 return 411 case <-ticker.C: 412 err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) 413 if err != nil { 414 s.closeWithError(E.Cause(err, "send heartbeat")) 415 } 416 } 417 } 418 } 419 420 func (s *serverSession) closeWithError(err error) { 421 s.connAccess.Lock() 422 defer s.connAccess.Unlock() 423 select { 424 case <-s.connDone: 425 return 426 default: 427 s.connErr = err 428 close(s.connDone) 429 } 430 if E.IsClosedOrCanceled(err) { 431 s.logger.Debug(E.Cause(err, "connection failed")) 432 } else { 433 s.logger.Error(E.Cause(err, "connection failed")) 434 } 435 _ = s.quicConn.CloseWithError(0, "") 436 } 437 438 type serverConn struct { 439 quic.Stream 440 destination M.Socksaddr 441 } 442 443 func (c *serverConn) Read(p []byte) (n int, err error) { 444 n, err = c.Stream.Read(p) 445 return n, baderror.WrapQUIC(err) 446 } 447 448 func (c *serverConn) Write(p []byte) (n int, err error) { 449 n, err = c.Stream.Write(p) 450 return n, baderror.WrapQUIC(err) 451 } 452 453 func (c *serverConn) LocalAddr() net.Addr { 454 return c.destination 455 } 456 457 func (c *serverConn) RemoteAddr() net.Addr { 458 return M.Socksaddr{} 459 } 460 461 func (c *serverConn) Close() error { 462 c.Stream.CancelRead(0) 463 return c.Stream.Close() 464 }