github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/writer_test.go (about) 1 package wsutil 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "reflect" 8 "strconv" 9 "testing" 10 "unsafe" 11 12 "github.com/simonmittag/ws" 13 ) 14 15 // TODO(gobwas): test NewWriterSize on edge cases for offset. 16 17 const ( 18 bitsize = 32 << (^uint(0) >> 63) 19 maxint = int(^uint(1 << (bitsize - 1))) 20 ) 21 22 func TestControlWriter(t *testing.T) { 23 const ( 24 server = ws.StateServerSide 25 client = ws.StateClientSide 26 ) 27 for _, test := range []struct { 28 name string 29 size int 30 write []byte 31 state ws.State 32 op ws.OpCode 33 exp ws.Frame 34 err bool 35 }{ 36 { 37 state: server, 38 op: ws.OpPing, 39 exp: ws.NewPingFrame(nil), 40 }, 41 { 42 write: []byte("0123456789"), 43 state: server, 44 op: ws.OpPing, 45 exp: ws.NewPingFrame([]byte("0123456789")), 46 }, 47 { 48 size: 10 + reserve(server, 10), 49 write: []byte("0123456789"), 50 state: server, 51 op: ws.OpPing, 52 exp: ws.NewPingFrame([]byte("0123456789")), 53 }, 54 { 55 size: 10 + reserve(server, 10), 56 write: []byte("0123456789a"), 57 state: server, 58 op: ws.OpPing, 59 err: true, 60 }, 61 { 62 write: bytes.Repeat([]byte{'x'}, ws.MaxControlFramePayloadSize+1), 63 state: server, 64 op: ws.OpPing, 65 err: true, 66 }, 67 } { 68 t.Run(test.name, func(t *testing.T) { 69 var buf bytes.Buffer 70 var w *ControlWriter 71 if n := test.size; n == 0 { 72 w = NewControlWriter(&buf, test.state, test.op) 73 } else { 74 p := make([]byte, n) 75 w = NewControlWriterBuffer(&buf, test.state, test.op, p) 76 } 77 78 _, err := w.Write(test.write) 79 if err == nil { 80 err = w.Flush() 81 } 82 if test.err { 83 if err == nil { 84 t.Errorf("want error") 85 } 86 return 87 } 88 if !test.err && err != nil { 89 t.Errorf("unexpected error: %v", err) 90 return 91 } 92 93 act, err := ws.ReadFrame(&buf) 94 if err != nil { 95 t.Fatal(err) 96 } 97 98 act = omitMask(act) 99 exp := omitMask(test.exp) 100 if !reflect.DeepEqual(act, exp) { 101 t.Errorf("unexpected frame:\nflushed: %v\nwant: %v", pretty(act), pretty(exp)) 102 } 103 }) 104 } 105 } 106 107 type reserveTestCase struct { 108 name string 109 buf int 110 state ws.State 111 expOffset int 112 panic bool 113 } 114 115 func genReserveTestCases(s ws.State, n, m, exp int) []reserveTestCase { 116 ret := make([]reserveTestCase, m-n) 117 for i := n; i < m; i++ { 118 var suffix string 119 if s.ClientSide() { 120 suffix = " masked" 121 } 122 123 ret[i-n] = reserveTestCase{ 124 name: "gen " + strconv.Itoa(i) + suffix, 125 buf: i, 126 state: s, 127 expOffset: exp, 128 } 129 } 130 return ret 131 } 132 133 func fakeMake(n int) (r []byte) { 134 rh := (*reflect.SliceHeader)(unsafe.Pointer(&r)) 135 *rh = reflect.SliceHeader{ 136 Len: n, 137 Cap: n, 138 } 139 return r 140 } 141 142 var reserveTestCases = []reserveTestCase{ 143 { 144 name: "len7", 145 buf: int(len7) + 2, 146 expOffset: 2, 147 }, 148 { 149 name: "len16", 150 buf: int(len16) + 4, 151 expOffset: 4, 152 }, 153 { 154 name: "maxint", 155 buf: maxint, 156 expOffset: 10, 157 }, 158 { 159 name: "len7 masked", 160 buf: int(len7) + 6, 161 state: ws.StateClientSide, 162 expOffset: 6, 163 }, 164 { 165 name: "len16 masked", 166 buf: int(len16) + 8, 167 state: ws.StateClientSide, 168 expOffset: 8, 169 }, 170 { 171 name: "maxint masked", 172 buf: maxint, 173 state: ws.StateClientSide, 174 expOffset: 14, 175 }, 176 { 177 name: "split case", 178 buf: 128, 179 expOffset: 4, 180 }, 181 } 182 183 func TestNewWriterBuffer(t *testing.T) { 184 cases := append( 185 reserveTestCases, 186 reserveTestCase{ 187 name: "panic", 188 buf: 2, 189 panic: true, 190 }, 191 reserveTestCase{ 192 name: "panic", 193 buf: 6, 194 state: ws.StateClientSide, 195 panic: true, 196 }, 197 ) 198 cases = append(cases, genReserveTestCases(0, int(len7)-2, int(len7)+2, 2)...) 199 cases = append(cases, genReserveTestCases(0, int(len16)-4, int(len16)+4, 4)...) 200 cases = append(cases, genReserveTestCases(0, maxint-10, maxint, 10)...) 201 202 cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len7)-6, int(len7)+6, 6)...) 203 cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len16)-8, int(len16)+8, 8)...) 204 cases = append(cases, genReserveTestCases(ws.StateClientSide, maxint-14, maxint, 14)...) 205 206 for _, test := range cases { 207 t.Run(test.name, func(t *testing.T) { 208 defer func() { 209 thePanic := recover() 210 if test.panic && thePanic == nil { 211 t.Errorf("expected panic") 212 } 213 if !test.panic && thePanic != nil { 214 t.Errorf("unexpected panic: %v", thePanic) 215 } 216 }() 217 w := NewWriterBuffer(nil, test.state, 0, fakeMake(test.buf)) 218 if act, exp := len(w.raw)-len(w.buf), test.expOffset; act != exp { 219 t.Errorf( 220 "NewWriteBuffer(%d bytes) has offset %d; want %d", 221 test.buf, act, exp, 222 ) 223 } 224 }) 225 } 226 } 227 228 func TestWriter(t *testing.T) { 229 for i, test := range []struct { 230 label string 231 size int 232 state ws.State 233 data [][]byte 234 expFrm []ws.Frame 235 expBts []byte 236 }{ 237 // No Write(), no frames. 238 {}, 239 240 { 241 data: [][]byte{ 242 {}, 243 }, 244 expBts: ws.MustCompileFrame(ws.NewTextFrame(nil)), 245 }, 246 { 247 data: [][]byte{ 248 []byte("hello, world!"), 249 }, 250 expBts: ws.MustCompileFrame(ws.NewTextFrame([]byte("hello, world!"))), 251 }, 252 { 253 state: ws.StateClientSide, 254 data: [][]byte{ 255 []byte("hello, world!"), 256 }, 257 expFrm: []ws.Frame{ws.MaskFrame(ws.NewTextFrame([]byte("hello, world!")))}, 258 }, 259 { 260 size: 5, 261 data: [][]byte{ 262 []byte("hello"), 263 []byte(", wor"), 264 []byte("ld!"), 265 }, 266 expBts: bytes.Join( 267 bts( 268 ws.MustCompileFrame(ws.Frame{ 269 Header: ws.Header{ 270 Fin: false, 271 OpCode: ws.OpText, 272 Length: 5, 273 }, 274 Payload: []byte("hello"), 275 }), 276 ws.MustCompileFrame(ws.Frame{ 277 Header: ws.Header{ 278 Fin: false, 279 OpCode: ws.OpContinuation, 280 Length: 5, 281 }, 282 Payload: []byte(", wor"), 283 }), 284 ws.MustCompileFrame(ws.Frame{ 285 Header: ws.Header{ 286 Fin: true, 287 OpCode: ws.OpContinuation, 288 Length: 3, 289 }, 290 Payload: []byte("ld!"), 291 }), 292 ), 293 nil, 294 ), 295 }, 296 { // Large write case. 297 size: 5, 298 data: [][]byte{ 299 []byte("hello, world!"), 300 }, 301 expBts: bytes.Join( 302 bts( 303 ws.MustCompileFrame(ws.Frame{ 304 Header: ws.Header{ 305 Fin: false, 306 OpCode: ws.OpText, 307 Length: 13, 308 }, 309 Payload: []byte("hello, world!"), 310 }), 311 ws.MustCompileFrame(ws.Frame{ 312 Header: ws.Header{ 313 Fin: true, 314 OpCode: ws.OpContinuation, 315 Length: 0, 316 }, 317 }), 318 ), 319 nil, 320 ), 321 }, 322 } { 323 t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 324 buf := &bytes.Buffer{} 325 w := NewWriterSize(buf, test.state, ws.OpText, test.size) 326 327 for _, p := range test.data { 328 _, err := w.Write(p) 329 if err != nil { 330 t.Fatalf("unexpected Write() error: %s", err) 331 } 332 } 333 if err := w.Flush(); err != nil { 334 t.Fatalf("unexpected Flush() error: %s", err) 335 } 336 if test.expBts != nil { 337 if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) { 338 t.Errorf( 339 "wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts, 340 pretty(frames(bts)...), pretty(frames(test.expBts)...), 341 ) 342 } 343 } 344 if test.expFrm != nil { 345 act := omitMasks(frames(buf.Bytes())) 346 exp := omitMasks(test.expFrm) 347 348 if !reflect.DeepEqual(act, exp) { 349 t.Errorf( 350 "wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n", 351 pretty(act...), pretty(exp...), 352 ) 353 } 354 } 355 }) 356 } 357 } 358 359 func TestWriterLargeWrite(t *testing.T) { 360 var dest bytes.Buffer 361 w := NewWriterSize(&dest, 0, 0, 16) 362 363 // Test that event for big writes extensions set their bits. 364 var rsv = [3]bool{true, true, false} 365 w.SetExtensions(SendExtensionFunc(func(h ws.Header) (ws.Header, error) { 366 h.Rsv = ws.Rsv(rsv[0], rsv[1], rsv[2]) 367 return h, nil 368 })) 369 370 // Write message with size twice bigger than writer's internal buffer. 371 // We expect Writer to write it directly without buffering since we didn't 372 // write anything before (no data in internal buffer). 373 bts := make([]byte, 2*w.Size()) 374 if _, err := w.Write(bts); err != nil { 375 t.Fatal(err) 376 } 377 if err := w.Flush(); err != nil { 378 t.Fatal(err) 379 } 380 381 frame, err := ws.ReadFrame(&dest) 382 if err != nil { 383 t.Fatalf("can't read frame: %v", err) 384 } 385 386 var act [3]bool 387 act[0], act[1], act[2] = ws.RsvBits(frame.Header.Rsv) 388 if act != rsv { 389 t.Fatalf("unexpected rsv bits sent: %v; extension set %v", act, rsv) 390 } 391 } 392 393 func TestWriterReadFrom(t *testing.T) { 394 for i, test := range []struct { 395 label string 396 chop int 397 size int 398 data []byte 399 exp []ws.Frame 400 n int64 401 }{ 402 { 403 chop: 1, 404 size: 1, 405 data: []byte("golang"), 406 exp: []ws.Frame{ 407 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpText}, Payload: []byte{'g'}}, 408 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'o'}}, 409 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'l'}}, 410 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'a'}}, 411 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'n'}}, 412 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'g'}}, 413 {Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpContinuation}}, 414 }, 415 n: 6, 416 }, 417 { 418 chop: 1, 419 size: 4, 420 data: []byte("golang"), 421 exp: []ws.Frame{ 422 {Header: ws.Header{Fin: false, Length: 4, OpCode: ws.OpText}, Payload: []byte("gola")}, 423 {Header: ws.Header{Fin: true, Length: 2, OpCode: ws.OpContinuation}, Payload: []byte("ng")}, 424 }, 425 n: 6, 426 }, 427 { 428 size: 64, 429 data: []byte{}, 430 exp: []ws.Frame{ 431 {Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpText}}, 432 }, 433 n: 0, 434 }, 435 } { 436 t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 437 dst := &bytes.Buffer{} 438 wr := NewWriterSize(dst, 0, ws.OpText, test.size) 439 440 chop := test.chop 441 if chop == 0 { 442 chop = 128 443 } 444 src := &chopReader{bytes.NewReader(test.data), chop} 445 446 n, err := wr.ReadFrom(src) 447 if err == nil { 448 err = wr.Flush() 449 } 450 if err != nil { 451 t.Fatalf("unexpected error: %s", err) 452 } 453 if n != test.n { 454 t.Errorf("ReadFrom() read out %d; want %d", n, test.n) 455 } 456 if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) { 457 t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames...), pretty(test.exp...)) 458 } 459 }) 460 } 461 } 462 463 func TestWriterWriteCount(t *testing.T) { 464 for _, test := range []struct { 465 name string 466 cap int 467 exp int 468 write []int // For ability to avoid large write inside Write()'s "if". 469 }{ 470 { 471 name: "one frame", 472 cap: 10, 473 write: []int{10}, 474 exp: 1, 475 }, 476 { 477 name: "two frames", 478 cap: 10, 479 write: []int{5, 7}, 480 exp: 2, 481 }, 482 } { 483 t.Run(test.name, func(t *testing.T) { 484 n := writeCounter{} 485 w := NewWriterSize(&n, 0, ws.OpText, test.cap) 486 487 for _, n := range test.write { 488 text := bytes.Repeat([]byte{'x'}, n) 489 if _, err := w.Write(text); err != nil { 490 t.Fatal(err) 491 } 492 } 493 494 if err := w.Flush(); err != nil { 495 t.Fatal(err) 496 } 497 498 if act, exp := n.n, test.exp; act != exp { 499 t.Errorf("made %d Write() calls to dest writer; want %d", act, exp) 500 } 501 }) 502 } 503 } 504 505 func TestWriterNoPreemtiveFlush(t *testing.T) { 506 n := writeCounter{} 507 w := NewWriterSize(&n, 0, 0, 10) 508 509 // Fill buffer. 510 if _, err := w.Write([]byte("0123456789")); err != nil { 511 t.Fatal(err) 512 } 513 if n.n != 0 { 514 t.Fatalf( 515 "after filling up Writer got %d writes to the dest; want 0", 516 n.n, 517 ) 518 } 519 } 520 521 type writeCounter struct { 522 n int 523 } 524 525 func (w *writeCounter) Write(p []byte) (int, error) { 526 w.n++ 527 return len(p), nil 528 } 529 530 func frames(p []byte) (ret []ws.Frame) { 531 r := bytes.NewReader(p) 532 for stop := false; !stop; { 533 f, err := ws.ReadFrame(r) 534 if err != nil { 535 if err == io.EOF { 536 break 537 } 538 panic(err) 539 } 540 ret = append(ret, f) 541 } 542 return 543 } 544 545 func pretty(f ...ws.Frame) string { 546 str := "\n" 547 for _, f := range f { 548 str += fmt.Sprintf("\t%#v\n\t%#x (%#q)\n\t----\n", f.Header, f.Payload, f.Payload) 549 } 550 return str 551 } 552 553 func omitMask(f ws.Frame) ws.Frame { 554 if f.Header.Masked { 555 p := make([]byte, int(f.Header.Length)) 556 copy(p, f.Payload) 557 558 ws.Cipher(p, f.Header.Mask, 0) 559 560 f.Header.Mask = [4]byte{0, 0, 0, 0} 561 f.Payload = p 562 } 563 return f 564 } 565 566 func omitMasks(f []ws.Frame) []ws.Frame { 567 for i := 0; i < len(f); i++ { 568 f[i] = omitMask(f[i]) 569 } 570 return f 571 } 572 573 func bts(b ...[]byte) [][]byte { return b }