github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/writer_test.go (about) 1 package wsutil 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "reflect" 8 "strconv" 9 "testing" 10 "unsafe" 11 12 "github.com/ezoic/ws" 13 ) 14 15 // TODO(ezoic): test NewWriterSize on edge cases for offset. 16 17 const ( 18 bitsize = 32 << (^uint(0) >> 63) 19 maxint = int(^uint(1 << (bitsize - 1))) 20 ) 21 22 func TestControlWriter(t *testing.T) { 23 const ( 24 server = ws.StateServerSide 25 client = ws.StateClientSide 26 ) 27 for _, test := range []struct { 28 name string 29 size int 30 write []byte 31 state ws.State 32 op ws.OpCode 33 exp ws.Frame 34 err bool 35 }{ 36 { 37 state: server, 38 op: ws.OpPing, 39 exp: ws.NewPingFrame(nil), 40 }, 41 { 42 write: []byte("0123456789"), 43 state: server, 44 op: ws.OpPing, 45 exp: ws.NewPingFrame([]byte("0123456789")), 46 }, 47 { 48 size: 10 + reserve(server, 10), 49 write: []byte("0123456789"), 50 state: server, 51 op: ws.OpPing, 52 exp: ws.NewPingFrame([]byte("0123456789")), 53 }, 54 { 55 size: 10 + reserve(server, 10), 56 write: []byte("0123456789a"), 57 state: server, 58 op: ws.OpPing, 59 err: true, 60 }, 61 { 62 write: bytes.Repeat([]byte{'x'}, ws.MaxControlFramePayloadSize+1), 63 state: server, 64 op: ws.OpPing, 65 err: true, 66 }, 67 } { 68 t.Run(test.name, func(t *testing.T) { 69 var buf bytes.Buffer 70 var w *ControlWriter 71 if n := test.size; n == 0 { 72 w = NewControlWriter(&buf, test.state, test.op) 73 } else { 74 p := make([]byte, n) 75 w = NewControlWriterBuffer(&buf, test.state, test.op, p) 76 } 77 78 _, err := w.Write(test.write) 79 if err == nil { 80 err = w.Flush() 81 } 82 if test.err { 83 if err == nil { 84 t.Errorf("want error") 85 } 86 return 87 } 88 if !test.err && err != nil { 89 t.Errorf("unexpected error: %v", err) 90 return 91 } 92 93 act, err := ws.ReadFrame(&buf) 94 if err != nil { 95 t.Fatal(err) 96 } 97 98 act = omitMask(act) 99 exp := omitMask(test.exp) 100 if !reflect.DeepEqual(act, exp) { 101 t.Errorf("unexpected frame:\nflushed: %v\nwant: %v", pretty(act), pretty(exp)) 102 } 103 }) 104 } 105 } 106 107 type reserveTestCase struct { 108 name string 109 buf int 110 state ws.State 111 expOffset int 112 panic bool 113 } 114 115 func genReserveTestCases(s ws.State, n, m, exp int) []reserveTestCase { 116 ret := make([]reserveTestCase, m-n) 117 for i := n; i < m; i++ { 118 var suffix string 119 if s.ClientSide() { 120 suffix = " masked" 121 } 122 123 ret[i-n] = reserveTestCase{ 124 name: "gen " + strconv.Itoa(i) + suffix, 125 buf: i, 126 state: s, 127 expOffset: exp, 128 } 129 } 130 return ret 131 } 132 133 func fakeMake(n int) (r []byte) { 134 rh := (*reflect.SliceHeader)(unsafe.Pointer(&r)) 135 *rh = reflect.SliceHeader{ 136 Len: n, 137 Cap: n, 138 } 139 return r 140 } 141 142 var reserveTestCases = []reserveTestCase{ 143 { 144 name: "len7", 145 buf: int(len7) + 2, 146 expOffset: 2, 147 }, 148 { 149 name: "len16", 150 buf: int(len16) + 4, 151 expOffset: 4, 152 }, 153 { 154 name: "maxint", 155 buf: maxint, 156 expOffset: 10, 157 }, 158 { 159 name: "len7 masked", 160 buf: int(len7) + 6, 161 state: ws.StateClientSide, 162 expOffset: 6, 163 }, 164 { 165 name: "len16 masked", 166 buf: int(len16) + 8, 167 state: ws.StateClientSide, 168 expOffset: 8, 169 }, 170 { 171 name: "maxint masked", 172 buf: maxint, 173 state: ws.StateClientSide, 174 expOffset: 14, 175 }, 176 { 177 name: "split case", 178 buf: 128, 179 expOffset: 4, 180 }, 181 } 182 183 func TestNewWriterBuffer(t *testing.T) { 184 cases := append( 185 reserveTestCases, 186 reserveTestCase{ 187 name: "panic", 188 buf: 2, 189 panic: true, 190 }, 191 reserveTestCase{ 192 name: "panic", 193 buf: 6, 194 state: ws.StateClientSide, 195 panic: true, 196 }, 197 ) 198 cases = append(cases, genReserveTestCases(0, int(len7)-2, int(len7)+2, 2)...) 199 cases = append(cases, genReserveTestCases(0, int(len16)-4, int(len16)+4, 4)...) 200 cases = append(cases, genReserveTestCases(0, maxint-10, maxint, 10)...) 201 202 cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len7)-6, int(len7)+6, 6)...) 203 cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len16)-8, int(len16)+8, 8)...) 204 cases = append(cases, genReserveTestCases(ws.StateClientSide, maxint-14, maxint, 14)...) 205 206 for _, test := range cases { 207 t.Run(test.name, func(t *testing.T) { 208 defer func() { 209 thePanic := recover() 210 if test.panic && thePanic == nil { 211 t.Errorf("expected panic") 212 } 213 if !test.panic && thePanic != nil { 214 t.Errorf("unexpected panic: %v", thePanic) 215 } 216 }() 217 w := NewWriterBuffer(nil, test.state, 0, fakeMake(test.buf)) 218 if act, exp := len(w.raw)-len(w.buf), test.expOffset; act != exp { 219 t.Errorf( 220 "NewWriteBuffer(%d bytes) has offset %d; want %d", 221 test.buf, act, exp, 222 ) 223 } 224 }) 225 } 226 } 227 228 func TestWriter(t *testing.T) { 229 for i, test := range []struct { 230 label string 231 size int 232 state ws.State 233 data [][]byte 234 expFrm []ws.Frame 235 expBts []byte 236 }{ 237 // No Write(), no frames. 238 {}, 239 240 { 241 data: [][]byte{ 242 {}, 243 }, 244 expBts: ws.MustCompileFrame(ws.NewTextFrame(nil)), 245 }, 246 { 247 data: [][]byte{ 248 []byte("hello, world!"), 249 }, 250 expBts: ws.MustCompileFrame(ws.NewTextFrame([]byte("hello, world!"))), 251 }, 252 { 253 state: ws.StateClientSide, 254 data: [][]byte{ 255 []byte("hello, world!"), 256 }, 257 expFrm: []ws.Frame{ws.MaskFrame(ws.NewTextFrame([]byte("hello, world!")))}, 258 }, 259 { 260 size: 5, 261 data: [][]byte{ 262 []byte("hello"), 263 []byte(", wor"), 264 []byte("ld!"), 265 }, 266 expBts: bytes.Join( 267 bts( 268 ws.MustCompileFrame(ws.Frame{ 269 Header: ws.Header{ 270 Fin: false, 271 OpCode: ws.OpText, 272 Length: 5, 273 }, 274 Payload: []byte("hello"), 275 }), 276 ws.MustCompileFrame(ws.Frame{ 277 Header: ws.Header{ 278 Fin: false, 279 OpCode: ws.OpContinuation, 280 Length: 5, 281 }, 282 Payload: []byte(", wor"), 283 }), 284 ws.MustCompileFrame(ws.Frame{ 285 Header: ws.Header{ 286 Fin: true, 287 OpCode: ws.OpContinuation, 288 Length: 3, 289 }, 290 Payload: []byte("ld!"), 291 }), 292 ), 293 nil, 294 ), 295 }, 296 { // Large write case. 297 size: 5, 298 data: [][]byte{ 299 []byte("hello, world!"), 300 }, 301 expBts: bytes.Join( 302 bts( 303 ws.MustCompileFrame(ws.Frame{ 304 Header: ws.Header{ 305 Fin: false, 306 OpCode: ws.OpText, 307 Length: 13, 308 }, 309 Payload: []byte("hello, world!"), 310 }), 311 ws.MustCompileFrame(ws.Frame{ 312 Header: ws.Header{ 313 Fin: true, 314 OpCode: ws.OpContinuation, 315 Length: 0, 316 }, 317 }), 318 ), 319 nil, 320 ), 321 }, 322 } { 323 t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 324 buf := &bytes.Buffer{} 325 w := NewWriterSize(buf, test.state, ws.OpText, test.size) 326 327 for _, p := range test.data { 328 _, err := w.Write(p) 329 if err != nil { 330 t.Fatalf("unexpected Write() error: %s", err) 331 } 332 } 333 if err := w.Flush(); err != nil { 334 t.Fatalf("unexpected Flush() error: %s", err) 335 } 336 if test.expBts != nil { 337 if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) { 338 t.Errorf( 339 "wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts, 340 pretty(frames(bts)...), pretty(frames(test.expBts)...), 341 ) 342 } 343 } 344 if test.expFrm != nil { 345 act := omitMasks(frames(buf.Bytes())) 346 exp := omitMasks(test.expFrm) 347 348 if !reflect.DeepEqual(act, exp) { 349 t.Errorf( 350 "wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n", 351 pretty(act...), pretty(exp...), 352 ) 353 } 354 } 355 }) 356 } 357 } 358 359 func TestWriterReadFrom(t *testing.T) { 360 for i, test := range []struct { 361 label string 362 chop int 363 size int 364 data []byte 365 exp []ws.Frame 366 n int64 367 }{ 368 { 369 chop: 1, 370 size: 1, 371 data: []byte("golang"), 372 exp: []ws.Frame{ 373 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpText}, Payload: []byte{'g'}}, 374 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'o'}}, 375 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'l'}}, 376 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'a'}}, 377 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'n'}}, 378 {Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'g'}}, 379 {Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpContinuation}}, 380 }, 381 n: 6, 382 }, 383 { 384 chop: 1, 385 size: 4, 386 data: []byte("golang"), 387 exp: []ws.Frame{ 388 {Header: ws.Header{Fin: false, Length: 4, OpCode: ws.OpText}, Payload: []byte("gola")}, 389 {Header: ws.Header{Fin: true, Length: 2, OpCode: ws.OpContinuation}, Payload: []byte("ng")}, 390 }, 391 n: 6, 392 }, 393 { 394 size: 64, 395 data: []byte{}, 396 exp: []ws.Frame{ 397 {Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpText}}, 398 }, 399 n: 0, 400 }, 401 } { 402 t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 403 dst := &bytes.Buffer{} 404 wr := NewWriterSize(dst, 0, ws.OpText, test.size) 405 406 chop := test.chop 407 if chop == 0 { 408 chop = 128 409 } 410 src := &chopReader{bytes.NewReader(test.data), chop} 411 412 n, err := wr.ReadFrom(src) 413 if err == nil { 414 err = wr.Flush() 415 } 416 if err != nil { 417 t.Fatalf("unexpected error: %s", err) 418 } 419 if n != test.n { 420 t.Errorf("ReadFrom() read out %d; want %d", n, test.n) 421 } 422 if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) { 423 t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames...), pretty(test.exp...)) 424 } 425 }) 426 } 427 } 428 429 func TestWriterWriteCount(t *testing.T) { 430 for _, test := range []struct { 431 name string 432 cap int 433 exp int 434 write []int // For ability to avoid large write inside Write()'s "if". 435 }{ 436 { 437 name: "one frame", 438 cap: 10, 439 write: []int{10}, 440 exp: 1, 441 }, 442 { 443 name: "two frames", 444 cap: 10, 445 write: []int{5, 7}, 446 exp: 2, 447 }, 448 } { 449 t.Run(test.name, func(t *testing.T) { 450 n := writeCounter{} 451 w := NewWriterSize(&n, 0, ws.OpText, test.cap) 452 453 for _, n := range test.write { 454 text := bytes.Repeat([]byte{'x'}, n) 455 if _, err := w.Write(text); err != nil { 456 t.Fatal(err) 457 } 458 } 459 460 if err := w.Flush(); err != nil { 461 t.Fatal(err) 462 } 463 464 if act, exp := n.n, test.exp; act != exp { 465 t.Errorf("made %d Write() calls to dest writer; want %d", act, exp) 466 } 467 }) 468 } 469 } 470 471 func TestWriterNoPreemtiveFlush(t *testing.T) { 472 n := writeCounter{} 473 w := NewWriterSize(&n, 0, 0, 10) 474 475 // Fill buffer. 476 if _, err := w.Write([]byte("0123456789")); err != nil { 477 t.Fatal(err) 478 } 479 if n.n != 0 { 480 t.Fatalf( 481 "after filling up Writer got %d writes to the dest; want 0", 482 n.n, 483 ) 484 } 485 } 486 487 type writeCounter struct { 488 n int 489 } 490 491 func (w *writeCounter) Write(p []byte) (int, error) { 492 w.n++ 493 return len(p), nil 494 } 495 496 func frames(p []byte) (ret []ws.Frame) { 497 r := bytes.NewReader(p) 498 for stop := false; !stop; { 499 f, err := ws.ReadFrame(r) 500 if err != nil { 501 if err == io.EOF { 502 break 503 } 504 panic(err) 505 } 506 ret = append(ret, f) 507 } 508 return 509 } 510 511 func pretty(f ...ws.Frame) string { 512 str := "\n" 513 for _, f := range f { 514 str += fmt.Sprintf("\t%#v\n\t%#x (%s)\n\t----\n", f.Header, f.Payload, f.Payload) 515 } 516 return str 517 } 518 519 func omitMask(f ws.Frame) ws.Frame { 520 if f.Header.Masked { 521 p := make([]byte, int(f.Header.Length)) 522 copy(p, f.Payload) 523 524 ws.Cipher(p, f.Header.Mask, 0) 525 526 f.Header.Mask = [4]byte{0, 0, 0, 0} 527 f.Payload = p 528 } 529 return f 530 } 531 532 func omitMasks(f []ws.Frame) []ws.Frame { 533 for i := 0; i < len(f); i++ { 534 f[i] = omitMask(f[i]) 535 } 536 return f 537 } 538 539 func bts(b ...[]byte) [][]byte { return b }