github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 "github.com/xmplusdev/xmcore/common" 16 "github.com/xmplusdev/xmcore/common/net" 17 http_proto "github.com/xmplusdev/xmcore/common/protocol/http" 18 "github.com/xmplusdev/xmcore/common/session" 19 "github.com/xmplusdev/xmcore/transport/internet" 20 v2tls "github.com/xmplusdev/xmcore/transport/internet/tls" 21 ) 22 23 type requestHandler struct { 24 host string 25 path string 26 ln *Listener 27 } 28 29 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 30 31 var upgrader = &websocket.Upgrader{ 32 ReadBufferSize: 0, 33 WriteBufferSize: 0, 34 HandshakeTimeout: time.Second * 4, 35 CheckOrigin: func(r *http.Request) bool { 36 return true 37 }, 38 } 39 40 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 41 if len(h.host) > 0 && request.Host != h.host { 42 writer.WriteHeader(http.StatusNotFound) 43 return 44 } 45 if request.URL.Path != h.path { 46 writer.WriteHeader(http.StatusNotFound) 47 return 48 } 49 50 var extraReader io.Reader 51 responseHeader := http.Header{} 52 if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { 53 if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { 54 extraReader = bytes.NewReader(ed) 55 responseHeader.Set("Sec-WebSocket-Protocol", str) 56 } 57 } 58 59 conn, err := upgrader.Upgrade(writer, request, responseHeader) 60 if err != nil { 61 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 62 return 63 } 64 65 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 66 remoteAddr := conn.RemoteAddr() 67 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 68 remoteAddr = &net.TCPAddr{ 69 IP: forwardedAddrs[0].IP(), 70 Port: int(0), 71 } 72 } 73 74 h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) 75 } 76 77 type Listener struct { 78 sync.Mutex 79 server http.Server 80 listener net.Listener 81 config *Config 82 addConn internet.ConnHandler 83 } 84 85 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 86 l := &Listener{ 87 addConn: addConn, 88 } 89 wsSettings := streamSettings.ProtocolSettings.(*Config) 90 l.config = wsSettings 91 if l.config != nil { 92 if streamSettings.SocketSettings == nil { 93 streamSettings.SocketSettings = &internet.SocketConfig{} 94 } 95 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 96 } 97 var listener net.Listener 98 var err error 99 if port == net.Port(0) { // unix 100 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 101 Name: address.Domain(), 102 Net: "unix", 103 }, streamSettings.SocketSettings) 104 if err != nil { 105 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 106 } 107 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 108 } else { // tcp 109 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 110 IP: address.IP(), 111 Port: int(port), 112 }, streamSettings.SocketSettings) 113 if err != nil { 114 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 115 } 116 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 117 } 118 119 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 120 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 121 } 122 123 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 124 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 125 listener = tls.NewListener(listener, tlsConfig) 126 } 127 } 128 129 l.listener = listener 130 131 l.server = http.Server{ 132 Handler: &requestHandler{ 133 host: wsSettings.Host, 134 path: wsSettings.GetNormalizedPath(), 135 ln: l, 136 }, 137 ReadHeaderTimeout: time.Second * 4, 138 MaxHeaderBytes: 8192, 139 } 140 141 go func() { 142 if err := l.server.Serve(l.listener); err != nil { 143 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 144 } 145 }() 146 147 return l, err 148 } 149 150 // Addr implements net.Listener.Addr(). 151 func (ln *Listener) Addr() net.Addr { 152 return ln.listener.Addr() 153 } 154 155 // Close implements net.Listener.Close(). 156 func (ln *Listener) Close() error { 157 return ln.listener.Close() 158 } 159 160 func init() { 161 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 162 }