github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/reader_test.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"io/ioutil"
     7  	"testing"
     8  	"unicode/utf8"
     9  
    10  	"github.com/ezoic/ws"
    11  )
    12  
    13  // TODO(ezoic): test continuation discard.
    14  //				 test discard when NextFrame().
    15  
    16  var eofReader = bytes.NewReader(nil)
    17  
    18  func TestReadFromWithIntermediateControl(t *testing.T) {
    19  	var buf bytes.Buffer
    20  
    21  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpText, false, []byte("foo")))
    22  	ws.MustWriteFrame(&buf, ws.NewPingFrame([]byte("ping")))
    23  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, false, []byte("bar")))
    24  	ws.MustWriteFrame(&buf, ws.NewPongFrame([]byte("pong")))
    25  	ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, true, []byte("baz")))
    26  
    27  	var intermediate [][]byte
    28  	r := Reader{
    29  		Source: &buf,
    30  		OnIntermediate: func(h ws.Header, r io.Reader) error {
    31  			bts, err := ioutil.ReadAll(r)
    32  			if err != nil {
    33  				t.Fatal(err)
    34  			}
    35  			intermediate = append(
    36  				intermediate,
    37  				append(([]byte)(nil), bts...),
    38  			)
    39  			return nil
    40  		},
    41  	}
    42  
    43  	h, err := r.NextFrame()
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	exp := ws.Header{
    48  		Length: 3,
    49  		Fin:    false,
    50  		OpCode: ws.OpText,
    51  	}
    52  	if act := h; act != exp {
    53  		t.Fatalf("unexpected NextFrame() header: %+v; want %+v", act, exp)
    54  	}
    55  
    56  	act, err := ioutil.ReadAll(&r)
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  	if exp := []byte("foobarbaz"); !bytes.Equal(act, exp) {
    61  		t.Errorf("unexpected all bytes: %q; want %q", act, exp)
    62  	}
    63  	if act, exp := len(intermediate), 2; act != exp {
    64  		t.Errorf("unexpected intermediate payload: %d; want %d", act, exp)
    65  	} else {
    66  		for i, exp := range [][]byte{
    67  			[]byte("ping"),
    68  			[]byte("pong"),
    69  		} {
    70  			if act := intermediate[i]; !bytes.Equal(act, exp) {
    71  				t.Errorf(
    72  					"unexpected #%d intermediate payload: %q; want %q",
    73  					i, act, exp,
    74  				)
    75  			}
    76  		}
    77  	}
    78  }
    79  
    80  func TestReaderNoFrameAdvance(t *testing.T) {
    81  	r := Reader{
    82  		Source: eofReader,
    83  	}
    84  	if _, err := r.Read(make([]byte, 10)); err != ErrNoFrameAdvance {
    85  		t.Errorf("Read() returned %v; want %v", err, ErrNoFrameAdvance)
    86  	}
    87  }
    88  
    89  func TestReaderNextFrameAndReadEOF(t *testing.T) {
    90  	for _, test := range []struct {
    91  		source       func() io.Reader
    92  		nextFrameErr error
    93  		readErr      error
    94  	}{
    95  		{
    96  			source:       func() io.Reader { return eofReader },
    97  			nextFrameErr: io.EOF,
    98  			readErr:      ErrNoFrameAdvance,
    99  		},
   100  		{
   101  			source: func() io.Reader {
   102  				// This case tests that ReadMessage still fails after
   103  				// successfully reading header bytes frame via ws.ReadHeader()
   104  				// and non-successfully read of the body.
   105  				var buf bytes.Buffer
   106  				f := ws.NewTextFrame([]byte("this part will be lost"))
   107  				if err := ws.WriteHeader(&buf, f.Header); err != nil {
   108  					panic(err)
   109  				}
   110  				return &buf
   111  			},
   112  			nextFrameErr: nil,
   113  			readErr:      io.ErrUnexpectedEOF,
   114  		},
   115  		{
   116  			source: func() io.Reader {
   117  				var buf bytes.Buffer
   118  				f := ws.NewTextFrame([]byte("foobar"))
   119  				if err := ws.WriteHeader(&buf, f.Header); err != nil {
   120  					panic(err)
   121  				}
   122  				buf.WriteString("foo")
   123  				return &buf
   124  			},
   125  			nextFrameErr: nil,
   126  			readErr:      io.ErrUnexpectedEOF,
   127  		},
   128  		{
   129  			source: func() io.Reader {
   130  				var buf bytes.Buffer
   131  				f := ws.NewFrame(ws.OpText, false, []byte("payload"))
   132  				if err := ws.WriteFrame(&buf, f); err != nil {
   133  					panic(err)
   134  				}
   135  				return &buf
   136  			},
   137  			nextFrameErr: nil,
   138  			readErr:      io.ErrUnexpectedEOF,
   139  		},
   140  	} {
   141  		t.Run("", func(t *testing.T) {
   142  			r := Reader{
   143  				Source: test.source(),
   144  			}
   145  			_, err := r.NextFrame()
   146  			if err != test.nextFrameErr {
   147  				t.Errorf("NextFrame() = %v; want %v", err, test.nextFrameErr)
   148  			}
   149  			var (
   150  				p = make([]byte, 4096)
   151  				i = 0
   152  			)
   153  			for {
   154  				if i == 100 {
   155  					t.Fatal(io.ErrNoProgress)
   156  				}
   157  				_, err := r.Read(p)
   158  				if err == nil {
   159  					continue
   160  				}
   161  				if err != test.readErr {
   162  					t.Errorf("Read() = %v; want %v", err, test.readErr)
   163  				}
   164  				break
   165  			}
   166  		})
   167  	}
   168  
   169  }
   170  
   171  func TestReaderUTF8(t *testing.T) {
   172  	yo := []byte("Ё")
   173  	if !utf8.ValidString(string(yo)) {
   174  		panic("bad fixture")
   175  	}
   176  
   177  	var buf bytes.Buffer
   178  	ws.WriteFrame(&buf,
   179  		ws.NewFrame(ws.OpText, false, yo[:1]),
   180  	)
   181  	ws.WriteFrame(&buf,
   182  		ws.NewFrame(ws.OpContinuation, true, yo[1:]),
   183  	)
   184  
   185  	r := Reader{
   186  		Source:    &buf,
   187  		CheckUTF8: true,
   188  	}
   189  	if _, err := r.NextFrame(); err != nil {
   190  		t.Fatal(err)
   191  	}
   192  	bts, err := ioutil.ReadAll(&r)
   193  	if err != nil {
   194  		t.Errorf("unexpected error: %v", err)
   195  	}
   196  	if !bytes.Equal(bts, yo) {
   197  		t.Errorf("ReadAll(r) = %v; want %v", bts, yo)
   198  	}
   199  }
   200  
   201  func TestNextReader(t *testing.T) {
   202  	for _, test := range []struct {
   203  		name string
   204  		seq  []ws.Frame
   205  		chop int
   206  		exp  []byte
   207  		err  error
   208  	}{
   209  		{
   210  			name: "empty",
   211  			seq:  []ws.Frame{},
   212  			err:  io.EOF,
   213  		},
   214  		{
   215  			name: "single",
   216  			seq: []ws.Frame{
   217  				ws.NewTextFrame([]byte("Привет, Мир!")),
   218  			},
   219  			exp: []byte("Привет, Мир!"),
   220  		},
   221  		{
   222  			name: "single_masked",
   223  			seq: []ws.Frame{
   224  				ws.MaskFrame(ws.NewTextFrame([]byte("Привет, Мир!"))),
   225  			},
   226  			exp: []byte("Привет, Мир!"),
   227  		},
   228  		{
   229  			name: "fragmented",
   230  			seq: []ws.Frame{
   231  				ws.NewFrame(ws.OpText, false, []byte("Привет,")),
   232  				ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")),
   233  				ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")),
   234  				ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")),
   235  
   236  				ws.NewTextFrame([]byte("Hello, Brave New World!")),
   237  			},
   238  			exp: []byte("Привет, о дивный, новый Мир!"),
   239  		},
   240  		{
   241  			name: "fragmented_masked",
   242  			seq: []ws.Frame{
   243  				ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))),
   244  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))),
   245  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))),
   246  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))),
   247  
   248  				ws.MaskFrame(ws.NewTextFrame([]byte("Hello, Brave New World!"))),
   249  			},
   250  			exp: []byte("Привет, о дивный, новый Мир!"),
   251  		},
   252  		{
   253  			name: "fragmented_and_control",
   254  			seq: []ws.Frame{
   255  				ws.NewFrame(ws.OpText, false, []byte("Привет,")),
   256  				ws.NewFrame(ws.OpPing, true, nil),
   257  				ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")),
   258  				ws.NewFrame(ws.OpPing, true, nil),
   259  				ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")),
   260  				ws.NewFrame(ws.OpPing, true, nil),
   261  				ws.NewFrame(ws.OpPing, true, []byte("ping info")),
   262  				ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")),
   263  			},
   264  			exp: []byte("Привет, о дивный, новый Мир!"),
   265  		},
   266  		{
   267  			name: "fragmented_and_control_mask",
   268  			seq: []ws.Frame{
   269  				ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))),
   270  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   271  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))),
   272  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   273  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))),
   274  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)),
   275  				ws.MaskFrame(ws.NewFrame(ws.OpPing, true, []byte("ping info"))),
   276  				ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))),
   277  			},
   278  			exp: []byte("Привет, о дивный, новый Мир!"),
   279  		},
   280  	} {
   281  		t.Run(test.name, func(t *testing.T) {
   282  			// Prepare input.
   283  			buf := &bytes.Buffer{}
   284  			for _, f := range test.seq {
   285  				if err := ws.WriteFrame(buf, f); err != nil {
   286  					t.Fatal(err)
   287  				}
   288  			}
   289  
   290  			conn := &chopReader{
   291  				src: bytes.NewReader(buf.Bytes()),
   292  				sz:  test.chop,
   293  			}
   294  
   295  			var bts []byte
   296  			_, reader, err := NextReader(conn, 0)
   297  			if err == nil {
   298  				bts, err = ioutil.ReadAll(reader)
   299  			}
   300  			if err != test.err {
   301  				t.Errorf("unexpected error; got %v; want %v", err, test.err)
   302  				return
   303  			}
   304  			if test.err == nil && !bytes.Equal(bts, test.exp) {
   305  				t.Errorf(
   306  					"ReadAll from reader:\nact:\t%#x\nexp:\t%#x\nact:\t%s\nexp:\t%s\n",
   307  					bts, test.exp, string(bts), string(test.exp),
   308  				)
   309  			}
   310  		})
   311  	}
   312  }
   313  
   314  type chopReader struct {
   315  	src io.Reader
   316  	sz  int
   317  }
   318  
   319  func (c chopReader) Read(p []byte) (n int, err error) {
   320  	sz := c.sz
   321  	if sz == 0 {
   322  		sz = 1
   323  	}
   324  	if sz > len(p) {
   325  		sz = len(p)
   326  	}
   327  	return c.src.Read(p[:sz])
   328  }