github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/writer_test.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"strconv"
     9  	"testing"
    10  	"unsafe"
    11  
    12  	"github.com/ezoic/ws"
    13  )
    14  
    15  // TODO(ezoic): test NewWriterSize on edge cases for offset.
    16  
    17  const (
    18  	bitsize = 32 << (^uint(0) >> 63)
    19  	maxint  = int(^uint(1 << (bitsize - 1)))
    20  )
    21  
    22  func TestControlWriter(t *testing.T) {
    23  	const (
    24  		server = ws.StateServerSide
    25  		client = ws.StateClientSide
    26  	)
    27  	for _, test := range []struct {
    28  		name  string
    29  		size  int
    30  		write []byte
    31  		state ws.State
    32  		op    ws.OpCode
    33  		exp   ws.Frame
    34  		err   bool
    35  	}{
    36  		{
    37  			state: server,
    38  			op:    ws.OpPing,
    39  			exp:   ws.NewPingFrame(nil),
    40  		},
    41  		{
    42  			write: []byte("0123456789"),
    43  			state: server,
    44  			op:    ws.OpPing,
    45  			exp:   ws.NewPingFrame([]byte("0123456789")),
    46  		},
    47  		{
    48  			size:  10 + reserve(server, 10),
    49  			write: []byte("0123456789"),
    50  			state: server,
    51  			op:    ws.OpPing,
    52  			exp:   ws.NewPingFrame([]byte("0123456789")),
    53  		},
    54  		{
    55  			size:  10 + reserve(server, 10),
    56  			write: []byte("0123456789a"),
    57  			state: server,
    58  			op:    ws.OpPing,
    59  			err:   true,
    60  		},
    61  		{
    62  			write: bytes.Repeat([]byte{'x'}, ws.MaxControlFramePayloadSize+1),
    63  			state: server,
    64  			op:    ws.OpPing,
    65  			err:   true,
    66  		},
    67  	} {
    68  		t.Run(test.name, func(t *testing.T) {
    69  			var buf bytes.Buffer
    70  			var w *ControlWriter
    71  			if n := test.size; n == 0 {
    72  				w = NewControlWriter(&buf, test.state, test.op)
    73  			} else {
    74  				p := make([]byte, n)
    75  				w = NewControlWriterBuffer(&buf, test.state, test.op, p)
    76  			}
    77  
    78  			_, err := w.Write(test.write)
    79  			if err == nil {
    80  				err = w.Flush()
    81  			}
    82  			if test.err {
    83  				if err == nil {
    84  					t.Errorf("want error")
    85  				}
    86  				return
    87  			}
    88  			if !test.err && err != nil {
    89  				t.Errorf("unexpected error: %v", err)
    90  				return
    91  			}
    92  
    93  			act, err := ws.ReadFrame(&buf)
    94  			if err != nil {
    95  				t.Fatal(err)
    96  			}
    97  
    98  			act = omitMask(act)
    99  			exp := omitMask(test.exp)
   100  			if !reflect.DeepEqual(act, exp) {
   101  				t.Errorf("unexpected frame:\nflushed: %v\nwant: %v", pretty(act), pretty(exp))
   102  			}
   103  		})
   104  	}
   105  }
   106  
   107  type reserveTestCase struct {
   108  	name      string
   109  	buf       int
   110  	state     ws.State
   111  	expOffset int
   112  	panic     bool
   113  }
   114  
   115  func genReserveTestCases(s ws.State, n, m, exp int) []reserveTestCase {
   116  	ret := make([]reserveTestCase, m-n)
   117  	for i := n; i < m; i++ {
   118  		var suffix string
   119  		if s.ClientSide() {
   120  			suffix = " masked"
   121  		}
   122  
   123  		ret[i-n] = reserveTestCase{
   124  			name:      "gen " + strconv.Itoa(i) + suffix,
   125  			buf:       i,
   126  			state:     s,
   127  			expOffset: exp,
   128  		}
   129  	}
   130  	return ret
   131  }
   132  
   133  func fakeMake(n int) (r []byte) {
   134  	rh := (*reflect.SliceHeader)(unsafe.Pointer(&r))
   135  	*rh = reflect.SliceHeader{
   136  		Len: n,
   137  		Cap: n,
   138  	}
   139  	return r
   140  }
   141  
   142  var reserveTestCases = []reserveTestCase{
   143  	{
   144  		name:      "len7",
   145  		buf:       int(len7) + 2,
   146  		expOffset: 2,
   147  	},
   148  	{
   149  		name:      "len16",
   150  		buf:       int(len16) + 4,
   151  		expOffset: 4,
   152  	},
   153  	{
   154  		name:      "maxint",
   155  		buf:       maxint,
   156  		expOffset: 10,
   157  	},
   158  	{
   159  		name:      "len7 masked",
   160  		buf:       int(len7) + 6,
   161  		state:     ws.StateClientSide,
   162  		expOffset: 6,
   163  	},
   164  	{
   165  		name:      "len16 masked",
   166  		buf:       int(len16) + 8,
   167  		state:     ws.StateClientSide,
   168  		expOffset: 8,
   169  	},
   170  	{
   171  		name:      "maxint masked",
   172  		buf:       maxint,
   173  		state:     ws.StateClientSide,
   174  		expOffset: 14,
   175  	},
   176  	{
   177  		name:      "split case",
   178  		buf:       128,
   179  		expOffset: 4,
   180  	},
   181  }
   182  
   183  func TestNewWriterBuffer(t *testing.T) {
   184  	cases := append(
   185  		reserveTestCases,
   186  		reserveTestCase{
   187  			name:  "panic",
   188  			buf:   2,
   189  			panic: true,
   190  		},
   191  		reserveTestCase{
   192  			name:  "panic",
   193  			buf:   6,
   194  			state: ws.StateClientSide,
   195  			panic: true,
   196  		},
   197  	)
   198  	cases = append(cases, genReserveTestCases(0, int(len7)-2, int(len7)+2, 2)...)
   199  	cases = append(cases, genReserveTestCases(0, int(len16)-4, int(len16)+4, 4)...)
   200  	cases = append(cases, genReserveTestCases(0, maxint-10, maxint, 10)...)
   201  
   202  	cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len7)-6, int(len7)+6, 6)...)
   203  	cases = append(cases, genReserveTestCases(ws.StateClientSide, int(len16)-8, int(len16)+8, 8)...)
   204  	cases = append(cases, genReserveTestCases(ws.StateClientSide, maxint-14, maxint, 14)...)
   205  
   206  	for _, test := range cases {
   207  		t.Run(test.name, func(t *testing.T) {
   208  			defer func() {
   209  				thePanic := recover()
   210  				if test.panic && thePanic == nil {
   211  					t.Errorf("expected panic")
   212  				}
   213  				if !test.panic && thePanic != nil {
   214  					t.Errorf("unexpected panic: %v", thePanic)
   215  				}
   216  			}()
   217  			w := NewWriterBuffer(nil, test.state, 0, fakeMake(test.buf))
   218  			if act, exp := len(w.raw)-len(w.buf), test.expOffset; act != exp {
   219  				t.Errorf(
   220  					"NewWriteBuffer(%d bytes) has offset %d; want %d",
   221  					test.buf, act, exp,
   222  				)
   223  			}
   224  		})
   225  	}
   226  }
   227  
   228  func TestWriter(t *testing.T) {
   229  	for i, test := range []struct {
   230  		label  string
   231  		size   int
   232  		state  ws.State
   233  		data   [][]byte
   234  		expFrm []ws.Frame
   235  		expBts []byte
   236  	}{
   237  		// No Write(), no frames.
   238  		{},
   239  
   240  		{
   241  			data: [][]byte{
   242  				{},
   243  			},
   244  			expBts: ws.MustCompileFrame(ws.NewTextFrame(nil)),
   245  		},
   246  		{
   247  			data: [][]byte{
   248  				[]byte("hello, world!"),
   249  			},
   250  			expBts: ws.MustCompileFrame(ws.NewTextFrame([]byte("hello, world!"))),
   251  		},
   252  		{
   253  			state: ws.StateClientSide,
   254  			data: [][]byte{
   255  				[]byte("hello, world!"),
   256  			},
   257  			expFrm: []ws.Frame{ws.MaskFrame(ws.NewTextFrame([]byte("hello, world!")))},
   258  		},
   259  		{
   260  			size: 5,
   261  			data: [][]byte{
   262  				[]byte("hello"),
   263  				[]byte(", wor"),
   264  				[]byte("ld!"),
   265  			},
   266  			expBts: bytes.Join(
   267  				bts(
   268  					ws.MustCompileFrame(ws.Frame{
   269  						Header: ws.Header{
   270  							Fin:    false,
   271  							OpCode: ws.OpText,
   272  							Length: 5,
   273  						},
   274  						Payload: []byte("hello"),
   275  					}),
   276  					ws.MustCompileFrame(ws.Frame{
   277  						Header: ws.Header{
   278  							Fin:    false,
   279  							OpCode: ws.OpContinuation,
   280  							Length: 5,
   281  						},
   282  						Payload: []byte(", wor"),
   283  					}),
   284  					ws.MustCompileFrame(ws.Frame{
   285  						Header: ws.Header{
   286  							Fin:    true,
   287  							OpCode: ws.OpContinuation,
   288  							Length: 3,
   289  						},
   290  						Payload: []byte("ld!"),
   291  					}),
   292  				),
   293  				nil,
   294  			),
   295  		},
   296  		{ // Large write case.
   297  			size: 5,
   298  			data: [][]byte{
   299  				[]byte("hello, world!"),
   300  			},
   301  			expBts: bytes.Join(
   302  				bts(
   303  					ws.MustCompileFrame(ws.Frame{
   304  						Header: ws.Header{
   305  							Fin:    false,
   306  							OpCode: ws.OpText,
   307  							Length: 13,
   308  						},
   309  						Payload: []byte("hello, world!"),
   310  					}),
   311  					ws.MustCompileFrame(ws.Frame{
   312  						Header: ws.Header{
   313  							Fin:    true,
   314  							OpCode: ws.OpContinuation,
   315  							Length: 0,
   316  						},
   317  					}),
   318  				),
   319  				nil,
   320  			),
   321  		},
   322  	} {
   323  		t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
   324  			buf := &bytes.Buffer{}
   325  			w := NewWriterSize(buf, test.state, ws.OpText, test.size)
   326  
   327  			for _, p := range test.data {
   328  				_, err := w.Write(p)
   329  				if err != nil {
   330  					t.Fatalf("unexpected Write() error: %s", err)
   331  				}
   332  			}
   333  			if err := w.Flush(); err != nil {
   334  				t.Fatalf("unexpected Flush() error: %s", err)
   335  			}
   336  			if test.expBts != nil {
   337  				if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) {
   338  					t.Errorf(
   339  						"wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts,
   340  						pretty(frames(bts)...), pretty(frames(test.expBts)...),
   341  					)
   342  				}
   343  			}
   344  			if test.expFrm != nil {
   345  				act := omitMasks(frames(buf.Bytes()))
   346  				exp := omitMasks(test.expFrm)
   347  
   348  				if !reflect.DeepEqual(act, exp) {
   349  					t.Errorf(
   350  						"wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n",
   351  						pretty(act...), pretty(exp...),
   352  					)
   353  				}
   354  			}
   355  		})
   356  	}
   357  }
   358  
   359  func TestWriterReadFrom(t *testing.T) {
   360  	for i, test := range []struct {
   361  		label string
   362  		chop  int
   363  		size  int
   364  		data  []byte
   365  		exp   []ws.Frame
   366  		n     int64
   367  	}{
   368  		{
   369  			chop: 1,
   370  			size: 1,
   371  			data: []byte("golang"),
   372  			exp: []ws.Frame{
   373  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpText}, Payload: []byte{'g'}},
   374  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'o'}},
   375  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'l'}},
   376  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'a'}},
   377  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'n'}},
   378  				{Header: ws.Header{Fin: false, Length: 1, OpCode: ws.OpContinuation}, Payload: []byte{'g'}},
   379  				{Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpContinuation}},
   380  			},
   381  			n: 6,
   382  		},
   383  		{
   384  			chop: 1,
   385  			size: 4,
   386  			data: []byte("golang"),
   387  			exp: []ws.Frame{
   388  				{Header: ws.Header{Fin: false, Length: 4, OpCode: ws.OpText}, Payload: []byte("gola")},
   389  				{Header: ws.Header{Fin: true, Length: 2, OpCode: ws.OpContinuation}, Payload: []byte("ng")},
   390  			},
   391  			n: 6,
   392  		},
   393  		{
   394  			size: 64,
   395  			data: []byte{},
   396  			exp: []ws.Frame{
   397  				{Header: ws.Header{Fin: true, Length: 0, OpCode: ws.OpText}},
   398  			},
   399  			n: 0,
   400  		},
   401  	} {
   402  		t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
   403  			dst := &bytes.Buffer{}
   404  			wr := NewWriterSize(dst, 0, ws.OpText, test.size)
   405  
   406  			chop := test.chop
   407  			if chop == 0 {
   408  				chop = 128
   409  			}
   410  			src := &chopReader{bytes.NewReader(test.data), chop}
   411  
   412  			n, err := wr.ReadFrom(src)
   413  			if err == nil {
   414  				err = wr.Flush()
   415  			}
   416  			if err != nil {
   417  				t.Fatalf("unexpected error: %s", err)
   418  			}
   419  			if n != test.n {
   420  				t.Errorf("ReadFrom() read out %d; want %d", n, test.n)
   421  			}
   422  			if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) {
   423  				t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames...), pretty(test.exp...))
   424  			}
   425  		})
   426  	}
   427  }
   428  
   429  func TestWriterWriteCount(t *testing.T) {
   430  	for _, test := range []struct {
   431  		name  string
   432  		cap   int
   433  		exp   int
   434  		write []int // For ability to avoid large write inside Write()'s "if".
   435  	}{
   436  		{
   437  			name:  "one frame",
   438  			cap:   10,
   439  			write: []int{10},
   440  			exp:   1,
   441  		},
   442  		{
   443  			name:  "two frames",
   444  			cap:   10,
   445  			write: []int{5, 7},
   446  			exp:   2,
   447  		},
   448  	} {
   449  		t.Run(test.name, func(t *testing.T) {
   450  			n := writeCounter{}
   451  			w := NewWriterSize(&n, 0, ws.OpText, test.cap)
   452  
   453  			for _, n := range test.write {
   454  				text := bytes.Repeat([]byte{'x'}, n)
   455  				if _, err := w.Write(text); err != nil {
   456  					t.Fatal(err)
   457  				}
   458  			}
   459  
   460  			if err := w.Flush(); err != nil {
   461  				t.Fatal(err)
   462  			}
   463  
   464  			if act, exp := n.n, test.exp; act != exp {
   465  				t.Errorf("made %d Write() calls to dest writer; want %d", act, exp)
   466  			}
   467  		})
   468  	}
   469  }
   470  
   471  func TestWriterNoPreemtiveFlush(t *testing.T) {
   472  	n := writeCounter{}
   473  	w := NewWriterSize(&n, 0, 0, 10)
   474  
   475  	// Fill buffer.
   476  	if _, err := w.Write([]byte("0123456789")); err != nil {
   477  		t.Fatal(err)
   478  	}
   479  	if n.n != 0 {
   480  		t.Fatalf(
   481  			"after filling up Writer got %d writes to the dest; want 0",
   482  			n.n,
   483  		)
   484  	}
   485  }
   486  
   487  type writeCounter struct {
   488  	n int
   489  }
   490  
   491  func (w *writeCounter) Write(p []byte) (int, error) {
   492  	w.n++
   493  	return len(p), nil
   494  }
   495  
   496  func frames(p []byte) (ret []ws.Frame) {
   497  	r := bytes.NewReader(p)
   498  	for stop := false; !stop; {
   499  		f, err := ws.ReadFrame(r)
   500  		if err != nil {
   501  			if err == io.EOF {
   502  				break
   503  			}
   504  			panic(err)
   505  		}
   506  		ret = append(ret, f)
   507  	}
   508  	return
   509  }
   510  
   511  func pretty(f ...ws.Frame) string {
   512  	str := "\n"
   513  	for _, f := range f {
   514  		str += fmt.Sprintf("\t%#v\n\t%#x (%s)\n\t----\n", f.Header, f.Payload, f.Payload)
   515  	}
   516  	return str
   517  }
   518  
   519  func omitMask(f ws.Frame) ws.Frame {
   520  	if f.Header.Masked {
   521  		p := make([]byte, int(f.Header.Length))
   522  		copy(p, f.Payload)
   523  
   524  		ws.Cipher(p, f.Header.Mask, 0)
   525  
   526  		f.Header.Mask = [4]byte{0, 0, 0, 0}
   527  		f.Payload = p
   528  	}
   529  	return f
   530  }
   531  
   532  func omitMasks(f []ws.Frame) []ws.Frame {
   533  	for i := 0; i < len(f); i++ {
   534  		f[i] = omitMask(f[i])
   535  	}
   536  	return f
   537  }
   538  
   539  func bts(b ...[]byte) [][]byte { return b }