github.com/xmplusdev/xray-core@v1.8.10/proxy/trojan/server.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "io" 6 "strconv" 7 "strings" 8 "time" 9 10 "github.com/xmplusdev/xray-core/common" 11 "github.com/xmplusdev/xray-core/common/buf" 12 "github.com/xmplusdev/xray-core/common/errors" 13 "github.com/xmplusdev/xray-core/common/log" 14 "github.com/xmplusdev/xray-core/common/net" 15 "github.com/xmplusdev/xray-core/common/protocol" 16 udp_proto "github.com/xmplusdev/xray-core/common/protocol/udp" 17 "github.com/xmplusdev/xray-core/common/retry" 18 "github.com/xmplusdev/xray-core/common/session" 19 "github.com/xmplusdev/xray-core/common/signal" 20 "github.com/xmplusdev/xray-core/common/task" 21 "github.com/xmplusdev/xray-core/core" 22 "github.com/xmplusdev/xray-core/features/policy" 23 "github.com/xmplusdev/xray-core/features/routing" 24 "github.com/xmplusdev/xray-core/transport/internet/reality" 25 "github.com/xmplusdev/xray-core/transport/internet/stat" 26 "github.com/xmplusdev/xray-core/transport/internet/tls" 27 "github.com/xmplusdev/xray-core/transport/internet/udp" 28 ) 29 30 func init() { 31 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 32 return NewServer(ctx, config.(*ServerConfig)) 33 })) 34 } 35 36 // Server is an inbound connection handler that handles messages in trojan protocol. 37 type Server struct { 38 policyManager policy.Manager 39 validator *Validator 40 fallbacks map[string]map[string]map[string]*Fallback // or nil 41 cone bool 42 } 43 44 // NewServer creates a new trojan inbound handler. 45 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 46 validator := new(Validator) 47 for _, user := range config.Users { 48 u, err := user.ToMemoryUser() 49 if err != nil { 50 return nil, newError("failed to get trojan user").Base(err).AtError() 51 } 52 53 if err := validator.Add(u); err != nil { 54 return nil, newError("failed to add user").Base(err).AtError() 55 } 56 } 57 58 v := core.MustFromContext(ctx) 59 server := &Server{ 60 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 61 validator: validator, 62 cone: ctx.Value("cone").(bool), 63 } 64 65 if config.Fallbacks != nil { 66 server.fallbacks = make(map[string]map[string]map[string]*Fallback) 67 for _, fb := range config.Fallbacks { 68 if server.fallbacks[fb.Name] == nil { 69 server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback) 70 } 71 if server.fallbacks[fb.Name][fb.Alpn] == nil { 72 server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback) 73 } 74 server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb 75 } 76 if server.fallbacks[""] != nil { 77 for name, apfb := range server.fallbacks { 78 if name != "" { 79 for alpn := range server.fallbacks[""] { 80 if apfb[alpn] == nil { 81 apfb[alpn] = make(map[string]*Fallback) 82 } 83 } 84 } 85 } 86 } 87 for _, apfb := range server.fallbacks { 88 if apfb[""] != nil { 89 for alpn, pfb := range apfb { 90 if alpn != "" { // && alpn != "h2" { 91 for path, fb := range apfb[""] { 92 if pfb[path] == nil { 93 pfb[path] = fb 94 } 95 } 96 } 97 } 98 } 99 } 100 if server.fallbacks[""] != nil { 101 for name, apfb := range server.fallbacks { 102 if name != "" { 103 for alpn, pfb := range server.fallbacks[""] { 104 for path, fb := range pfb { 105 if apfb[alpn][path] == nil { 106 apfb[alpn][path] = fb 107 } 108 } 109 } 110 } 111 } 112 } 113 } 114 115 return server, nil 116 } 117 118 // AddUser implements proxy.UserManager.AddUser(). 119 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 120 return s.validator.Add(u) 121 } 122 123 // RemoveUser implements proxy.UserManager.RemoveUser(). 124 func (s *Server) RemoveUser(ctx context.Context, e string) error { 125 return s.validator.Del(e) 126 } 127 128 // Network implements proxy.Inbound.Network(). 129 func (s *Server) Network() []net.Network { 130 return []net.Network{net.Network_TCP, net.Network_UNIX} 131 } 132 133 // Process implements proxy.Inbound.Process(). 134 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 135 sid := session.ExportIDToError(ctx) 136 137 iConn := conn 138 statConn, ok := iConn.(*stat.CounterConnection) 139 if ok { 140 iConn = statConn.Connection 141 } 142 143 sessionPolicy := s.policyManager.ForLevel(0) 144 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 145 return newError("unable to set read deadline").Base(err).AtWarning() 146 } 147 148 first := buf.FromBytes(make([]byte, buf.Size)) 149 first.Clear() 150 firstLen, err := first.ReadFrom(conn) 151 if err != nil { 152 return newError("failed to read first request").Base(err) 153 } 154 newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid) 155 156 bufferedReader := &buf.BufferedReader{ 157 Reader: buf.NewReader(conn), 158 Buffer: buf.MultiBuffer{first}, 159 } 160 161 var user *protocol.MemoryUser 162 163 napfb := s.fallbacks 164 isfb := napfb != nil 165 166 shouldFallback := false 167 if firstLen < 58 || first.Byte(56) != '\r' { 168 // invalid protocol 169 err = newError("not trojan protocol") 170 log.Record(&log.AccessMessage{ 171 From: conn.RemoteAddr(), 172 To: "", 173 Status: log.AccessRejected, 174 Reason: err, 175 }) 176 177 shouldFallback = true 178 } else { 179 user = s.validator.Get(hexString(first.BytesTo(56))) 180 if user == nil { 181 // invalid user, let's fallback 182 err = newError("not a valid user") 183 log.Record(&log.AccessMessage{ 184 From: conn.RemoteAddr(), 185 To: "", 186 Status: log.AccessRejected, 187 Reason: err, 188 }) 189 190 shouldFallback = true 191 } 192 } 193 194 if isfb && shouldFallback { 195 return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader) 196 } else if shouldFallback { 197 return newError("invalid protocol or invalid user") 198 } 199 200 clientReader := &ConnReader{Reader: bufferedReader} 201 if err := clientReader.ParseHeader(); err != nil { 202 log.Record(&log.AccessMessage{ 203 From: conn.RemoteAddr(), 204 To: "", 205 Status: log.AccessRejected, 206 Reason: err, 207 }) 208 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 209 } 210 211 destination := clientReader.Target 212 if err := conn.SetReadDeadline(time.Time{}); err != nil { 213 return newError("unable to set read deadline").Base(err).AtWarning() 214 } 215 216 inbound := session.InboundFromContext(ctx) 217 inbound.Name = "trojan" 218 inbound.SetCanSpliceCopy(3) 219 inbound.User = user 220 sessionPolicy = s.policyManager.ForLevel(user.Level) 221 222 if destination.Network == net.Network_UDP { // handle udp request 223 return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) 224 } 225 226 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 227 From: conn.RemoteAddr(), 228 To: destination, 229 Status: log.AccessAccepted, 230 Reason: "", 231 Email: user.Email, 232 }) 233 234 newError("received request for ", destination).WriteToLog(sid) 235 return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) 236 } 237 238 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { 239 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 240 udpPayload := packet.Payload 241 if udpPayload.UDP == nil { 242 udpPayload.UDP = &packet.Source 243 } 244 245 if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil { 246 newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 247 } 248 }) 249 250 inbound := session.InboundFromContext(ctx) 251 user := inbound.User 252 253 var dest *net.Destination 254 255 for { 256 select { 257 case <-ctx.Done(): 258 return nil 259 default: 260 mb, err := clientReader.ReadMultiBuffer() 261 if err != nil { 262 if errors.Cause(err) != io.EOF { 263 return newError("unexpected EOF").Base(err) 264 } 265 return nil 266 } 267 268 mb2, b := buf.SplitFirst(mb) 269 if b == nil { 270 continue 271 } 272 destination := *b.UDP 273 274 currentPacketCtx := ctx 275 if inbound.Source.IsValid() { 276 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 277 From: inbound.Source, 278 To: destination, 279 Status: log.AccessAccepted, 280 Reason: "", 281 Email: user.Email, 282 }) 283 } 284 newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx)) 285 286 if !s.cone || dest == nil { 287 dest = &destination 288 } 289 290 udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet 291 for _, payload := range mb2 { 292 udpServer.Dispatch(currentPacketCtx, *dest, payload) 293 } 294 } 295 } 296 } 297 298 func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session, 299 destination net.Destination, 300 clientReader buf.Reader, 301 clientWriter buf.Writer, dispatcher routing.Dispatcher, 302 ) error { 303 ctx, cancel := context.WithCancel(ctx) 304 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 305 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 306 307 link, err := dispatcher.Dispatch(ctx, destination) 308 if err != nil { 309 return newError("failed to dispatch request to ", destination).Base(err) 310 } 311 312 requestDone := func() error { 313 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 314 if buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)) != nil { 315 return newError("failed to transfer request").Base(err) 316 } 317 return nil 318 } 319 320 responseDone := func() error { 321 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 322 323 if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil { 324 return newError("failed to write response").Base(err) 325 } 326 return nil 327 } 328 329 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 330 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 331 common.Must(common.Interrupt(link.Reader)) 332 common.Must(common.Interrupt(link.Writer)) 333 return newError("connection ends").Base(err) 334 } 335 336 return nil 337 } 338 339 func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { 340 if err := connection.SetReadDeadline(time.Time{}); err != nil { 341 newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) 342 } 343 newError("fallback starts").Base(err).AtInfo().WriteToLog(sid) 344 345 name := "" 346 alpn := "" 347 if tlsConn, ok := iConn.(*tls.Conn); ok { 348 cs := tlsConn.ConnectionState() 349 name = cs.ServerName 350 alpn = cs.NegotiatedProtocol 351 newError("realName = " + name).AtInfo().WriteToLog(sid) 352 newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) 353 } else if realityConn, ok := iConn.(*reality.Conn); ok { 354 cs := realityConn.ConnectionState() 355 name = cs.ServerName 356 alpn = cs.NegotiatedProtocol 357 newError("realName = " + name).AtInfo().WriteToLog(sid) 358 newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) 359 } 360 name = strings.ToLower(name) 361 alpn = strings.ToLower(alpn) 362 363 if len(napfb) > 1 || napfb[""] == nil { 364 if name != "" && napfb[name] == nil { 365 match := "" 366 for n := range napfb { 367 if n != "" && strings.Contains(name, n) && len(n) > len(match) { 368 match = n 369 } 370 } 371 name = match 372 } 373 } 374 375 if napfb[name] == nil { 376 name = "" 377 } 378 apfb := napfb[name] 379 if apfb == nil { 380 return newError(`failed to find the default "name" config`).AtWarning() 381 } 382 383 if apfb[alpn] == nil { 384 alpn = "" 385 } 386 pfb := apfb[alpn] 387 if pfb == nil { 388 return newError(`failed to find the default "alpn" config`).AtWarning() 389 } 390 391 path := "" 392 if len(pfb) > 1 || pfb[""] == nil { 393 if firstLen >= 18 && first.Byte(4) != '*' { // not h2c 394 firstBytes := first.Bytes() 395 for i := 4; i <= 8; i++ { // 5 -> 9 396 if firstBytes[i] == '/' && firstBytes[i-1] == ' ' { 397 search := len(firstBytes) 398 if search > 64 { 399 search = 64 // up to about 60 400 } 401 for j := i + 1; j < search; j++ { 402 k := firstBytes[j] 403 if k == '\r' || k == '\n' { // avoid logging \r or \n 404 break 405 } 406 if k == '?' || k == ' ' { 407 path = string(firstBytes[i:j]) 408 newError("realPath = " + path).AtInfo().WriteToLog(sid) 409 if pfb[path] == nil { 410 path = "" 411 } 412 break 413 } 414 } 415 break 416 } 417 } 418 } 419 } 420 fb := pfb[path] 421 if fb == nil { 422 return newError(`failed to find the default "path" config`).AtWarning() 423 } 424 425 ctx, cancel := context.WithCancel(ctx) 426 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 427 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 428 429 var conn net.Conn 430 if err := retry.ExponentialBackoff(5, 100).On(func() error { 431 var dialer net.Dialer 432 conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest) 433 if err != nil { 434 return err 435 } 436 return nil 437 }); err != nil { 438 return newError("failed to dial to " + fb.Dest).Base(err).AtWarning() 439 } 440 defer conn.Close() 441 442 serverReader := buf.NewReader(conn) 443 serverWriter := buf.NewWriter(conn) 444 445 postRequest := func() error { 446 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 447 if fb.Xver != 0 { 448 ipType := 4 449 remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String()) 450 if err != nil { 451 ipType = 0 452 } 453 localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String()) 454 if err != nil { 455 ipType = 0 456 } 457 if ipType == 4 { 458 for i := 0; i < len(remoteAddr); i++ { 459 if remoteAddr[i] == ':' { 460 ipType = 6 461 break 462 } 463 } 464 } 465 pro := buf.New() 466 defer pro.Release() 467 switch fb.Xver { 468 case 1: 469 if ipType == 0 { 470 common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n"))) 471 break 472 } 473 if ipType == 4 { 474 common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 475 } else { 476 common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 477 } 478 case 2: 479 common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature 480 if ipType == 0 { 481 common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes 482 break 483 } 484 if ipType == 4 { 485 common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes 486 common.Must2(pro.Write(net.ParseIP(remoteAddr).To4())) 487 common.Must2(pro.Write(net.ParseIP(localAddr).To4())) 488 } else { 489 common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes 490 common.Must2(pro.Write(net.ParseIP(remoteAddr).To16())) 491 common.Must2(pro.Write(net.ParseIP(localAddr).To16())) 492 } 493 p1, _ := strconv.ParseUint(remotePort, 10, 16) 494 p2, _ := strconv.ParseUint(localPort, 10, 16) 495 common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})) 496 } 497 if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil { 498 return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning() 499 } 500 } 501 if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil { 502 return newError("failed to fallback request payload").Base(err).AtInfo() 503 } 504 return nil 505 } 506 507 writer := buf.NewWriter(connection) 508 509 getResponse := func() error { 510 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 511 if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil { 512 return newError("failed to deliver response payload").Base(err).AtInfo() 513 } 514 return nil 515 } 516 517 if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil { 518 common.Must(common.Interrupt(serverReader)) 519 common.Must(common.Interrupt(serverWriter)) 520 return newError("fallback ends").Base(err).AtInfo() 521 } 522 523 return nil 524 }