github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/reader_test.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"io/ioutil"
     7  	"testing"
     8  	"unicode/utf8"
     9  
    10  	"github.com/simonmittag/ws"
    11  )
    12  
    13  // TODO(gobwas): test continuation discard.
    14  //				 test discard when NextFrame().
    15  
    16  var eofReader = bytes.NewReader(nil)
    17  
    18  func TestReadFromWithIntermediateControl(t *testing.T) {
    19  	var buf bytes.Buffer
    20  
    21  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpText, false, []byte("foo")))
    22  	ws.MustWriteFrame(&buf, ws.NewPingFrame([]byte("ping")))
    23  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, false, []byte("bar")))
    24  	ws.MustWriteFrame(&buf, ws.NewPongFrame([]byte("pong")))
    25  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, true, []byte("baz")))
    26  
    27  	var intermediate [][]byte
    28  	r := Reader{
    29  		Source: &buf,
    30  		OnIntermediate: func(h ws.Header, r io.Reader) error {
    31  			bts, err := ioutil.ReadAll(r)
    32  			if err != nil {
    33  				t.Fatal(err)
    34  			}
    35  			intermediate = append(
    36  				intermediate,
    37  				append(([]byte)(nil), bts...),
    38  			)
    39  			return nil
    40  		},
    41  	}
    42  
    43  	h, err := r.NextFrame()
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	exp := ws.Header{
    48  		Length: 3,
    49  		Fin:    false,
    50  		OpCode: ws.OpText,
    51  	}
    52  	if act := h; act != exp {
    53  		t.Fatalf("unexpected NextFrame() header: %+v; want %+v", act, exp)
    54  	}
    55  
    56  	act, err := ioutil.ReadAll(&r)
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  	if exp := []byte("foobarbaz"); !bytes.Equal(act, exp) {
    61  		t.Errorf("unexpected all bytes: %q; want %q", act, exp)
    62  	}
    63  	if act, exp := len(intermediate), 2; act != exp {
    64  		t.Errorf("unexpected intermediate payload: %d; want %d", act, exp)
    65  	} else {
    66  		for i, exp := range [][]byte{
    67  			[]byte("ping"),
    68  			[]byte("pong"),
    69  		} {
    70  			if act := intermediate[i]; !bytes.Equal(act, exp) {
    71  				t.Errorf(
    72  					"unexpected #%d intermediate payload: %q; want %q",
    73  					i, act, exp,
    74  				)
    75  			}
    76  		}
    77  	}
    78  }
    79  
    80  func TestReaderNoFrameAdvance(t *testing.T) {
    81  	r := Reader{
    82  		Source: eofReader,
    83  	}
    84  	if _, err := r.Read(make([]byte, 10)); err != ErrNoFrameAdvance {
    85  		t.Errorf("Read() returned %v; want %v", err, ErrNoFrameAdvance)
    86  	}
    87  }
    88  
    89  func TestReaderNextFrameAndReadEOF(t *testing.T) {
    90  	for _, test := range []struct {
    91  		source       func() io.Reader
    92  		nextFrameErr error
    93  		readErr      error
    94  	}{
    95  		{
    96  			source:       func() io.Reader { return eofReader },
    97  			nextFrameErr: io.EOF,
    98  			readErr:      ErrNoFrameAdvance,
    99  		},
   100  		{
   101  			source: func() io.Reader {
   102  				// This case tests that ReadMessage still fails after
   103  				// successfully reading header bytes frame via ws.ReadHeader()
   104  				// and non-successfully read of the body.
   105  				var buf bytes.Buffer
   106  				f := ws.NewTextFrame([]byte("this part will be lost"))
   107  				if err := ws.WriteHeader(&buf, f.Header); err != nil {
   108  					panic(err)
   109  				}
   110  				return &buf
   111  			},
   112  			nextFrameErr: nil,
   113  			readErr:      io.ErrUnexpectedEOF,
   114  		},
   115  		{
   116  			source: func() io.Reader {
   117  				var buf bytes.Buffer
   118  				f := ws.NewTextFrame([]byte("foobar"))
   119  				if err := ws.WriteHeader(&buf, f.Header); err != nil {
   120  					panic(err)
   121  				}
   122  				buf.WriteString("foo")
   123  				return &buf
   124  			},
   125  			nextFrameErr: nil,
   126  			readErr:      io.ErrUnexpectedEOF,
   127  		},
   128  		{
   129  			source: func() io.Reader {
   130  				var buf bytes.Buffer
   131  				f := ws.NewFrame(ws.OpText, false, []byte("payload"))
   132  				if err := ws.WriteFrame(&buf, f); err != nil {
   133  					panic(err)
   134  				}
   135  				return &buf
   136  			},
   137  			nextFrameErr: nil,
   138  			readErr:      io.ErrUnexpectedEOF,
   139  		},
   140  	} {
   141  		t.Run("", func(t *testing.T) {
   142  			r := Reader{
   143  				Source: test.source(),
   144  			}
   145  			_, err := r.NextFrame()
   146  			if err != test.nextFrameErr {
   147  				t.Errorf("NextFrame() = %v; want %v", err, test.nextFrameErr)
   148  			}
   149  			var (
   150  				p = make([]byte, 4096)
   151  				i = 0
   152  			)
   153  			for {
   154  				if i == 100 {
   155  					t.Fatal(io.ErrNoProgress)
   156  				}
   157  				_, err := r.Read(p)
   158  				if err == nil {
   159  					continue
   160  				}
   161  				if err != test.readErr {
   162  					t.Errorf("Read() = %v; want %v", err, test.readErr)
   163  				}
   164  				break
   165  			}
   166  		})
   167  	}
   168  
   169  }
   170  
   171  func TestMaxFrameSize(t *testing.T) {
   172  	var buf bytes.Buffer
   173  	msg := []byte("small frame")
   174  	f := ws.NewTextFrame(msg)
   175  	if err := ws.WriteFrame(&buf, f); err != nil {
   176  		t.Fatal(err)
   177  	}
   178  	r := Reader{
   179  		Source:       &buf,
   180  		MaxFrameSize: int64(len(msg)) - 1,
   181  	}
   182  
   183  	_, err := r.NextFrame()
   184  	if got, want := err, ErrFrameTooLarge; got != want {
   185  		t.Errorf("NextFrame() error = %v; want %v", got, want)
   186  	}
   187  
   188  	p := make([]byte, 100)
   189  	n, err := r.Read(p)
   190  	if got, want := err, ErrNoFrameAdvance; got != want {
   191  		t.Errorf("Read() error = %v; want %v", got, want)
   192  	}
   193  	if got, want := n, 0; got != want {
   194  		t.Errorf("Read() bytes returned = %v; want %v", got, want)
   195  	}
   196  }
   197  
   198  func TestReaderUTF8(t *testing.T) {
   199  	yo := []byte("Ё")
   200  	if !utf8.ValidString(string(yo)) {
   201  		panic("bad fixture")
   202  	}
   203  
   204  	var buf bytes.Buffer
   205  	ws.WriteFrame(&buf,
   206  		ws.NewFrame(ws.OpText, false, yo[:1]),
   207  	)
   208  	ws.WriteFrame(&buf,
   209  		ws.NewFrame(ws.OpContinuation, true, yo[1:]),
   210  	)
   211  
   212  	r := Reader{
   213  		Source:    &buf,
   214  		CheckUTF8: true,
   215  	}
   216  	if _, err := r.NextFrame(); err != nil {
   217  		t.Fatal(err)
   218  	}
   219  	bts, err := ioutil.ReadAll(&r)
   220  	if err != nil {
   221  		t.Errorf("unexpected error: %v", err)
   222  	}
   223  	if !bytes.Equal(bts, yo) {
   224  		t.Errorf("ReadAll(r) = %v; want %v", bts, yo)
   225  	}
   226  }
   227  
   228  func TestNextReader(t *testing.T) {
   229  	for _, test := range []struct {
   230  		name string
   231  		seq  []ws.Frame
   232  		chop int
   233  		exp  []byte
   234  		err  error
   235  	}{
   236  		{
   237  			name: "empty",
   238  			seq:  []ws.Frame{},
   239  			err:  io.EOF,
   240  		},
   241  		{
   242  			name: "single",
   243  			seq: []ws.Frame{
   244  				ws.NewTextFrame([]byte("Привет, Мир!")),
   245  			},
   246  			exp: []byte("Привет, Мир!"),
   247  		},
   248  		{
   249  			name: "single_masked",
   250  			seq: []ws.Frame{
   251  				ws.MaskFrame(ws.NewTextFrame([]byte("Привет, Мир!"))),
   252  			},
   253  			exp: []byte("Привет, Мир!"),
   254  		},
   255  		{
   256  			name: "fragmented",
   257  			seq: []ws.Frame{
   258  				ws.NewFrame(ws.OpText, false, []byte("Привет,")),
   259  				ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")),
   260  				ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")),
   261  				ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")),
   262  
   263  				ws.NewTextFrame([]byte("Hello, Brave New World!")),
   264  			},
   265  			exp: []byte("Привет, о дивный, новый Мир!"),
   266  		},
   267  		{
   268  			name: "fragmented_masked",
   269  			seq: []ws.Frame{
   270  				ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))),
   271  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))),
   272  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))),
   273  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))),
   274  
   275  				ws.MaskFrame(ws.NewTextFrame([]byte("Hello, Brave New World!"))),
   276  			},
   277  			exp: []byte("Привет, о дивный, новый Мир!"),
   278  		},
   279  		{
   280  			name: "fragmented_and_control",
   281  			seq: []ws.Frame{
   282  				ws.NewFrame(ws.OpText, false, []byte("Привет,")),
   283  				ws.NewFrame(ws.OpPing, true, nil),
   284  				ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")),
   285  				ws.NewFrame(ws.OpPing, true, nil),
   286  				ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")),
   287  				ws.NewFrame(ws.OpPing, true, nil),
   288  				ws.NewFrame(ws.OpPing, true, []byte("ping info")),
   289  				ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")),
   290  			},
   291  			exp: []byte("Привет, о дивный, новый Мир!"),
   292  		},
   293  		{
   294  			name: "fragmented_and_control_mask",
   295  			seq: []ws.Frame{
   296  				ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))),
   297  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   298  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))),
   299  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   300  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))),
   301  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   302  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, []byte("ping info"))),
   303  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))),
   304  			},
   305  			exp: []byte("Привет, о дивный, новый Мир!"),
   306  		},
   307  	} {
   308  		t.Run(test.name, func(t *testing.T) {
   309  			// Prepare input.
   310  			buf := &bytes.Buffer{}
   311  			for _, f := range test.seq {
   312  				if err := ws.WriteFrame(buf, f); err != nil {
   313  					t.Fatal(err)
   314  				}
   315  			}
   316  
   317  			conn := &chopReader{
   318  				src: bytes.NewReader(buf.Bytes()),
   319  				sz:  test.chop,
   320  			}
   321  
   322  			var bts []byte
   323  			_, reader, err := NextReader(conn, 0)
   324  			if err == nil {
   325  				bts, err = ioutil.ReadAll(reader)
   326  			}
   327  			if err != test.err {
   328  				t.Errorf("unexpected error; got %v; want %v", err, test.err)
   329  				return
   330  			}
   331  			if test.err == nil && !bytes.Equal(bts, test.exp) {
   332  				t.Errorf(
   333  					"ReadAll from reader:\nact:\t%#x\nexp:\t%#x\nact:\t%s\nexp:\t%s\n",
   334  					bts, test.exp, string(bts), string(test.exp),
   335  				)
   336  			}
   337  		})
   338  	}
   339  }
   340  
   341  type chopReader struct {
   342  	src io.Reader
   343  	sz  int
   344  }
   345  
   346  func (c chopReader) Read(p []byte) (n int, err error) {
   347  	sz := c.sz
   348  	if sz == 0 {
   349  		sz = 1
   350  	}
   351  	if sz > len(p) {
   352  		sz = len(p)
   353  	}
   354  	return c.src.Read(p[:sz])
   355  }