github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/server.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"net/http"
    11  	"strings"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  )
    15  
    16  type Request struct {
    17  	Request         *http.Request
    18  	SecWebSocketKey string
    19  	Protocol        []string
    20  	Header          http.Header
    21  }
    22  
    23  func NewServerConn(w http.ResponseWriter, req *http.Request, handshake func(*Request) error) (conn *Conn, err error) {
    24  	var hs = &ServerHandshaker{
    25  		Request: &Request{
    26  			Request: req,
    27  		},
    28  	}
    29  	code, err := hs.ReadHandshake(req)
    30  	if err != nil {
    31  		if err == ErrBadWebSocketVersion {
    32  			w.Header().Set("Sec-WebSocket-Version", SupportedProtocolVersion)
    33  		}
    34  		w.WriteHeader(code)
    35  		_, _ = w.Write([]byte(err.Error()))
    36  		return
    37  	}
    38  
    39  	if handshake != nil {
    40  		err = handshake(hs.Request)
    41  		if err != nil {
    42  			w.WriteHeader(http.StatusForbidden)
    43  			return
    44  		}
    45  	}
    46  
    47  	err = hs.AcceptHandshake(w)
    48  	if err != nil {
    49  		w.WriteHeader(http.StatusBadRequest)
    50  		return
    51  	}
    52  
    53  	rwc, buf, err := http.NewResponseController(w).Hijack()
    54  	if err != nil {
    55  		err = fmt.Errorf("failed to hijack connection: %w", err)
    56  		http.Error(w, err.Error(), http.StatusInternalServerError)
    57  		return nil, err
    58  	}
    59  
    60  	if err := buf.Writer.Flush(); err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	rwc, err = netapi.MergeBufioReaderConn(rwc, buf.Reader)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	PutBufioReader(buf.Reader)
    70  	putBufioWriter(buf.Writer)
    71  
    72  	return newConn(rwc, true), nil
    73  }
    74  
    75  // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
    76  type ServerHandshaker struct {
    77  	*Request
    78  }
    79  
    80  func (c *ServerHandshaker) ReadHandshake(req *http.Request) (code int, err error) {
    81  	if req.Method != "GET" {
    82  		return http.StatusMethodNotAllowed, ErrBadRequestMethod
    83  	}
    84  	// HTTP version can be safely ignored.
    85  
    86  	if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
    87  		return http.StatusBadRequest, ErrNotWebSocket
    88  	}
    89  
    90  	c.SecWebSocketKey = req.Header.Get("Sec-Websocket-Key")
    91  	if c.SecWebSocketKey == "" {
    92  		return http.StatusBadRequest, ErrChallengeResponse
    93  	}
    94  
    95  	version := req.Header.Get("Sec-Websocket-Version")
    96  	switch version {
    97  	case SupportedProtocolVersion:
    98  	default:
    99  		return http.StatusBadRequest, ErrBadWebSocketVersion
   100  	}
   101  
   102  	protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
   103  	if protocol != "" {
   104  		for _, v := range strings.Split(protocol, ",") {
   105  			c.Protocol = append(c.Protocol, strings.TrimSpace(v))
   106  		}
   107  	}
   108  
   109  	return http.StatusSwitchingProtocols, nil
   110  }
   111  
   112  func (c *ServerHandshaker) AcceptHandshake(w http.ResponseWriter) (err error) {
   113  	if len(c.Protocol) > 0 && len(c.Protocol) != 1 {
   114  		// You need choose a Protocol in Handshake func in Server.
   115  		return ErrBadWebSocketProtocol
   116  	}
   117  
   118  	w.Header().Set("Upgrade", "websocket")
   119  	w.Header().Set("Connection", "Upgrade")
   120  	w.Header().Set("Sec-WebSocket-Accept", getNonceAccept(c.SecWebSocketKey))
   121  	if len(c.Protocol) > 0 {
   122  		w.Header().Set("Sec-WebSocket-Protocol", c.Protocol[0])
   123  	}
   124  	// TODO(ukai): send Sec-WebSocket-Extensions.
   125  	if c.Header != nil {
   126  		for k, v := range c.Header {
   127  			if handshakeHeader[k] {
   128  				continue
   129  			}
   130  			for _, vv := range v {
   131  				w.Header().Add(k, vv)
   132  			}
   133  		}
   134  	}
   135  	w.WriteHeader(http.StatusSwitchingProtocols)
   136  	return nil
   137  }
   138  
   139  func ServeHTTP(w http.ResponseWriter, req *http.Request, Handler func(context.Context, *Conn) error) error {
   140  	conn, err := NewServerConn(w, req, nil)
   141  	if err != nil {
   142  		return err
   143  	}
   144  	return Handler(req.Context(), conn)
   145  }