github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/websocket/server.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  )
    13  
    14  func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
    15  	var hs serverHandshaker = &hybiServerHandshaker{Config: config}
    16  	code, err := hs.ReadHandshake(buf.Reader, req)
    17  	if err == ErrBadWebSocketVersion {
    18  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    19  		fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
    20  		buf.WriteString("\r\n")
    21  		buf.WriteString(err.Error())
    22  		buf.Flush()
    23  		return
    24  	}
    25  	if err != nil {
    26  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    27  		buf.WriteString("\r\n")
    28  		buf.WriteString(err.Error())
    29  		buf.Flush()
    30  		return
    31  	}
    32  	if handshake != nil {
    33  		err = handshake(config, req)
    34  		if err != nil {
    35  			code = http.StatusForbidden
    36  			fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    37  			buf.WriteString("\r\n")
    38  			buf.Flush()
    39  			return
    40  		}
    41  	}
    42  	err = hs.AcceptHandshake(buf.Writer)
    43  	if err != nil {
    44  		code = http.StatusBadRequest
    45  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
    46  		buf.WriteString("\r\n")
    47  		buf.Flush()
    48  		return
    49  	}
    50  	conn = hs.NewServerConn(buf, rwc, req)
    51  	return
    52  }
    53  
    54  // Server represents a server of a WebSocket.
    55  type Server struct {
    56  	// Config is a WebSocket configuration for new WebSocket connection.
    57  	Config
    58  
    59  	// Handshake is an optional function in WebSocket handshake.
    60  	// For example, you can check, or don't check Origin header.
    61  	// Another example, you can select config.Protocol.
    62  	Handshake func(*Config, *http.Request) error
    63  
    64  	// Handler handles a WebSocket connection.
    65  	Handler
    66  }
    67  
    68  // ServeHTTP implements the http.Handler interface for a WebSocket
    69  func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    70  	s.serveWebSocket(w, req)
    71  }
    72  
    73  func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
    74  	rwc, buf, err := w.(http.Hijacker).Hijack()
    75  	if err != nil {
    76  		panic("Hijack failed: " + err.Error())
    77  	}
    78  	// The server should abort the WebSocket connection if it finds
    79  	// the client did not send a handshake that matches with protocol
    80  	// specification.
    81  	defer rwc.Close()
    82  	conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
    83  	if err != nil {
    84  		return
    85  	}
    86  	if conn == nil {
    87  		panic("unexpected nil conn")
    88  	}
    89  	s.Handler(conn)
    90  }
    91  
    92  // Handler is a simple interface to a WebSocket browser client.
    93  // It checks if Origin header is valid URL by default.
    94  // You might want to verify websocket.Conn.Config().Origin in the func.
    95  // If you use Server instead of Handler, you could call websocket.Origin and
    96  // check the origin in your Handshake func. So, if you want to accept
    97  // non-browser clients, which do not send an Origin header, set a
    98  // Server.Handshake that does not check the origin.
    99  type Handler func(*Conn)
   100  
   101  func checkOrigin(config *Config, req *http.Request) (err error) {
   102  	config.Origin, err = Origin(config, req)
   103  	if err == nil && config.Origin == nil {
   104  		return fmt.Errorf("null origin")
   105  	}
   106  	return err
   107  }
   108  
   109  // ServeHTTP implements the http.Handler interface for a WebSocket
   110  func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   111  	s := Server{Handler: h, Handshake: checkOrigin}
   112  	s.serveWebSocket(w, req)
   113  }