github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/net/websocket/server.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"fmt"
    10  	"io"
    11  
    12  	http "github.com/hxx258456/ccgo/gmhttp"
    13  )
    14  
    15  func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
    16  	var hs serverHandshaker = &hybiServerHandshaker{Config: config}
    17  	code, err := hs.ReadHandshake(buf.Reader, req)
    18  	if err == ErrBadWebSocketVersion {
    19  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    20  		fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
    21  		buf.WriteString("\r\n")
    22  		buf.WriteString(err.Error())
    23  		buf.Flush()
    24  		return
    25  	}
    26  	if err != nil {
    27  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    28  		buf.WriteString("\r\n")
    29  		buf.WriteString(err.Error())
    30  		buf.Flush()
    31  		return
    32  	}
    33  	if handshake != nil {
    34  		err = handshake(config, req)
    35  		if err != nil {
    36  			code = http.StatusForbidden
    37  			fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    38  			buf.WriteString("\r\n")
    39  			buf.Flush()
    40  			return
    41  		}
    42  	}
    43  	err = hs.AcceptHandshake(buf.Writer)
    44  	if err != nil {
    45  		code = http.StatusBadRequest
    46  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    47  		buf.WriteString("\r\n")
    48  		buf.Flush()
    49  		return
    50  	}
    51  	conn = hs.NewServerConn(buf, rwc, req)
    52  	return
    53  }
    54  
    55  // Server represents a server of a WebSocket.
    56  type Server struct {
    57  	// Config is a WebSocket configuration for new WebSocket connection.
    58  	Config
    59  
    60  	// Handshake is an optional function in WebSocket handshake.
    61  	// For example, you can check, or don't check Origin header.
    62  	// Another example, you can select config.Protocol.
    63  	Handshake func(*Config, *http.Request) error
    64  
    65  	// Handler handles a WebSocket connection.
    66  	Handler
    67  }
    68  
    69  // ServeHTTP implements the http.Handler interface for a WebSocket
    70  func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    71  	s.serveWebSocket(w, req)
    72  }
    73  
    74  func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
    75  	rwc, buf, err := w.(http.Hijacker).Hijack()
    76  	if err != nil {
    77  		panic("Hijack failed: " + err.Error())
    78  	}
    79  	// The server should abort the WebSocket connection if it finds
    80  	// the client did not send a handshake that matches with protocol
    81  	// specification.
    82  	defer rwc.Close()
    83  	conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
    84  	if err != nil {
    85  		return
    86  	}
    87  	if conn == nil {
    88  		panic("unexpected nil conn")
    89  	}
    90  	s.Handler(conn)
    91  }
    92  
    93  // Handler is a simple interface to a WebSocket browser client.
    94  // It checks if Origin header is valid URL by default.
    95  // You might want to verify websocket.Conn.Config().Origin in the func.
    96  // If you use Server instead of Handler, you could call websocket.Origin and
    97  // check the origin in your Handshake func. So, if you want to accept
    98  // non-browser clients, which do not send an Origin header, set a
    99  // Server.Handshake that does not check the origin.
   100  type Handler func(*Conn)
   101  
   102  func checkOrigin(config *Config, req *http.Request) (err error) {
   103  	config.Origin, err = Origin(config, req)
   104  	if err == nil && config.Origin == nil {
   105  		return fmt.Errorf("null origin")
   106  	}
   107  	return err
   108  }
   109  
   110  // ServeHTTP implements the http.Handler interface for a WebSocket
   111  func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   112  	s := Server{Handler: h, Handshake: checkOrigin}
   113  	s.serveWebSocket(w, req)
   114  }