github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/proxy/trojan/server.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "io" 6 "strconv" 7 "time" 8 9 core "github.com/v2fly/v2ray-core/v5" 10 "github.com/v2fly/v2ray-core/v5/common" 11 "github.com/v2fly/v2ray-core/v5/common/buf" 12 "github.com/v2fly/v2ray-core/v5/common/errors" 13 "github.com/v2fly/v2ray-core/v5/common/log" 14 "github.com/v2fly/v2ray-core/v5/common/net" 15 "github.com/v2fly/v2ray-core/v5/common/net/packetaddr" 16 "github.com/v2fly/v2ray-core/v5/common/protocol" 17 udp_proto "github.com/v2fly/v2ray-core/v5/common/protocol/udp" 18 "github.com/v2fly/v2ray-core/v5/common/retry" 19 "github.com/v2fly/v2ray-core/v5/common/session" 20 "github.com/v2fly/v2ray-core/v5/common/signal" 21 "github.com/v2fly/v2ray-core/v5/common/task" 22 "github.com/v2fly/v2ray-core/v5/features/policy" 23 "github.com/v2fly/v2ray-core/v5/features/routing" 24 "github.com/v2fly/v2ray-core/v5/transport/internet" 25 "github.com/v2fly/v2ray-core/v5/transport/internet/tls" 26 "github.com/v2fly/v2ray-core/v5/transport/internet/udp" 27 ) 28 29 func init() { 30 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 31 return NewServer(ctx, config.(*ServerConfig)) 32 })) 33 } 34 35 // Server is an inbound connection handler that handles messages in trojan protocol. 36 type Server struct { 37 policyManager policy.Manager 38 validator *Validator 39 fallbacks map[string]map[string]*Fallback // or nil 40 packetEncoding packetaddr.PacketAddrType 41 } 42 43 // NewServer creates a new trojan inbound handler. 44 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 45 validator := new(Validator) 46 for _, user := range config.Users { 47 u, err := user.ToMemoryUser() 48 if err != nil { 49 return nil, newError("failed to get trojan user").Base(err).AtError() 50 } 51 52 if err := validator.Add(u); err != nil { 53 return nil, newError("failed to add user").Base(err).AtError() 54 } 55 } 56 57 v := core.MustFromContext(ctx) 58 server := &Server{ 59 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 60 validator: validator, 61 packetEncoding: config.PacketEncoding, 62 } 63 64 if config.Fallbacks != nil { 65 server.fallbacks = make(map[string]map[string]*Fallback) 66 for _, fb := range config.Fallbacks { 67 if server.fallbacks[fb.Alpn] == nil { 68 server.fallbacks[fb.Alpn] = make(map[string]*Fallback) 69 } 70 server.fallbacks[fb.Alpn][fb.Path] = fb 71 } 72 if server.fallbacks[""] != nil { 73 for alpn, pfb := range server.fallbacks { 74 if alpn != "" { // && alpn != "h2" { 75 for path, fb := range server.fallbacks[""] { 76 if pfb[path] == nil { 77 pfb[path] = fb 78 } 79 } 80 } 81 } 82 } 83 } 84 85 return server, nil 86 } 87 88 // AddUser implements proxy.UserManager.AddUser(). 89 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 90 return s.validator.Add(u) 91 } 92 93 // RemoveUser implements proxy.UserManager.RemoveUser(). 94 func (s *Server) RemoveUser(ctx context.Context, e string) error { 95 return s.validator.Del(e) 96 } 97 98 // Network implements proxy.Inbound.Network(). 99 func (s *Server) Network() []net.Network { 100 return []net.Network{net.Network_TCP, net.Network_UNIX} 101 } 102 103 // Process implements proxy.Inbound.Process(). 104 func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { 105 sid := session.ExportIDToError(ctx) 106 107 iConn := conn 108 if statConn, ok := iConn.(*internet.StatCouterConnection); ok { 109 iConn = statConn.Connection 110 } 111 112 sessionPolicy := s.policyManager.ForLevel(0) 113 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 114 return newError("unable to set read deadline").Base(err).AtWarning() 115 } 116 117 first := buf.New() 118 defer first.Release() 119 120 firstLen, err := first.ReadFrom(conn) 121 if err != nil { 122 return newError("failed to read first request").Base(err) 123 } 124 newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid) 125 126 bufferedReader := &buf.BufferedReader{ 127 Reader: buf.NewReader(conn), 128 Buffer: buf.MultiBuffer{first}, 129 } 130 131 var user *protocol.MemoryUser 132 133 apfb := s.fallbacks 134 isfb := apfb != nil 135 136 shouldFallback := false 137 if firstLen < 58 || first.Byte(56) != '\r' { 138 // invalid protocol 139 err = newError("not trojan protocol") 140 log.Record(&log.AccessMessage{ 141 From: conn.RemoteAddr(), 142 To: "", 143 Status: log.AccessRejected, 144 Reason: err, 145 }) 146 147 shouldFallback = true 148 } else { 149 user = s.validator.Get(hexString(first.BytesTo(56))) 150 if user == nil { 151 // invalid user, let's fallback 152 err = newError("not a valid user") 153 log.Record(&log.AccessMessage{ 154 From: conn.RemoteAddr(), 155 To: "", 156 Status: log.AccessRejected, 157 Reason: err, 158 }) 159 160 shouldFallback = true 161 } 162 } 163 164 if isfb && shouldFallback { 165 return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, apfb, first, firstLen, bufferedReader) 166 } else if shouldFallback { 167 return newError("invalid protocol or invalid user") 168 } 169 170 clientReader := &ConnReader{Reader: bufferedReader} 171 if err := clientReader.ParseHeader(); err != nil { 172 log.Record(&log.AccessMessage{ 173 From: conn.RemoteAddr(), 174 To: "", 175 Status: log.AccessRejected, 176 Reason: err, 177 }) 178 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 179 } 180 181 destination := clientReader.Target 182 if err := conn.SetReadDeadline(time.Time{}); err != nil { 183 return newError("unable to set read deadline").Base(err).AtWarning() 184 } 185 186 inbound := session.InboundFromContext(ctx) 187 if inbound == nil { 188 panic("no inbound metadata") 189 } 190 inbound.User = user 191 sessionPolicy = s.policyManager.ForLevel(user.Level) 192 193 if destination.Network == net.Network_UDP { // handle udp request 194 return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) 195 } 196 197 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 198 From: conn.RemoteAddr(), 199 To: destination, 200 Status: log.AccessAccepted, 201 Reason: "", 202 Email: user.Email, 203 }) 204 205 newError("received request for ", destination).WriteToLog(sid) 206 return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) 207 } 208 209 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { 210 udpDispatcherConstructor := udp.NewSplitDispatcher 211 switch s.packetEncoding { 212 case packetaddr.PacketAddrType_None: 213 case packetaddr.PacketAddrType_Packet: 214 packetAddrDispatcherFactory := udp.NewPacketAddrDispatcherCreator(ctx) 215 udpDispatcherConstructor = packetAddrDispatcherFactory.NewPacketAddrDispatcher 216 } 217 218 udpServer := udpDispatcherConstructor(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 219 if err := clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source); err != nil { 220 newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 221 } 222 }) 223 224 inbound := session.InboundFromContext(ctx) 225 user := inbound.User 226 227 for { 228 select { 229 case <-ctx.Done(): 230 return nil 231 default: 232 p, err := clientReader.ReadMultiBufferWithMetadata() 233 if err != nil { 234 if errors.Cause(err) != io.EOF { 235 return newError("unexpected EOF").Base(err) 236 } 237 return nil 238 } 239 currentPacketCtx := ctx 240 currentPacketCtx = log.ContextWithAccessMessage(currentPacketCtx, &log.AccessMessage{ 241 From: inbound.Source, 242 To: p.Target, 243 Status: log.AccessAccepted, 244 Reason: "", 245 Email: user.Email, 246 }) 247 newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) 248 249 for _, b := range p.Buffer { 250 udpServer.Dispatch(currentPacketCtx, p.Target, b) 251 } 252 } 253 } 254 } 255 256 func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session, 257 destination net.Destination, 258 clientReader buf.Reader, 259 clientWriter buf.Writer, dispatcher routing.Dispatcher, 260 ) error { 261 ctx, cancel := context.WithCancel(ctx) 262 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 263 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 264 265 link, err := dispatcher.Dispatch(ctx, destination) 266 if err != nil { 267 return newError("failed to dispatch request to ", destination).Base(err) 268 } 269 270 requestDone := func() error { 271 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 272 273 if err := buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)); err != nil { 274 return newError("failed to transfer request").Base(err) 275 } 276 return nil 277 } 278 279 responseDone := func() error { 280 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 281 282 if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil { 283 return newError("failed to write response").Base(err) 284 } 285 return nil 286 } 287 288 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 289 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 290 common.Must(common.Interrupt(link.Reader)) 291 common.Must(common.Interrupt(link.Writer)) 292 return newError("connection ends").Base(err) 293 } 294 295 return nil 296 } 297 298 func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, apfb map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { 299 if err := connection.SetReadDeadline(time.Time{}); err != nil { 300 newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) 301 } 302 newError("fallback starts").Base(err).AtInfo().WriteToLog(sid) 303 304 alpn := "" 305 if len(apfb) > 1 || apfb[""] == nil { 306 if tlsConn, ok := iConn.(*tls.Conn); ok { 307 alpn = tlsConn.ConnectionState().NegotiatedProtocol 308 newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) 309 } 310 if apfb[alpn] == nil { 311 alpn = "" 312 } 313 } 314 pfb := apfb[alpn] 315 if pfb == nil { 316 return newError(`failed to find the default "alpn" config`).AtWarning() 317 } 318 319 path := "" 320 if len(pfb) > 1 || pfb[""] == nil { 321 if firstLen >= 18 && first.Byte(4) != '*' { // not h2c 322 firstBytes := first.Bytes() 323 for i := 4; i <= 8; i++ { // 5 -> 9 324 if firstBytes[i] == '/' && firstBytes[i-1] == ' ' { 325 search := len(firstBytes) 326 if search > 64 { 327 search = 64 // up to about 60 328 } 329 for j := i + 1; j < search; j++ { 330 k := firstBytes[j] 331 if k == '\r' || k == '\n' { // avoid logging \r or \n 332 break 333 } 334 if k == ' ' { 335 path = string(firstBytes[i:j]) 336 newError("realPath = " + path).AtInfo().WriteToLog(sid) 337 if pfb[path] == nil { 338 path = "" 339 } 340 break 341 } 342 } 343 break 344 } 345 } 346 } 347 } 348 fb := pfb[path] 349 if fb == nil { 350 return newError(`failed to find the default "path" config`).AtWarning() 351 } 352 353 ctx, cancel := context.WithCancel(ctx) 354 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 355 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 356 357 var conn net.Conn 358 if err := retry.ExponentialBackoff(5, 100).On(func() error { 359 var dialer net.Dialer 360 conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest) 361 if err != nil { 362 return err 363 } 364 return nil 365 }); err != nil { 366 return newError("failed to dial to " + fb.Dest).Base(err).AtWarning() 367 } 368 defer conn.Close() 369 370 serverReader := buf.NewReader(conn) 371 serverWriter := buf.NewWriter(conn) 372 373 postRequest := func() error { 374 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 375 if fb.Xver != 0 { 376 remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String()) 377 if err != nil { 378 return err 379 } 380 localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String()) 381 if err != nil { 382 return err 383 } 384 ipv4 := true 385 for i := 0; i < len(remoteAddr); i++ { 386 if remoteAddr[i] == ':' { 387 ipv4 = false 388 break 389 } 390 } 391 pro := buf.New() 392 defer pro.Release() 393 switch fb.Xver { 394 case 1: 395 if ipv4 { 396 common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 397 } else { 398 common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))) 399 } 400 case 2: 401 common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21"))) // signature + v2 + PROXY 402 if ipv4 { 403 common.Must2(pro.Write([]byte("\x11\x00\x0C"))) // AF_INET + STREAM + 12 bytes 404 common.Must2(pro.Write(net.ParseIP(remoteAddr).To4())) 405 common.Must2(pro.Write(net.ParseIP(localAddr).To4())) 406 } else { 407 common.Must2(pro.Write([]byte("\x21\x00\x24"))) // AF_INET6 + STREAM + 36 bytes 408 common.Must2(pro.Write(net.ParseIP(remoteAddr).To16())) 409 common.Must2(pro.Write(net.ParseIP(localAddr).To16())) 410 } 411 p1, _ := strconv.ParseUint(remotePort, 10, 16) 412 p2, _ := strconv.ParseUint(localPort, 10, 16) 413 common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})) 414 } 415 if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil { 416 return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning() 417 } 418 } 419 if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil { 420 return newError("failed to fallback request payload").Base(err).AtInfo() 421 } 422 return nil 423 } 424 425 writer := buf.NewWriter(connection) 426 427 getResponse := func() error { 428 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 429 if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil { 430 return newError("failed to deliver response payload").Base(err).AtInfo() 431 } 432 return nil 433 } 434 435 if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil { 436 common.Must(common.Interrupt(serverReader)) 437 common.Must(common.Interrupt(serverWriter)) 438 return newError("fallback ends").Base(err).AtInfo() 439 } 440 441 return nil 442 }