github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/writer_test.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"strconv"
     9  	"testing"
    10  	"unsafe"
    11  
    12  	"github.com/simonmittag/ws"
    13  )
    14  
    15  // TODO(gobwas): test NewWriterSize on edge cases for offset.
    16  
    17  const (
    18  	bitsize = 32 << (^uint(0) >> 63)
    19  	maxint  = int(^uint(1 << (bitsize - 1)))
    20  )
    21  
    22  func TestControlWriter(t *testing.T) {
    23  	const (
    24  		server = ws.StateServerSide
    25  		client = ws.StateClientSide
    26  	)
    27  	for _, test := range []struct {
    28  		name  string
    29  		size  int
    30  		write []byte
    31  		state ws.State
    32  		op    ws.OpCode
    33  		exp   ws.Frame
    34  		err   bool
    35  	}{
    36  		{
    37  			state: server,
    38  			op:    ws.OpPing,
    39  			exp:   ws.NewPingFrame(nil),
    40  		},
    41  		{
    42  			write: []byte("0123456789"),
    43  			state: server,
    44  			op:    ws.OpPing,
    45  			exp:   ws.NewPingFrame([]byte("0123456789")),
    46  		},
    47  		{
    48  			size:  10 + reserve(server, 10),
    49  			write: []byte("0123456789"),
    50  			state: server,
    51  			op:    ws.OpPing,
    52  			exp:   ws.NewPingFrame([]byte("0123456789")),
    53  		},
    54  		{
    55  			size:  10 + reserve(server, 10),
    56  			write: []byte("0123456789a"),
    57  			state: server,
    58  			op:    ws.OpPing,
    59  			err:   true,
    60  		},
    61  		{
    62  			write: bytes.Repeat([]byte{'x'}, ws.MaxControlFramePayloadSize+1),
    63  			state: server,
    64  			op:    ws.OpPing,
    65  			err:   true,
    66  		},
    67  	} {
    68  		t.Run(test.name, func(t *testing.T) {
    69  			var buf bytes.Buffer
    70  			var w *ControlWriter
    71  			if n := test.size; n == 0 {
    72  				w = NewControlWriter(&buf, test.state, test.op)
    73  			} else {
    74  				p := make([]byte, n)
    75  				w = NewControlWriterBuffer(&buf, test.state, test.op, p)
    76  			}
    77  
    78  			_, err := w.Write(test.write)
    79  			if err == nil {
    80  				err = w.Flush()
    81  			}
    82  			if test.err {
    83  				if err == nil {
    84  					t.Errorf("want error")
    85  				}
    86  				return
    87  			}
    88  			if !test.err && err != nil {
    89  				t.Errorf("unexpected error: %v", err)
    90  				return
    91  			}
    92  
    93  			act, err := ws.ReadFrame(&buf)
    94  			if err != nil {
    95  				t.Fatal(err)
    96  			}
    97  
    98  			act = omitMask(act)
    99  			exp := omitMask(test.exp)
   100  			if !reflect.DeepEqual(act, exp) {
   101  				t.Errorf("unexpected frame:\nflushed: %v\nwant: %v", pretty(act), pretty(exp))
   102  			}
   103  		})
   104  	}
   105  }
   106  
   107  type reserveTestCase struct {
   108  	name      string
   109  	buf       int
   110  	state     ws.State
   111  	expOffset int
   112  	panic     bool
   113  }
   114  
   115  func genReserveTestCases(s ws.State, n, m, exp int) []reserveTestCase {
   116  	ret := make([]reserveTestCase, m-n)
   117  	for i := n; i < m; i++ {
   118  		var suffix string
   119  		if s.ClientSide() {
   120  			suffix = " masked"
   121  		}
   122  
   123  		ret[i-n] = reserveTestCase{
   124  			name:      "gen " + strconv.Itoa(i) + suffix,
   125  			buf:       i,
   126  			state:     s,
   127  			expOffset: exp,
   128  		}
   129  	}
   130  	return ret
   131  }
   132  
   133  func fakeMake(n int) (r []byte) {
   134  	rh := (*reflect.SliceHeader)(unsafe.Pointer(&r))
   135  	*rh = reflect.SliceHeader{
   136  		Len: n,
   137  		Cap: n,
   138  	}
   139  	return r
   140  }
   141  
   142  var reserveTestCases = []reserveTestCase{
   143  	{
   144  		name:      "len7",
   145  		buf:       int(len7) + 2,
   146  		expOffset: 2,
   147  	},
   148  	{
   149  		name:      "len16",
   150  		buf:       int(len16) + 4,
   151  		expOffset: 4,
   152  	},
   153  	{
   154  		name:      "maxint",
   155  		buf:       maxint,
   156  		expOffset: 10,
   157  	},
   158  	{
   159  		name:      "len7 masked",
   160  		buf:       int(len7) + 6,
   161  		state:     ws.StateClientSide,
   162  		expOffset: 6,
   163  	},
   164  	{
   165  		name:      "len16 masked",
   166  		buf:       int(len16) + 8,
   167  		state:     ws.StateClientSide,
   168  		expOffset: 8,
   169  	},
   170  	{
   171  		name:      "maxint masked",
   172  		buf:       maxint,
   173  		state:     ws.StateClientSide,
   174  		expOffset: 14,
   175  	},
   176  	{
   177  		name:      "split case",
   178  		buf:       128,
   179  		expOffset: 4,
   180  	},
   181  }
   182  
   183  func TestNewWriterBuffer(t *testing.T) {
   184  	cases := append(
   185  		reserveTestCases,
   186  		reserveTestCase{
   187  			name:  "panic",
   188  			buf:   2,
   189  			panic: true,
   190  		},
   191  		reserveTestCase{
   192  			name:  "panic",
   193  			buf:   6,
   194  			state: ws.StateClientSide,
   195  			panic: true,
   196  		},
   197  	)
   198  	cases = append(cases, genReserveTestCases(0, int(len7)-2, int(len7)+2, 2)...)
   199  	cases = append(cases, genReserveTestCases(0, int(len16)-4, int(len16)+4, 4)...)
   200  	cases = append(cases, genReserveTestCases(0, maxint-10, maxint, 10)...)
   201  
   202  	cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len7)-6, int(len7)+6, 6)...)
   203  	cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len16)-8, int(len16)+8, 8)...)
   204  	cases = append(cases, genReserveTestCases(ws.StateClientSide, maxint-14, maxint, 14)...)
   205  
   206  	for _, test := range cases {
   207  		t.Run(test.name, func(t *testing.T) {
   208  			defer func() {
   209  				thePanic := recover()
   210  				if test.panic && thePanic == nil {
   211  					t.Errorf("expected panic")
   212  				}
   213  				if !test.panic && thePanic != nil {
   214  					t.Errorf("unexpected panic: %v", thePanic)
   215  				}
   216  			}()
   217  			w := NewWriterBuffer(nil, test.state, 0, fakeMake(test.buf))
   218  			if act, exp := len(w.raw)-len(w.buf), test.expOffset; act != exp {
   219  				t.Errorf(
   220  					"NewWriteBuffer(%d bytes) has offset %d; want %d",
   221  					test.buf, act, exp,
   222  				)
   223  			}
   224  		})
   225  	}
   226  }
   227  
   228  func TestWriter(t *testing.T) {
   229  	for i, test := range []struct {
   230  		label  string
   231  		size   int
   232  		state  ws.State
   233  		data   [][]byte
   234  		expFrm []ws.Frame
   235  		expBts []byte
   236  	}{
   237  		// No Write(), no frames.
   238  		{},
   239  
   240  		{
   241  			data: [][]byte{
   242  				{},
   243  			},
   244  			expBts: ws.MustCompileFrame(ws.NewTextFrame(nil)),
   245  		},
   246  		{
   247  			data: [][]byte{
   248  				[]byte("hello, world!"),
   249  			},
   250  			expBts: ws.MustCompileFrame(ws.NewTextFrame([]byte("hello, world!"))),
   251  		},
   252  		{
   253  			state: ws.StateClientSide,
   254  			data: [][]byte{
   255  				[]byte("hello, world!"),
   256  			},
   257  			expFrm: []ws.Frame{ws.MaskFrame(ws.NewTextFrame([]byte("hello, world!")))},
   258  		},
   259  		{
   260  			size: 5,
   261  			data: [][]byte{
   262  				[]byte("hello"),
   263  				[]byte(", wor"),
   264  				[]byte("ld!"),
   265  			},
   266  			expBts: bytes.Join(
   267  				bts(
   268  					ws.MustCompileFrame(ws.Frame{
   269  						Header: ws.Header{
   270  							Fin:    false,
   271  							OpCode: ws.OpText,
   272  							Length: 5,
   273  						},
   274  						Payload: []byte("hello"),
   275  					}),
   276  					ws.MustCompileFrame(ws.Frame{
   277  						Header: ws.Header{
   278  							Fin:    false,
   279  							OpCode: ws.OpContinuation,
   280  							Length: 5,
   281  						},
   282  						Payload: []byte(", wor"),
   283  					}),
   284  					ws.MustCompileFrame(ws.Frame{
   285  						Header: ws.Header{
   286  							Fin:    true,
   287  							OpCode: ws.OpContinuation,
   288  							Length: 3,
   289  						},
   290  						Payload: []byte("ld!"),
   291  					}),
   292  				),
   293  				nil,
   294  			),
   295  		},
   296  		{ // Large write case.
   297  			size: 5,
   298  			data: [][]byte{
   299  				[]byte("hello, world!"),
   300  			},
   301  			expBts: bytes.Join(
   302  				bts(
   303  					ws.MustCompileFrame(ws.Frame{
   304  						Header: ws.Header{
   305  							Fin:    false,
   306  							OpCode: ws.OpText,
   307  							Length: 13,
   308  						},
   309  						Payload: []byte("hello, world!"),
   310  					}),
   311  					ws.MustCompileFrame(ws.Frame{
   312  						Header: ws.Header{
   313  							Fin:    true,
   314  							OpCode: ws.OpContinuation,
   315  							Length: 0,
   316  						},
   317  					}),
   318  				),
   319  				nil,
   320  			),
   321  		},
   322  	} {
   323  		t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
   324  			buf := &bytes.Buffer{}
   325  			w := NewWriterSize(buf, test.state, ws.OpText, test.size)
   326  
   327  			for _, p := range test.data {
   328  				_, err := w.Write(p)
   329  				if err != nil {
   330  					t.Fatalf("unexpected Write() error: %s", err)
   331  				}
   332  			}
   333  			if err := w.Flush(); err != nil {
   334  				t.Fatalf("unexpected Flush() error: %s", err)
   335  			}
   336  			if test.expBts != nil {
   337  				if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) {
   338  					t.Errorf(
   339  						"wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts,
   340  						pretty(frames(bts)...), pretty(frames(test.expBts)...),
   341  					)
   342  				}
   343  			}
   344  			if test.expFrm != nil {
   345  				act := omitMasks(frames(buf.Bytes()))
   346  				exp := omitMasks(test.expFrm)
   347  
   348  				if !reflect.DeepEqual(act, exp) {
   349  					t.Errorf(
   350  						"wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n",
   351  						pretty(act...), pretty(exp...),
   352  					)
   353  				}
   354  			}
   355  		})
   356  	}
   357  }
   358  
   359  func TestWriterLargeWrite(t *testing.T) {
   360  	var dest bytes.Buffer
   361  	w := NewWriterSize(&dest, 0, 0, 16)
   362  
   363  	// Test that event for big writes extensions set their bits.
   364  	var rsv = [3]bool{true, true, false}
   365  	w.SetExtensions(SendExtensionFunc(func(h ws.Header) (ws.Header, error) {
   366  		h.Rsv = ws.Rsv(rsv[0], rsv[1], rsv[2])
   367  		return h, nil
   368  	}))
   369  
   370  	// Write message with size twice bigger than writer's internal buffer.
   371  	// We expect Writer to write it directly without buffering since we didn't
   372  	// write anything before (no data in internal buffer).
   373  	bts := make([]byte, 2*w.Size())
   374  	if _, err := w.Write(bts); err != nil {
   375  		t.Fatal(err)
   376  	}
   377  	if err := w.Flush(); err != nil {
   378  		t.Fatal(err)
   379  	}
   380  
   381  	frame, err := ws.ReadFrame(&dest)
   382  	if err != nil {
   383  		t.Fatalf("can't read frame: %v", err)
   384  	}
   385  
   386  	var act [3]bool
   387  	act[0], act[1], act[2] = ws.RsvBits(frame.Header.Rsv)
   388  	if act != rsv {
   389  		t.Fatalf("unexpected rsv bits sent: %v; extension set %v", act, rsv)
   390  	}
   391  }
   392  
   393  func TestWriterReadFrom(t *testing.T) {
   394  	for i, test := range []struct {
   395  		label string
   396  		chop  int
   397  		size  int
   398  		data  []byte
   399  		exp   []ws.Frame
   400  		n     int64
   401  	}{
   402  		{
   403  			chop: 1,
   404  			size: 1,
   405  			data: []byte("golang"),
   406  			exp: []ws.Frame{
   407  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpText}, Payload: []byte{'g'}},
   408  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'o'}},
   409  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'l'}},
   410  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'a'}},
   411  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'n'}},
   412  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'g'}},
   413  				{Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpContinuation}},
   414  			},
   415  			n: 6,
   416  		},
   417  		{
   418  			chop: 1,
   419  			size: 4,
   420  			data: []byte("golang"),
   421  			exp: []ws.Frame{
   422  				{Header: ws.Header{Fin: false, Length: 4, OpCode: ws.OpText}, Payload: []byte("gola")},
   423  				{Header: ws.Header{Fin: true, Length: 2, OpCode: ws.OpContinuation}, Payload: []byte("ng")},
   424  			},
   425  			n: 6,
   426  		},
   427  		{
   428  			size: 64,
   429  			data: []byte{},
   430  			exp: []ws.Frame{
   431  				{Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpText}},
   432  			},
   433  			n: 0,
   434  		},
   435  	} {
   436  		t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
   437  			dst := &bytes.Buffer{}
   438  			wr := NewWriterSize(dst, 0, ws.OpText, test.size)
   439  
   440  			chop := test.chop
   441  			if chop == 0 {
   442  				chop = 128
   443  			}
   444  			src := &chopReader{bytes.NewReader(test.data), chop}
   445  
   446  			n, err := wr.ReadFrom(src)
   447  			if err == nil {
   448  				err = wr.Flush()
   449  			}
   450  			if err != nil {
   451  				t.Fatalf("unexpected error: %s", err)
   452  			}
   453  			if n != test.n {
   454  				t.Errorf("ReadFrom() read out %d; want %d", n, test.n)
   455  			}
   456  			if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) {
   457  				t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames...), pretty(test.exp...))
   458  			}
   459  		})
   460  	}
   461  }
   462  
   463  func TestWriterWriteCount(t *testing.T) {
   464  	for _, test := range []struct {
   465  		name  string
   466  		cap   int
   467  		exp   int
   468  		write []int // For ability to avoid large write inside Write()'s "if".
   469  	}{
   470  		{
   471  			name:  "one frame",
   472  			cap:   10,
   473  			write: []int{10},
   474  			exp:   1,
   475  		},
   476  		{
   477  			name:  "two frames",
   478  			cap:   10,
   479  			write: []int{5, 7},
   480  			exp:   2,
   481  		},
   482  	} {
   483  		t.Run(test.name, func(t *testing.T) {
   484  			n := writeCounter{}
   485  			w := NewWriterSize(&n, 0, ws.OpText, test.cap)
   486  
   487  			for _, n := range test.write {
   488  				text := bytes.Repeat([]byte{'x'}, n)
   489  				if _, err := w.Write(text); err != nil {
   490  					t.Fatal(err)
   491  				}
   492  			}
   493  
   494  			if err := w.Flush(); err != nil {
   495  				t.Fatal(err)
   496  			}
   497  
   498  			if act, exp := n.n, test.exp; act != exp {
   499  				t.Errorf("made %d Write() calls to dest writer; want %d", act, exp)
   500  			}
   501  		})
   502  	}
   503  }
   504  
   505  func TestWriterNoPreemtiveFlush(t *testing.T) {
   506  	n := writeCounter{}
   507  	w := NewWriterSize(&n, 0, 0, 10)
   508  
   509  	// Fill buffer.
   510  	if _, err := w.Write([]byte("0123456789")); err != nil {
   511  		t.Fatal(err)
   512  	}
   513  	if n.n != 0 {
   514  		t.Fatalf(
   515  			"after filling up Writer got %d writes to the dest; want 0",
   516  			n.n,
   517  		)
   518  	}
   519  }
   520  
   521  type writeCounter struct {
   522  	n int
   523  }
   524  
   525  func (w *writeCounter) Write(p []byte) (int, error) {
   526  	w.n++
   527  	return len(p), nil
   528  }
   529  
   530  func frames(p []byte) (ret []ws.Frame) {
   531  	r := bytes.NewReader(p)
   532  	for stop := false; !stop; {
   533  		f, err := ws.ReadFrame(r)
   534  		if err != nil {
   535  			if err == io.EOF {
   536  				break
   537  			}
   538  			panic(err)
   539  		}
   540  		ret = append(ret, f)
   541  	}
   542  	return
   543  }
   544  
   545  func pretty(f ...ws.Frame) string {
   546  	str := "\n"
   547  	for _, f := range f {
   548  		str += fmt.Sprintf("\t%#v\n\t%#x (%#q)\n\t----\n", f.Header, f.Payload, f.Payload)
   549  	}
   550  	return str
   551  }
   552  
   553  func omitMask(f ws.Frame) ws.Frame {
   554  	if f.Header.Masked {
   555  		p := make([]byte, int(f.Header.Length))
   556  		copy(p, f.Payload)
   557  
   558  		ws.Cipher(p, f.Header.Mask, 0)
   559  
   560  		f.Header.Mask = [4]byte{0, 0, 0, 0}
   561  		f.Payload = p
   562  	}
   563  	return f
   564  }
   565  
   566  func omitMasks(f []ws.Frame) []ws.Frame {
   567  	for i := 0; i < len(f); i++ {
   568  		f[i] = omitMask(f[i])
   569  	}
   570  	return f
   571  }
   572  
   573  func bts(b ...[]byte) [][]byte { return b }