github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/httpupgrade/hub.go (about) 1 package httpupgrade 2 3 import ( 4 "bufio" 5 "context" 6 "crypto/tls" 7 "net/http" 8 "strings" 9 10 "github.com/xmplusdev/xmcore/common" 11 "github.com/xmplusdev/xmcore/common/net" 12 http_proto "github.com/xmplusdev/xmcore/common/protocol/http" 13 "github.com/xmplusdev/xmcore/common/session" 14 "github.com/xmplusdev/xmcore/transport/internet" 15 "github.com/xmplusdev/xmcore/transport/internet/stat" 16 v2tls "github.com/xmplusdev/xmcore/transport/internet/tls" 17 ) 18 19 type server struct { 20 config *Config 21 addConn internet.ConnHandler 22 innnerListener net.Listener 23 } 24 25 func (s *server) Close() error { 26 return s.innnerListener.Close() 27 } 28 29 func (s *server) Addr() net.Addr { 30 return nil 31 } 32 33 func (s *server) Handle(conn net.Conn) (stat.Connection, error) { 34 connReader := bufio.NewReader(conn) 35 req, err := http.ReadRequest(connReader) 36 if err != nil { 37 return nil, err 38 } 39 40 if s.config != nil { 41 host := req.Host 42 if len(s.config.Host) > 0 && host != s.config.Host { 43 return nil, newError("bad host: ", host) 44 } 45 path := s.config.GetNormalizedPath() 46 if req.URL.Path != path { 47 return nil, newError("bad path: ", req.URL.Path) 48 } 49 } 50 51 connection := strings.ToLower(req.Header.Get("Connection")) 52 upgrade := strings.ToLower(req.Header.Get("Upgrade")) 53 if connection != "upgrade" || upgrade != "websocket" { 54 _ = conn.Close() 55 return nil, newError("unrecognized request") 56 } 57 resp := &http.Response{ 58 Status: "101 Switching Protocols", 59 StatusCode: 101, 60 Proto: "HTTP/1.1", 61 ProtoMajor: 1, 62 ProtoMinor: 1, 63 Header: http.Header{}, 64 } 65 resp.Header.Set("Connection", "upgrade") 66 resp.Header.Set("Upgrade", "websocket") 67 err = resp.Write(conn) 68 if err != nil { 69 _ = conn.Close() 70 return nil, err 71 } 72 73 forwardedAddrs := http_proto.ParseXForwardedFor(req.Header) 74 remoteAddr := conn.RemoteAddr() 75 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 76 remoteAddr = &net.TCPAddr{ 77 IP: forwardedAddrs[0].IP(), 78 Port: int(0), 79 } 80 } 81 if remoteAddr == nil { 82 return nil, newError("remoteAddr is nil") 83 } 84 85 conn = newConnection(conn, remoteAddr) 86 return stat.Connection(conn), nil 87 } 88 89 func (s *server) keepAccepting() { 90 for { 91 conn, err := s.innnerListener.Accept() 92 if err != nil { 93 return 94 } 95 handledConn, err := s.Handle(conn) 96 if err != nil { 97 newError("failed to handle request").Base(err).WriteToLog() 98 continue 99 } 100 s.addConn(handledConn) 101 } 102 } 103 104 func listenHTTPUpgrade(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 105 transportConfiguration := streamSettings.ProtocolSettings.(*Config) 106 if transportConfiguration != nil { 107 if streamSettings.SocketSettings == nil { 108 streamSettings.SocketSettings = &internet.SocketConfig{} 109 } 110 streamSettings.SocketSettings.AcceptProxyProtocol = transportConfiguration.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 111 } 112 var listener net.Listener 113 var err error 114 if port == net.Port(0) { // unix 115 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 116 Name: address.Domain(), 117 Net: "unix", 118 }, streamSettings.SocketSettings) 119 if err != nil { 120 return nil, newError("failed to listen unix domain socket(for HttpUpgrade) on ", address).Base(err) 121 } 122 newError("listening unix domain socket(for HttpUpgrade) on ", address).WriteToLog(session.ExportIDToError(ctx)) 123 } else { // tcp 124 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 125 IP: address.IP(), 126 Port: int(port), 127 }, streamSettings.SocketSettings) 128 if err != nil { 129 return nil, newError("failed to listen TCP(for HttpUpgrade) on ", address, ":", port).Base(err) 130 } 131 newError("listening TCP(for HttpUpgrade) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 132 } 133 134 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 135 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 136 } 137 138 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 139 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 140 listener = tls.NewListener(listener, tlsConfig) 141 } 142 } 143 144 serverInstance := &server{ 145 config: transportConfiguration, 146 addConn: addConn, 147 innnerListener: listener, 148 } 149 go serverInstance.keepAccepting() 150 return serverInstance, nil 151 } 152 153 func init() { 154 common.Must(internet.RegisterTransportListener(protocolName, listenHTTPUpgrade)) 155 }