github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/example/autobahn/autobahn.go (about) 1 package main 2 3 import ( 4 "compress/flate" 5 "context" 6 "flag" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "net" 12 "net/http" 13 "os" 14 "os/signal" 15 "syscall" 16 "time" 17 18 "github.com/gobwas/httphead" 19 "github.com/simonmittag/ws" 20 "github.com/simonmittag/ws/wsflate" 21 "github.com/simonmittag/ws/wsutil" 22 ) 23 24 const dir = "./example/autobahn" 25 26 var addr = flag.String("listen", ":9001", "addr to listen") 27 28 func main() { 29 log.SetFlags(0) 30 flag.Parse() 31 32 http.HandleFunc("/ws", wsHandler) 33 http.HandleFunc("/wsutil", wsutilHandler) 34 http.HandleFunc("/wsflate", wsflateHandler) 35 http.HandleFunc("/helpers/low", helpersLowLevelHandler) 36 http.HandleFunc("/helpers/high", helpersHighLevelHandler) 37 38 ln, err := net.Listen("tcp", *addr) 39 if err != nil { 40 log.Fatalf("listen %q error: %v", *addr, err) 41 } 42 log.Printf("listening %s (%q)", ln.Addr(), *addr) 43 44 var ( 45 s = new(http.Server) 46 serve = make(chan error, 1) 47 sig = make(chan os.Signal, 1) 48 ) 49 signal.Notify(sig, syscall.SIGTERM) 50 go func() { serve <- s.Serve(ln) }() 51 52 select { 53 case err := <-serve: 54 log.Fatal(err) 55 case sig := <-sig: 56 const timeout = 5 * time.Second 57 58 log.Printf("signal %q received; shutting down with %s timeout", sig, timeout) 59 60 ctx, _ := context.WithTimeout(context.Background(), timeout) 61 if err := s.Shutdown(ctx); err != nil { 62 log.Fatal(err) 63 } 64 } 65 } 66 67 var ( 68 closeInvalidPayload = ws.MustCompileFrame( 69 ws.NewCloseFrame(ws.NewCloseFrameBody( 70 ws.StatusInvalidFramePayloadData, "", 71 )), 72 ) 73 closeProtocolError = ws.MustCompileFrame( 74 ws.NewCloseFrame(ws.NewCloseFrameBody( 75 ws.StatusProtocolError, "", 76 )), 77 ) 78 ) 79 80 func helpersHighLevelHandler(w http.ResponseWriter, r *http.Request) { 81 conn, _, _, err := ws.UpgradeHTTP(r, w) 82 if err != nil { 83 log.Printf("upgrade error: %s", err) 84 return 85 } 86 defer conn.Close() 87 88 for { 89 bts, op, err := wsutil.ReadClientData(conn) 90 if err != nil { 91 log.Printf("read message error: %v", err) 92 return 93 } 94 err = wsutil.WriteServerMessage(conn, op, bts) 95 if err != nil { 96 log.Printf("write message error: %v", err) 97 return 98 } 99 } 100 } 101 102 func helpersLowLevelHandler(w http.ResponseWriter, r *http.Request) { 103 conn, _, _, err := ws.UpgradeHTTP(r, w) 104 if err != nil { 105 log.Printf("upgrade error: %s", err) 106 return 107 } 108 defer conn.Close() 109 110 msg := make([]wsutil.Message, 0, 4) 111 112 for { 113 msg, err = wsutil.ReadClientMessage(conn, msg[:0]) 114 if err != nil { 115 log.Printf("read message error: %v", err) 116 return 117 } 118 for _, m := range msg { 119 if m.OpCode.IsControl() { 120 err := wsutil.HandleClientControlMessage(conn, m) 121 if err != nil { 122 log.Printf("handle control error: %v", err) 123 return 124 } 125 continue 126 } 127 err := wsutil.WriteServerMessage(conn, m.OpCode, m.Payload) 128 if err != nil { 129 log.Printf("write message error: %v", err) 130 return 131 } 132 } 133 } 134 } 135 136 func wsutilHandler(res http.ResponseWriter, req *http.Request) { 137 conn, _, _, err := ws.UpgradeHTTP(req, res) 138 if err != nil { 139 log.Printf("upgrade error: %s", err) 140 return 141 } 142 defer conn.Close() 143 144 state := ws.StateServerSide 145 146 ch := wsutil.ControlFrameHandler(conn, state) 147 r := &wsutil.Reader{ 148 Source: conn, 149 State: state, 150 CheckUTF8: true, 151 OnIntermediate: ch, 152 } 153 w := wsutil.NewWriter(conn, state, 0) 154 155 for { 156 h, err := r.NextFrame() 157 if err != nil { 158 log.Printf("next frame error: %v", err) 159 return 160 } 161 if h.OpCode.IsControl() { 162 if err = ch(h, r); err != nil { 163 log.Printf("handle control error: %v", err) 164 return 165 } 166 continue 167 } 168 169 w.Reset(conn, state, h.OpCode) 170 171 if _, err = io.Copy(w, r); err == nil { 172 err = w.Flush() 173 } 174 if err != nil { 175 log.Printf("echo error: %s", err) 176 return 177 } 178 } 179 } 180 181 func wsflateHandler(w http.ResponseWriter, r *http.Request) { 182 e := wsflate.Extension{ 183 Parameters: wsflate.Parameters{ 184 ServerNoContextTakeover: true, 185 ClientNoContextTakeover: true, 186 }, 187 } 188 u := ws.HTTPUpgrader{ 189 Negotiate: e.Negotiate, 190 } 191 conn, _, _, err := u.Upgrade(r, w) 192 if err != nil { 193 log.Printf("upgrade error: %s", err) 194 return 195 } 196 defer conn.Close() 197 198 if _, ok := e.Accepted(); !ok { 199 log.Printf("no accepted extension") 200 return 201 } 202 203 // Using nil as a destination io.Writer since we will Reset() it in the 204 // loop below. 205 fw := wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { 206 // As flat.NewWriter() docs says: 207 // If level is in the range [-2, 9] then the error returned will 208 // be nil. 209 f, _ := flate.NewWriter(w, 9) 210 return f 211 }) 212 // Using nil as a source io.Reader since we will Reset() it in the loop 213 // below. 214 fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { 215 return flate.NewReader(r) 216 }) 217 218 // MessageState implements wsutil.Extension and is used to check whether 219 // received WebSocket message is compressed. That is, it's generally 220 // possible to receive uncompressed messaged even if compression extension 221 // was negotiated. 222 var msg wsflate.MessageState 223 224 // Note that control frames are all written without compression. 225 controlHandler := wsutil.ControlFrameHandler(conn, ws.StateServerSide) 226 rd := wsutil.Reader{ 227 Source: conn, 228 State: ws.StateServerSide | ws.StateExtended, 229 CheckUTF8: false, 230 OnIntermediate: controlHandler, 231 Extensions: []wsutil.RecvExtension{&msg}, 232 } 233 234 wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0) 235 wr.SetExtensions(&msg) 236 237 for { 238 h, err := rd.NextFrame() 239 if err != nil { 240 log.Printf("next frame error: %v", err) 241 return 242 } 243 if h.OpCode.IsControl() { 244 if err := controlHandler(h, &rd); err != nil { 245 log.Printf("handle control frame error: %v", err) 246 return 247 } 248 continue 249 } 250 251 wr.ResetOp(h.OpCode) 252 253 var ( 254 src io.Reader = &rd 255 dst io.Writer = wr 256 ) 257 if msg.IsCompressed() { 258 fr.Reset(src) 259 fw.Reset(dst) 260 src = fr 261 dst = fw 262 } 263 // Copy incoming bytes right into writer, probably through decompressor 264 // and compressor. 265 if _, err = io.Copy(dst, src); err != nil { 266 log.Fatal(err) 267 } 268 if msg.IsCompressed() { 269 // Flush the flate writer. 270 if err = fw.Close(); err != nil { 271 log.Fatal(err) 272 } 273 } 274 // Flush WebSocket fragment writer. We could send multiple fragments 275 // for large messages. 276 if err = wr.Flush(); err != nil { 277 log.Fatal(err) 278 } 279 } 280 } 281 282 func wsHandler(w http.ResponseWriter, r *http.Request) { 283 u := ws.HTTPUpgrader{ 284 Extension: func(opt httphead.Option) bool { 285 log.Printf("extension: %s", opt) 286 return false 287 }, 288 } 289 conn, _, _, err := u.Upgrade(r, w) 290 if err != nil { 291 log.Printf("upgrade error: %s", err) 292 return 293 } 294 defer conn.Close() 295 296 state := ws.StateServerSide 297 298 textPending := false 299 utf8Reader := wsutil.NewUTF8Reader(nil) 300 cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0}) 301 302 for { 303 header, err := ws.ReadHeader(conn) 304 if err != nil { 305 log.Printf("read header error: %s", err) 306 break 307 } 308 if err = ws.CheckHeader(header, state); err != nil { 309 log.Printf("header check error: %s", err) 310 conn.Write(closeProtocolError) 311 return 312 } 313 314 cipherReader.Reset( 315 io.LimitReader(conn, header.Length), 316 header.Mask, 317 ) 318 319 var utf8Fin bool 320 var r io.Reader = cipherReader 321 322 switch header.OpCode { 323 case ws.OpPing: 324 header.OpCode = ws.OpPong 325 header.Masked = false 326 ws.WriteHeader(conn, header) 327 io.CopyN(conn, cipherReader, header.Length) 328 continue 329 330 case ws.OpPong: 331 io.CopyN(ioutil.Discard, conn, header.Length) 332 continue 333 334 case ws.OpClose: 335 utf8Fin = true 336 337 case ws.OpContinuation: 338 if textPending { 339 utf8Reader.Source = cipherReader 340 r = utf8Reader 341 } 342 if header.Fin { 343 state = state.Clear(ws.StateFragmented) 344 textPending = false 345 utf8Fin = true 346 } 347 348 case ws.OpText: 349 utf8Reader.Reset(cipherReader) 350 r = utf8Reader 351 352 if !header.Fin { 353 state = state.Set(ws.StateFragmented) 354 textPending = true 355 } else { 356 utf8Fin = true 357 } 358 359 case ws.OpBinary: 360 if !header.Fin { 361 state = state.Set(ws.StateFragmented) 362 } 363 } 364 365 payload := make([]byte, header.Length) 366 _, err = io.ReadFull(r, payload) 367 if err == nil && utf8Fin && !utf8Reader.Valid() { 368 err = wsutil.ErrInvalidUTF8 369 } 370 if err != nil { 371 log.Printf("read payload error: %s", err) 372 if err == wsutil.ErrInvalidUTF8 { 373 conn.Write(closeInvalidPayload) 374 } else { 375 conn.Write(ws.CompiledClose) 376 } 377 return 378 } 379 380 if header.OpCode == ws.OpClose { 381 code, reason := ws.ParseCloseFrameData(payload) 382 log.Printf("close frame received: %v %v", code, reason) 383 384 if !code.Empty() { 385 switch { 386 case code.IsProtocolSpec() && !code.IsProtocolDefined(): 387 err = fmt.Errorf("close code from spec range is not defined") 388 default: 389 err = ws.CheckCloseFrameData(code, reason) 390 } 391 if err != nil { 392 log.Printf("invalid close data: %s", err) 393 conn.Write(closeProtocolError) 394 } else { 395 ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody( 396 code, "", 397 ))) 398 } 399 return 400 } 401 402 conn.Write(ws.CompiledClose) 403 return 404 } 405 406 header.Masked = false 407 ws.WriteHeader(conn, header) 408 conn.Write(payload) 409 } 410 }