github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/example/autobahn/autobahn.go (about)

     1  package main
     2  
     3  import (
     4  	"compress/flate"
     5  	"context"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"net/http"
    13  	"os"
    14  	"os/signal"
    15  	"syscall"
    16  	"time"
    17  
    18  	"github.com/gobwas/httphead"
    19  	"github.com/simonmittag/ws"
    20  	"github.com/simonmittag/ws/wsflate"
    21  	"github.com/simonmittag/ws/wsutil"
    22  )
    23  
    24  const dir = "./example/autobahn"
    25  
    26  var addr = flag.String("listen", ":9001", "addr to listen")
    27  
    28  func main() {
    29  	log.SetFlags(0)
    30  	flag.Parse()
    31  
    32  	http.HandleFunc("/ws", wsHandler)
    33  	http.HandleFunc("/wsutil", wsutilHandler)
    34  	http.HandleFunc("/wsflate", wsflateHandler)
    35  	http.HandleFunc("/helpers/low", helpersLowLevelHandler)
    36  	http.HandleFunc("/helpers/high", helpersHighLevelHandler)
    37  
    38  	ln, err := net.Listen("tcp", *addr)
    39  	if err != nil {
    40  		log.Fatalf("listen %q error: %v", *addr, err)
    41  	}
    42  	log.Printf("listening %s (%q)", ln.Addr(), *addr)
    43  
    44  	var (
    45  		s     = new(http.Server)
    46  		serve = make(chan error, 1)
    47  		sig   = make(chan os.Signal, 1)
    48  	)
    49  	signal.Notify(sig, syscall.SIGTERM)
    50  	go func() { serve <- s.Serve(ln) }()
    51  
    52  	select {
    53  	case err := <-serve:
    54  		log.Fatal(err)
    55  	case sig := <-sig:
    56  		const timeout = 5 * time.Second
    57  
    58  		log.Printf("signal %q received; shutting down with %s timeout", sig, timeout)
    59  
    60  		ctx, _ := context.WithTimeout(context.Background(), timeout)
    61  		if err := s.Shutdown(ctx); err != nil {
    62  			log.Fatal(err)
    63  		}
    64  	}
    65  }
    66  
    67  var (
    68  	closeInvalidPayload = ws.MustCompileFrame(
    69  		ws.NewCloseFrame(ws.NewCloseFrameBody(
    70  			ws.StatusInvalidFramePayloadData, "",
    71  		)),
    72  	)
    73  	closeProtocolError = ws.MustCompileFrame(
    74  		ws.NewCloseFrame(ws.NewCloseFrameBody(
    75  			ws.StatusProtocolError, "",
    76  		)),
    77  	)
    78  )
    79  
    80  func helpersHighLevelHandler(w http.ResponseWriter, r *http.Request) {
    81  	conn, _, _, err := ws.UpgradeHTTP(r, w)
    82  	if err != nil {
    83  		log.Printf("upgrade error: %s", err)
    84  		return
    85  	}
    86  	defer conn.Close()
    87  
    88  	for {
    89  		bts, op, err := wsutil.ReadClientData(conn)
    90  		if err != nil {
    91  			log.Printf("read message error: %v", err)
    92  			return
    93  		}
    94  		err = wsutil.WriteServerMessage(conn, op, bts)
    95  		if err != nil {
    96  			log.Printf("write message error: %v", err)
    97  			return
    98  		}
    99  	}
   100  }
   101  
   102  func helpersLowLevelHandler(w http.ResponseWriter, r *http.Request) {
   103  	conn, _, _, err := ws.UpgradeHTTP(r, w)
   104  	if err != nil {
   105  		log.Printf("upgrade error: %s", err)
   106  		return
   107  	}
   108  	defer conn.Close()
   109  
   110  	msg := make([]wsutil.Message, 0, 4)
   111  
   112  	for {
   113  		msg, err = wsutil.ReadClientMessage(conn, msg[:0])
   114  		if err != nil {
   115  			log.Printf("read message error: %v", err)
   116  			return
   117  		}
   118  		for _, m := range msg {
   119  			if m.OpCode.IsControl() {
   120  				err := wsutil.HandleClientControlMessage(conn, m)
   121  				if err != nil {
   122  					log.Printf("handle control error: %v", err)
   123  					return
   124  				}
   125  				continue
   126  			}
   127  			err := wsutil.WriteServerMessage(conn, m.OpCode, m.Payload)
   128  			if err != nil {
   129  				log.Printf("write message error: %v", err)
   130  				return
   131  			}
   132  		}
   133  	}
   134  }
   135  
   136  func wsutilHandler(res http.ResponseWriter, req *http.Request) {
   137  	conn, _, _, err := ws.UpgradeHTTP(req, res)
   138  	if err != nil {
   139  		log.Printf("upgrade error: %s", err)
   140  		return
   141  	}
   142  	defer conn.Close()
   143  
   144  	state := ws.StateServerSide
   145  
   146  	ch := wsutil.ControlFrameHandler(conn, state)
   147  	r := &wsutil.Reader{
   148  		Source:         conn,
   149  		State:          state,
   150  		CheckUTF8:      true,
   151  		OnIntermediate: ch,
   152  	}
   153  	w := wsutil.NewWriter(conn, state, 0)
   154  
   155  	for {
   156  		h, err := r.NextFrame()
   157  		if err != nil {
   158  			log.Printf("next frame error: %v", err)
   159  			return
   160  		}
   161  		if h.OpCode.IsControl() {
   162  			if err = ch(h, r); err != nil {
   163  				log.Printf("handle control error: %v", err)
   164  				return
   165  			}
   166  			continue
   167  		}
   168  
   169  		w.Reset(conn, state, h.OpCode)
   170  
   171  		if _, err = io.Copy(w, r); err == nil {
   172  			err = w.Flush()
   173  		}
   174  		if err != nil {
   175  			log.Printf("echo error: %s", err)
   176  			return
   177  		}
   178  	}
   179  }
   180  
   181  func wsflateHandler(w http.ResponseWriter, r *http.Request) {
   182  	e := wsflate.Extension{
   183  		Parameters: wsflate.Parameters{
   184  			ServerNoContextTakeover: true,
   185  			ClientNoContextTakeover: true,
   186  		},
   187  	}
   188  	u := ws.HTTPUpgrader{
   189  		Negotiate: e.Negotiate,
   190  	}
   191  	conn, _, _, err := u.Upgrade(r, w)
   192  	if err != nil {
   193  		log.Printf("upgrade error: %s", err)
   194  		return
   195  	}
   196  	defer conn.Close()
   197  
   198  	if _, ok := e.Accepted(); !ok {
   199  		log.Printf("no accepted extension")
   200  		return
   201  	}
   202  
   203  	// Using nil as a destination io.Writer since we will Reset() it in the
   204  	// loop below.
   205  	fw := wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
   206  		// As flat.NewWriter() docs says:
   207  		//   If level is in the range [-2, 9] then the error returned will
   208  		//   be nil.
   209  		f, _ := flate.NewWriter(w, 9)
   210  		return f
   211  	})
   212  	// Using nil as a source io.Reader since we will Reset() it in the loop
   213  	// below.
   214  	fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
   215  		return flate.NewReader(r)
   216  	})
   217  
   218  	// MessageState implements wsutil.Extension and is used to check whether
   219  	// received WebSocket message is compressed. That is, it's generally
   220  	// possible to receive uncompressed messaged even if compression extension
   221  	// was negotiated.
   222  	var msg wsflate.MessageState
   223  
   224  	// Note that control frames are all written without compression.
   225  	controlHandler := wsutil.ControlFrameHandler(conn, ws.StateServerSide)
   226  	rd := wsutil.Reader{
   227  		Source:         conn,
   228  		State:          ws.StateServerSide | ws.StateExtended,
   229  		CheckUTF8:      false,
   230  		OnIntermediate: controlHandler,
   231  		Extensions:     []wsutil.RecvExtension{&msg},
   232  	}
   233  
   234  	wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0)
   235  	wr.SetExtensions(&msg)
   236  
   237  	for {
   238  		h, err := rd.NextFrame()
   239  		if err != nil {
   240  			log.Printf("next frame error: %v", err)
   241  			return
   242  		}
   243  		if h.OpCode.IsControl() {
   244  			if err := controlHandler(h, &rd); err != nil {
   245  				log.Printf("handle control frame error: %v", err)
   246  				return
   247  			}
   248  			continue
   249  		}
   250  
   251  		wr.ResetOp(h.OpCode)
   252  
   253  		var (
   254  			src io.Reader = &rd
   255  			dst io.Writer = wr
   256  		)
   257  		if msg.IsCompressed() {
   258  			fr.Reset(src)
   259  			fw.Reset(dst)
   260  			src = fr
   261  			dst = fw
   262  		}
   263  		// Copy incoming bytes right into writer, probably through decompressor
   264  		// and compressor.
   265  		if _, err = io.Copy(dst, src); err != nil {
   266  			log.Fatal(err)
   267  		}
   268  		if msg.IsCompressed() {
   269  			// Flush the flate writer.
   270  			if err = fw.Close(); err != nil {
   271  				log.Fatal(err)
   272  			}
   273  		}
   274  		// Flush WebSocket fragment writer. We could send multiple fragments
   275  		// for large messages.
   276  		if err = wr.Flush(); err != nil {
   277  			log.Fatal(err)
   278  		}
   279  	}
   280  }
   281  
   282  func wsHandler(w http.ResponseWriter, r *http.Request) {
   283  	u := ws.HTTPUpgrader{
   284  		Extension: func(opt httphead.Option) bool {
   285  			log.Printf("extension: %s", opt)
   286  			return false
   287  		},
   288  	}
   289  	conn, _, _, err := u.Upgrade(r, w)
   290  	if err != nil {
   291  		log.Printf("upgrade error: %s", err)
   292  		return
   293  	}
   294  	defer conn.Close()
   295  
   296  	state := ws.StateServerSide
   297  
   298  	textPending := false
   299  	utf8Reader := wsutil.NewUTF8Reader(nil)
   300  	cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0})
   301  
   302  	for {
   303  		header, err := ws.ReadHeader(conn)
   304  		if err != nil {
   305  			log.Printf("read header error: %s", err)
   306  			break
   307  		}
   308  		if err = ws.CheckHeader(header, state); err != nil {
   309  			log.Printf("header check error: %s", err)
   310  			conn.Write(closeProtocolError)
   311  			return
   312  		}
   313  
   314  		cipherReader.Reset(
   315  			io.LimitReader(conn, header.Length),
   316  			header.Mask,
   317  		)
   318  
   319  		var utf8Fin bool
   320  		var r io.Reader = cipherReader
   321  
   322  		switch header.OpCode {
   323  		case ws.OpPing:
   324  			header.OpCode = ws.OpPong
   325  			header.Masked = false
   326  			ws.WriteHeader(conn, header)
   327  			io.CopyN(conn, cipherReader, header.Length)
   328  			continue
   329  
   330  		case ws.OpPong:
   331  			io.CopyN(ioutil.Discard, conn, header.Length)
   332  			continue
   333  
   334  		case ws.OpClose:
   335  			utf8Fin = true
   336  
   337  		case ws.OpContinuation:
   338  			if textPending {
   339  				utf8Reader.Source = cipherReader
   340  				r = utf8Reader
   341  			}
   342  			if header.Fin {
   343  				state = state.Clear(ws.StateFragmented)
   344  				textPending = false
   345  				utf8Fin = true
   346  			}
   347  
   348  		case ws.OpText:
   349  			utf8Reader.Reset(cipherReader)
   350  			r = utf8Reader
   351  
   352  			if !header.Fin {
   353  				state = state.Set(ws.StateFragmented)
   354  				textPending = true
   355  			} else {
   356  				utf8Fin = true
   357  			}
   358  
   359  		case ws.OpBinary:
   360  			if !header.Fin {
   361  				state = state.Set(ws.StateFragmented)
   362  			}
   363  		}
   364  
   365  		payload := make([]byte, header.Length)
   366  		_, err = io.ReadFull(r, payload)
   367  		if err == nil && utf8Fin && !utf8Reader.Valid() {
   368  			err = wsutil.ErrInvalidUTF8
   369  		}
   370  		if err != nil {
   371  			log.Printf("read payload error: %s", err)
   372  			if err == wsutil.ErrInvalidUTF8 {
   373  				conn.Write(closeInvalidPayload)
   374  			} else {
   375  				conn.Write(ws.CompiledClose)
   376  			}
   377  			return
   378  		}
   379  
   380  		if header.OpCode == ws.OpClose {
   381  			code, reason := ws.ParseCloseFrameData(payload)
   382  			log.Printf("close frame received: %v %v", code, reason)
   383  
   384  			if !code.Empty() {
   385  				switch {
   386  				case code.IsProtocolSpec() && !code.IsProtocolDefined():
   387  					err = fmt.Errorf("close code from spec range is not defined")
   388  				default:
   389  					err = ws.CheckCloseFrameData(code, reason)
   390  				}
   391  				if err != nil {
   392  					log.Printf("invalid close data: %s", err)
   393  					conn.Write(closeProtocolError)
   394  				} else {
   395  					ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody(
   396  						code, "",
   397  					)))
   398  				}
   399  				return
   400  			}
   401  
   402  			conn.Write(ws.CompiledClose)
   403  			return
   404  		}
   405  
   406  		header.Masked = false
   407  		ws.WriteHeader(conn, header)
   408  		conn.Write(payload)
   409  	}
   410  }