github.com/moqsien/xraycore@v1.8.5/proxy/trojan/server.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "io" 6 "strconv" 7 "strings" 8 "time" 9 10 "github.com/moqsien/xraycore/common" 11 "github.com/moqsien/xraycore/common/buf" 12 "github.com/moqsien/xraycore/common/errors" 13 "github.com/moqsien/xraycore/common/log" 14 "github.com/moqsien/xraycore/common/net" 15 "github.com/moqsien/xraycore/common/protocol" 16 udp_proto "github.com/moqsien/xraycore/common/protocol/udp" 17 "github.com/moqsien/xraycore/common/retry" 18 "github.com/moqsien/xraycore/common/session" 19 "github.com/moqsien/xraycore/common/signal" 20 "github.com/moqsien/xraycore/common/task" 21 "github.com/moqsien/xraycore/core" 22 "github.com/moqsien/xraycore/features/policy" 23 "github.com/moqsien/xraycore/features/routing" 24 "github.com/moqsien/xraycore/transport/internet/reality" 25 "github.com/moqsien/xraycore/transport/internet/stat" 26 "github.com/moqsien/xraycore/transport/internet/tls" 27 "github.com/moqsien/xraycore/transport/internet/udp" 28 ) 29 30 func init() { 31 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 32 return NewServer(ctx, config.(*ServerConfig)) 33 })) 34 } 35 36 // Server is an inbound connection handler that handles messages in trojan protocol. 37 type Server struct { 38 policyManager policy.Manager 39 validator *Validator 40 fallbacks map[string]map[string]map[string]*Fallback // or nil 41 cone bool 42 } 43 44 // NewServer creates a new trojan inbound handler. 45 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 46 validator := new(Validator) 47 for _, user := range config.Users { 48 u, err := user.ToMemoryUser() 49 if err != nil { 50 return nil, newError("failed to get trojan user").Base(err).AtError() 51 } 52 53 if err := validator.Add(u); err != nil { 54 return nil, newError("failed to add user").Base(err).AtError() 55 } 56 } 57 58 v := core.MustFromContext(ctx) 59 server := &Server{ 60 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 61 validator: validator, 62 cone: ctx.Value("cone").(bool), 63 } 64 65 if config.Fallbacks != nil { 66 server.fallbacks = make(map[string]map[string]map[string]*Fallback) 67 for _, fb := range config.Fallbacks { 68 if server.fallbacks[fb.Name] == nil { 69 server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback) 70 } 71 if server.fallbacks[fb.Name][fb.Alpn] == nil { 72 server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback) 73 } 74 server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb 75 } 76 if server.fallbacks[""] != nil { 77 for name, apfb := range server.fallbacks { 78 if name != "" { 79 for alpn := range server.fallbacks[""] { 80 if apfb[alpn] == nil { 81 apfb[alpn] = make(map[string]*Fallback) 82 } 83 } 84 } 85 } 86 } 87 for _, apfb := range server.fallbacks { 88 if apfb[""] != nil { 89 for alpn, pfb := range apfb { 90 if alpn != "" { // && alpn != "h2" { 91 for path, fb := range apfb[""] { 92 if pfb[path] == nil { 93 pfb[path] = fb 94 } 95 } 96 } 97 } 98 } 99 } 100 if server.fallbacks[""] != nil { 101 for name, apfb := range server.fallbacks { 102 if name != "" { 103 for alpn, pfb := range server.fallbacks[""] { 104 for path, fb := range pfb { 105 if apfb[alpn][path] == nil { 106 apfb[alpn][path] = fb 107 } 108 } 109 } 110 } 111 } 112 } 113 } 114 115 return server, nil 116 } 117 118 // AddUser implements proxy.UserManager.AddUser(). 119 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 120 return s.validator.Add(u) 121 } 122 123 // RemoveUser implements proxy.UserManager.RemoveUser(). 124 func (s *Server) RemoveUser(ctx context.Context, e string) error { 125 return s.validator.Del(e) 126 } 127 128 // Network implements proxy.Inbound.Network(). 129 func (s *Server) Network() []net.Network { 130 return []net.Network{net.Network_TCP, net.Network_UNIX} 131 } 132 133 // Process implements proxy.Inbound.Process(). 134 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 135 sid := session.ExportIDToError(ctx) 136 137 iConn := conn 138 statConn, ok := iConn.(*stat.CounterConnection) 139 if ok { 140 iConn = statConn.Connection 141 } 142 143 sessionPolicy := s.policyManager.ForLevel(0) 144 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 145 return newError("unable to set read deadline").Base(err).AtWarning() 146 } 147 148 first := buf.FromBytes(make([]byte, buf.Size)) 149 first.Clear() 150 firstLen, err := first.ReadFrom(conn) 151 if err != nil { 152 return newError("failed to read first request").Base(err) 153 } 154 newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid) 155 156 bufferedReader := &buf.BufferedReader{ 157 Reader: buf.NewReader(conn), 158 Buffer: buf.MultiBuffer{first}, 159 } 160 161 var user *protocol.MemoryUser 162 163 napfb := s.fallbacks 164 isfb := napfb != nil 165 166 shouldFallback := false 167 if firstLen < 58 || first.Byte(56) != '\r' { 168 // invalid protocol 169 err = newError("not trojan protocol") 170 log.Record(&log.AccessMessage{ 171 From: conn.RemoteAddr(), 172 To: "", 173 Status: log.AccessRejected, 174 Reason: err, 175 }) 176 177 shouldFallback = true 178 } else { 179 user = s.validator.Get(hexString(first.BytesTo(56))) 180 if user == nil { 181 // invalid user, let's fallback 182 err = newError("not a valid user") 183 log.Record(&log.AccessMessage{ 184 From: conn.RemoteAddr(), 185 To: "", 186 Status: log.AccessRejected, 187 Reason: err, 188 }) 189 190 shouldFallback = true 191 } 192 } 193 194 if isfb && shouldFallback { 195 return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader) 196 } else if shouldFallback { 197 return newError("invalid protocol or invalid user") 198 } 199 200 clientReader := &ConnReader{Reader: bufferedReader} 201 if err := clientReader.ParseHeader(); err != nil { 202 log.Record(&log.AccessMessage{ 203 From: conn.RemoteAddr(), 204 To: "", 205 Status: log.AccessRejected, 206 Reason: err, 207 }) 208 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 209 } 210 211 destination := clientReader.Target 212 if err := conn.SetReadDeadline(time.Time{}); err != nil { 213 return newError("unable to set read deadline").Base(err).AtWarning() 214 } 215 216 inbound := session.InboundFromContext(ctx) 217 if inbound == nil { 218 panic("no inbound metadata") 219 } 220 inbound.Name = "trojan" 221 inbound.User = user 222 sessionPolicy = s.policyManager.ForLevel(user.Level) 223 224 if destination.Network == net.Network_UDP { // handle udp request 225 return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) 226 } 227 228 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 229 From: conn.RemoteAddr(), 230 To: destination, 231 Status: log.AccessAccepted, 232 Reason: "", 233 Email: user.Email, 234 }) 235 236 newError("received request for ", destination).WriteToLog(sid) 237 return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) 238 } 239 240 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { 241 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 242 udpPayload := packet.Payload 243 if udpPayload.UDP == nil { 244 udpPayload.UDP = &packet.Source 245 } 246 247 if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil { 248 newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 249 } 250 }) 251 252 inbound := session.InboundFromContext(ctx) 253 user := inbound.User 254 255 var dest *net.Destination 256 257 for { 258 select { 259 case <-ctx.Done(): 260 return nil 261 default: 262 mb, err := clientReader.ReadMultiBuffer() 263 if err != nil { 264 if errors.Cause(err) != io.EOF { 265 return newError("unexpected EOF").Base(err) 266 } 267 return nil 268 } 269 270 mb2, b := buf.SplitFirst(mb) 271 if b == nil { 272 continue 273 } 274 destination := *b.UDP 275 276 currentPacketCtx := ctx 277 if inbound.Source.IsValid() { 278 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 279 From: inbound.Source, 280 To: destination, 281 Status: log.AccessAccepted, 282 Reason: "", 283 Email: user.Email, 284 }) 285 } 286 newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx)) 287 288 if !s.cone || dest == nil { 289 dest = &destination 290 } 291 292 udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet 293 for _, payload := range mb2 { 294 udpServer.Dispatch(currentPacketCtx, *dest, payload) 295 } 296 } 297 } 298 } 299 300 func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session, 301 destination net.Destination, 302 clientReader buf.Reader, 303 clientWriter buf.Writer, dispatcher routing.Dispatcher, 304 ) error { 305 ctx, cancel := context.WithCancel(ctx) 306 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 307 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 308 309 link, err := dispatcher.Dispatch(ctx, destination) 310 if err != nil { 311 return newError("failed to dispatch request to ", destination).Base(err) 312 } 313 314 requestDone := func() error { 315 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 316 if buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)) != nil { 317 return newError("failed to transfer request").Base(err) 318 } 319 return nil 320 } 321 322 responseDone := func() error { 323 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 324 325 if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil { 326 return newError("failed to write response").Base(err) 327 } 328 return nil 329 } 330 331 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 332 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 333 common.Must(common.Interrupt(link.Reader)) 334 common.Must(common.Interrupt(link.Writer)) 335 return newError("connection ends").Base(err) 336 } 337 338 return nil 339 } 340 341 func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { 342 if err := connection.SetReadDeadline(time.Time{}); err != nil { 343 newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) 344 } 345 newError("fallback starts").Base(err).AtInfo().WriteToLog(sid) 346 347 name := "" 348 alpn := "" 349 if tlsConn, ok := iConn.(*tls.Conn); ok { 350 cs := tlsConn.ConnectionState() 351 name = cs.ServerName 352 alpn = cs.NegotiatedProtocol 353 newError("realName = " + name).AtInfo().WriteToLog(sid) 354 newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) 355 } else if realityConn, ok := iConn.(*reality.Conn); ok { 356 cs := realityConn.ConnectionState() 357 name = cs.ServerName 358 alpn = cs.NegotiatedProtocol 359 newError("realName = " + name).AtInfo().WriteToLog(sid) 360 newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) 361 } 362 name = strings.ToLower(name) 363 alpn = strings.ToLower(alpn) 364 365 if len(napfb) > 1 || napfb[""] == nil { 366 if name != "" && napfb[name] == nil { 367 match := "" 368 for n := range napfb { 369 if n != "" && strings.Contains(name, n) && len(n) > len(match) { 370 match = n 371 } 372 } 373 name = match 374 } 375 } 376 377 if napfb[name] == nil { 378 name = "" 379 } 380 apfb := napfb[name] 381 if apfb == nil { 382 return newError(`failed to find the default "name" config`).AtWarning() 383 } 384 385 if apfb[alpn] == nil { 386 alpn = "" 387 } 388 pfb := apfb[alpn] 389 if pfb == nil { 390 return newError(`failed to find the default "alpn" config`).AtWarning() 391 } 392 393 path := "" 394 if len(pfb) > 1 || pfb[""] == nil { 395 if firstLen >= 18 && first.Byte(4) != '*' { // not h2c 396 firstBytes := first.Bytes() 397 for i := 4; i <= 8; i++ { // 5 -> 9 398 if firstBytes[i] == '/' && firstBytes[i-1] == ' ' { 399 search := len(firstBytes) 400 if search > 64 { 401 search = 64 // up to about 60 402 } 403 for j := i + 1; j < search; j++ { 404 k := firstBytes[j] 405 if k == '\r' || k == '\n' { // avoid logging \r or \n 406 break 407 } 408 if k == '?' || k == ' ' { 409 path = string(firstBytes[i:j]) 410 newError("realPath = " + path).AtInfo().WriteToLog(sid) 411 if pfb[path] == nil { 412 path = "" 413 } 414 break 415 } 416 } 417 break 418 } 419 } 420 } 421 } 422 fb := pfb[path] 423 if fb == nil { 424 return newError(`failed to find the default "path" config`).AtWarning() 425 } 426 427 ctx, cancel := context.WithCancel(ctx) 428 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 429 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 430 431 var conn net.Conn 432 if err := retry.ExponentialBackoff(5, 100).On(func() error { 433 var dialer net.Dialer 434 conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest) 435 if err != nil { 436 return err 437 } 438 return nil 439 }); err != nil { 440 return newError("failed to dial to " + fb.Dest).Base(err).AtWarning() 441 } 442 defer conn.Close() 443 444 serverReader := buf.NewReader(conn) 445 serverWriter := buf.NewWriter(conn) 446 447 postRequest := func() error { 448 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 449 if fb.Xver != 0 { 450 ipType := 4 451 remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String()) 452 if err != nil { 453 ipType = 0 454 } 455 localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String()) 456 if err != nil { 457 ipType = 0 458 } 459 if ipType == 4 { 460 for i := 0; i < len(remoteAddr); i++ { 461 if remoteAddr[i] == ':' { 462 ipType = 6 463 break 464 } 465 } 466 } 467 pro := buf.New() 468 defer pro.Release() 469 switch fb.Xver { 470 case 1: 471 if ipType == 0 { 472 common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n"))) 473 break 474 } 475 if ipType == 4 { 476 common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 477 } else { 478 common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 479 } 480 case 2: 481 common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature 482 if ipType == 0 { 483 common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes 484 break 485 } 486 if ipType == 4 { 487 common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes 488 common.Must2(pro.Write(net.ParseIP(remoteAddr).To4())) 489 common.Must2(pro.Write(net.ParseIP(localAddr).To4())) 490 } else { 491 common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes 492 common.Must2(pro.Write(net.ParseIP(remoteAddr).To16())) 493 common.Must2(pro.Write(net.ParseIP(localAddr).To16())) 494 } 495 p1, _ := strconv.ParseUint(remotePort, 10, 16) 496 p2, _ := strconv.ParseUint(localPort, 10, 16) 497 common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})) 498 } 499 if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil { 500 return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning() 501 } 502 } 503 if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil { 504 return newError("failed to fallback request payload").Base(err).AtInfo() 505 } 506 return nil 507 } 508 509 writer := buf.NewWriter(connection) 510 511 getResponse := func() error { 512 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 513 if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil { 514 return newError("failed to deliver response payload").Base(err).AtInfo() 515 } 516 return nil 517 } 518 519 if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil { 520 common.Must(common.Interrupt(serverReader)) 521 common.Must(common.Interrupt(serverWriter)) 522 return newError("fallback ends").Base(err).AtInfo() 523 } 524 525 return nil 526 }