github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/http_request_websocket.go (about) 1 package nodes 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "github.com/TeaOSLab/EdgeNode/internal/utils" 8 "io" 9 "net/http" 10 "net/url" 11 ) 12 13 // WebsocketResponseReader Websocket响应Reader 14 type WebsocketResponseReader struct { 15 rawReader io.Reader 16 buf []byte 17 } 18 19 func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader { 20 return &WebsocketResponseReader{ 21 rawReader: rawReader, 22 } 23 } 24 25 func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) { 26 n, err = this.rawReader.Read(p) 27 if n > 0 { 28 if len(this.buf) == 0 { 29 this.buf = make([]byte, n) 30 copy(this.buf, p[:n]) 31 } else { 32 this.buf = append(this.buf, p[:n]...) 33 } 34 } 35 return 36 } 37 38 // 处理Websocket请求 39 func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) { 40 // 设置不缓存 41 this.web.Cache = nil 42 43 if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn { 44 this.writer.WriteHeader(http.StatusForbidden) 45 this.addError(errors.New("websocket have not been enabled yet")) 46 return 47 } 48 49 // TODO 实现handshakeTimeout 50 51 // 校验来源 52 var requestOrigin = this.RawReq.Header.Get("Origin") 53 if len(requestOrigin) > 0 { 54 u, err := url.Parse(requestOrigin) 55 if err == nil { 56 if !this.web.Websocket.MatchOrigin(u.Host) { 57 this.writer.WriteHeader(http.StatusForbidden) 58 this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed")) 59 return 60 } 61 } 62 } 63 64 // 标记 65 this.isWebsocketResponse = true 66 67 // 设置指定的来源域 68 if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 { 69 var newRequestOrigin = this.web.Websocket.RequestOrigin 70 if this.web.Websocket.RequestOriginHasVariables() { 71 newRequestOrigin = this.Format(newRequestOrigin) 72 } 73 this.RawReq.Header.Set("Origin", newRequestOrigin) 74 } 75 76 // 获取当前连接 77 var requestConn = this.RawReq.Context().Value(HTTPConnContextKey) 78 if requestConn == nil { 79 return 80 } 81 82 // 连接源站 83 originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost) 84 if err != nil { 85 if isLastRetry { 86 this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false) 87 } 88 89 // 增加失败次数 90 SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() { 91 this.reverseProxy.ResetScheduling() 92 }) 93 94 shouldRetry = true 95 return 96 } 97 98 if !this.origin.IsOk { 99 SharedOriginStateManager.Success(this.origin, func() { 100 this.reverseProxy.ResetScheduling() 101 }) 102 } 103 104 defer func() { 105 _ = originConn.Close() 106 }() 107 108 err = this.RawReq.Write(originConn) 109 if err != nil { 110 this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false) 111 return 112 } 113 114 requestClientConn, ok := requestConn.(ClientConnInterface) 115 if ok { 116 requestClientConn.SetIsPersistent(true) 117 } 118 119 clientConn, _, err := this.writer.Hijack() 120 if err != nil || clientConn == nil { 121 this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false) 122 return 123 } 124 defer func() { 125 _ = clientConn.Close() 126 }() 127 128 go func() { 129 // 读取第一个响应 130 var respReader = NewWebsocketResponseReader(originConn) 131 resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq) 132 if respErr != nil || resp == nil { 133 if resp != nil && resp.Body != nil { 134 _ = resp.Body.Close() 135 } 136 137 _ = clientConn.Close() 138 _ = originConn.Close() 139 return 140 } 141 142 this.ProcessResponseHeaders(resp.Header, resp.StatusCode) 143 this.writer.statusCode = resp.StatusCode 144 145 // 将响应写回客户端 146 err = resp.Write(clientConn) 147 if err != nil { 148 if resp.Body != nil { 149 _ = resp.Body.Close() 150 } 151 152 _ = clientConn.Close() 153 _ = originConn.Close() 154 return 155 } 156 157 // 剩余已经从源站读取的内容 158 var headerBytes = respReader.buf 159 var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF 160 if headerIndex > 0 { 161 var leftBytes = headerBytes[headerIndex+4:] 162 if len(leftBytes) > 0 { 163 _, writeErr := clientConn.Write(leftBytes) 164 if writeErr != nil { 165 if resp.Body != nil { 166 _ = resp.Body.Close() 167 } 168 169 _ = clientConn.Close() 170 _ = originConn.Close() 171 return 172 } 173 } 174 } 175 176 if resp.Body != nil { 177 _ = resp.Body.Close() 178 } 179 180 // 复制剩余的数据 181 var buf = utils.BytePool4k.Get() 182 defer utils.BytePool4k.Put(buf) 183 for { 184 n, readErr := originConn.Read(buf.Bytes) 185 if n > 0 { 186 this.writer.sentBodyBytes += int64(n) 187 _, writeErr := clientConn.Write(buf.Bytes[:n]) 188 if writeErr != nil { 189 break 190 } 191 } 192 if readErr != nil { 193 break 194 } 195 } 196 _ = clientConn.Close() 197 _ = originConn.Close() 198 }() 199 200 var buf = utils.BytePool4k.Get() 201 _, _ = io.CopyBuffer(originConn, clientConn, buf.Bytes) 202 utils.BytePool4k.Put(buf) 203 204 return 205 }