github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/server.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "context" 9 "fmt" 10 "net/http" 11 "strings" 12 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 ) 15 16 type Request struct { 17 Request *http.Request 18 SecWebSocketKey string 19 Protocol []string 20 Header http.Header 21 } 22 23 func NewServerConn(w http.ResponseWriter, req *http.Request, handshake func(*Request) error) (conn *Conn, err error) { 24 var hs = &ServerHandshaker{ 25 Request: &Request{ 26 Request: req, 27 }, 28 } 29 code, err := hs.ReadHandshake(req) 30 if err != nil { 31 if err == ErrBadWebSocketVersion { 32 w.Header().Set("Sec-WebSocket-Version", SupportedProtocolVersion) 33 } 34 w.WriteHeader(code) 35 _, _ = w.Write([]byte(err.Error())) 36 return 37 } 38 39 if handshake != nil { 40 err = handshake(hs.Request) 41 if err != nil { 42 w.WriteHeader(http.StatusForbidden) 43 return 44 } 45 } 46 47 err = hs.AcceptHandshake(w) 48 if err != nil { 49 w.WriteHeader(http.StatusBadRequest) 50 return 51 } 52 53 rwc, buf, err := http.NewResponseController(w).Hijack() 54 if err != nil { 55 err = fmt.Errorf("failed to hijack connection: %w", err) 56 http.Error(w, err.Error(), http.StatusInternalServerError) 57 return nil, err 58 } 59 60 if err := buf.Writer.Flush(); err != nil { 61 return nil, err 62 } 63 64 rwc, err = netapi.MergeBufioReaderConn(rwc, buf.Reader) 65 if err != nil { 66 return nil, err 67 } 68 69 PutBufioReader(buf.Reader) 70 putBufioWriter(buf.Writer) 71 72 return newConn(rwc, true), nil 73 } 74 75 // A HybiServerHandshaker performs a server handshake using hybi draft protocol. 76 type ServerHandshaker struct { 77 *Request 78 } 79 80 func (c *ServerHandshaker) ReadHandshake(req *http.Request) (code int, err error) { 81 if req.Method != "GET" { 82 return http.StatusMethodNotAllowed, ErrBadRequestMethod 83 } 84 // HTTP version can be safely ignored. 85 86 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { 87 return http.StatusBadRequest, ErrNotWebSocket 88 } 89 90 c.SecWebSocketKey = req.Header.Get("Sec-Websocket-Key") 91 if c.SecWebSocketKey == "" { 92 return http.StatusBadRequest, ErrChallengeResponse 93 } 94 95 version := req.Header.Get("Sec-Websocket-Version") 96 switch version { 97 case SupportedProtocolVersion: 98 default: 99 return http.StatusBadRequest, ErrBadWebSocketVersion 100 } 101 102 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) 103 if protocol != "" { 104 for _, v := range strings.Split(protocol, ",") { 105 c.Protocol = append(c.Protocol, strings.TrimSpace(v)) 106 } 107 } 108 109 return http.StatusSwitchingProtocols, nil 110 } 111 112 func (c *ServerHandshaker) AcceptHandshake(w http.ResponseWriter) (err error) { 113 if len(c.Protocol) > 0 && len(c.Protocol) != 1 { 114 // You need choose a Protocol in Handshake func in Server. 115 return ErrBadWebSocketProtocol 116 } 117 118 w.Header().Set("Upgrade", "websocket") 119 w.Header().Set("Connection", "Upgrade") 120 w.Header().Set("Sec-WebSocket-Accept", getNonceAccept(c.SecWebSocketKey)) 121 if len(c.Protocol) > 0 { 122 w.Header().Set("Sec-WebSocket-Protocol", c.Protocol[0]) 123 } 124 // TODO(ukai): send Sec-WebSocket-Extensions. 125 if c.Header != nil { 126 for k, v := range c.Header { 127 if handshakeHeader[k] { 128 continue 129 } 130 for _, vv := range v { 131 w.Header().Add(k, vv) 132 } 133 } 134 } 135 w.WriteHeader(http.StatusSwitchingProtocols) 136 return nil 137 } 138 139 func ServeHTTP(w http.ResponseWriter, req *http.Request, Handler func(context.Context, *Conn) error) error { 140 conn, err := NewServerConn(w, req, nil) 141 if err != nil { 142 return err 143 } 144 return Handler(req.Context(), conn) 145 }