github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/http_request_websocket.go (about)

     1  package nodes
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"github.com/TeaOSLab/EdgeNode/internal/utils"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  )
    12  
    13  // WebsocketResponseReader Websocket响应Reader
    14  type WebsocketResponseReader struct {
    15  	rawReader io.Reader
    16  	buf       []byte
    17  }
    18  
    19  func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
    20  	return &WebsocketResponseReader{
    21  		rawReader: rawReader,
    22  	}
    23  }
    24  
    25  func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
    26  	n, err = this.rawReader.Read(p)
    27  	if n > 0 {
    28  		if len(this.buf) == 0 {
    29  			this.buf = make([]byte, n)
    30  			copy(this.buf, p[:n])
    31  		} else {
    32  			this.buf = append(this.buf, p[:n]...)
    33  		}
    34  	}
    35  	return
    36  }
    37  
    38  // 处理Websocket请求
    39  func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
    40  	// 设置不缓存
    41  	this.web.Cache = nil
    42  
    43  	if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
    44  		this.writer.WriteHeader(http.StatusForbidden)
    45  		this.addError(errors.New("websocket have not been enabled yet"))
    46  		return
    47  	}
    48  
    49  	// TODO 实现handshakeTimeout
    50  
    51  	// 校验来源
    52  	var requestOrigin = this.RawReq.Header.Get("Origin")
    53  	if len(requestOrigin) > 0 {
    54  		u, err := url.Parse(requestOrigin)
    55  		if err == nil {
    56  			if !this.web.Websocket.MatchOrigin(u.Host) {
    57  				this.writer.WriteHeader(http.StatusForbidden)
    58  				this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
    59  				return
    60  			}
    61  		}
    62  	}
    63  
    64  	// 标记
    65  	this.isWebsocketResponse = true
    66  
    67  	// 设置指定的来源域
    68  	if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
    69  		var newRequestOrigin = this.web.Websocket.RequestOrigin
    70  		if this.web.Websocket.RequestOriginHasVariables() {
    71  			newRequestOrigin = this.Format(newRequestOrigin)
    72  		}
    73  		this.RawReq.Header.Set("Origin", newRequestOrigin)
    74  	}
    75  
    76  	// 获取当前连接
    77  	var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
    78  	if requestConn == nil {
    79  		return
    80  	}
    81  
    82  	// 连接源站
    83  	originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
    84  	if err != nil {
    85  		if isLastRetry {
    86  			this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
    87  		}
    88  
    89  		// 增加失败次数
    90  		SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
    91  			this.reverseProxy.ResetScheduling()
    92  		})
    93  
    94  		shouldRetry = true
    95  		return
    96  	}
    97  
    98  	if !this.origin.IsOk {
    99  		SharedOriginStateManager.Success(this.origin, func() {
   100  			this.reverseProxy.ResetScheduling()
   101  		})
   102  	}
   103  
   104  	defer func() {
   105  		_ = originConn.Close()
   106  	}()
   107  
   108  	err = this.RawReq.Write(originConn)
   109  	if err != nil {
   110  		this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
   111  		return
   112  	}
   113  
   114  	requestClientConn, ok := requestConn.(ClientConnInterface)
   115  	if ok {
   116  		requestClientConn.SetIsPersistent(true)
   117  	}
   118  
   119  	clientConn, _, err := this.writer.Hijack()
   120  	if err != nil || clientConn == nil {
   121  		this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
   122  		return
   123  	}
   124  	defer func() {
   125  		_ = clientConn.Close()
   126  	}()
   127  
   128  	go func() {
   129  		// 读取第一个响应
   130  		var respReader = NewWebsocketResponseReader(originConn)
   131  		resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
   132  		if respErr != nil || resp == nil {
   133  			if resp != nil && resp.Body != nil {
   134  				_ = resp.Body.Close()
   135  			}
   136  
   137  			_ = clientConn.Close()
   138  			_ = originConn.Close()
   139  			return
   140  		}
   141  
   142  		this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
   143  		this.writer.statusCode = resp.StatusCode
   144  
   145  		// 将响应写回客户端
   146  		err = resp.Write(clientConn)
   147  		if err != nil {
   148  			if resp.Body != nil {
   149  				_ = resp.Body.Close()
   150  			}
   151  
   152  			_ = clientConn.Close()
   153  			_ = originConn.Close()
   154  			return
   155  		}
   156  
   157  		// 剩余已经从源站读取的内容
   158  		var headerBytes = respReader.buf
   159  		var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
   160  		if headerIndex > 0 {
   161  			var leftBytes = headerBytes[headerIndex+4:]
   162  			if len(leftBytes) > 0 {
   163  				_, writeErr := clientConn.Write(leftBytes)
   164  				if writeErr != nil {
   165  					if resp.Body != nil {
   166  						_ = resp.Body.Close()
   167  					}
   168  
   169  					_ = clientConn.Close()
   170  					_ = originConn.Close()
   171  					return
   172  				}
   173  			}
   174  		}
   175  
   176  		if resp.Body != nil {
   177  			_ = resp.Body.Close()
   178  		}
   179  
   180  		// 复制剩余的数据
   181  		var buf = utils.BytePool4k.Get()
   182  		defer utils.BytePool4k.Put(buf)
   183  		for {
   184  			n, readErr := originConn.Read(buf.Bytes)
   185  			if n > 0 {
   186  				this.writer.sentBodyBytes += int64(n)
   187  				_, writeErr := clientConn.Write(buf.Bytes[:n])
   188  				if writeErr != nil {
   189  					break
   190  				}
   191  			}
   192  			if readErr != nil {
   193  				break
   194  			}
   195  		}
   196  		_ = clientConn.Close()
   197  		_ = originConn.Close()
   198  	}()
   199  
   200  	var buf = utils.BytePool4k.Get()
   201  	_, _ = io.CopyBuffer(originConn, clientConn, buf.Bytes)
   202  	utils.BytePool4k.Put(buf)
   203  
   204  	return
   205  }