github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/dialer_test.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  	"testing"
    12  
    13  	"github.com/simonmittag/ws"
    14  )
    15  
    16  var bg = context.Background()
    17  
    18  func TestDebugDialer(t *testing.T) {
    19  	for _, test := range []struct {
    20  		name string
    21  		resp *http.Response
    22  		body []byte
    23  		err  error
    24  	}{
    25  		{
    26  			name: "base",
    27  		},
    28  		{
    29  			name: "base with footer",
    30  			body: []byte("hello, additional bytes!"),
    31  		},
    32  		{
    33  			name: "fail",
    34  			resp: &http.Response{
    35  				StatusCode: 101,
    36  				ProtoMajor: 1,
    37  				ProtoMinor: 1,
    38  			},
    39  			err: ws.ErrHandshakeBadUpgrade,
    40  		},
    41  		{
    42  			name: "fail",
    43  			resp: &http.Response{
    44  				StatusCode: 400,
    45  				ProtoMajor: 42,
    46  				ProtoMinor: 1,
    47  			},
    48  			err: ws.ErrHandshakeBadProtocol,
    49  		},
    50  		{
    51  			name: "fail",
    52  			resp: &http.Response{
    53  				StatusCode: 400,
    54  				ProtoMajor: 1,
    55  				ProtoMinor: 1,
    56  			},
    57  			err: ws.StatusError(400),
    58  		},
    59  		{
    60  			name: "fail footer",
    61  			resp: &http.Response{
    62  				StatusCode: 400,
    63  				ProtoMajor: 1,
    64  				ProtoMinor: 1,
    65  			},
    66  			err: ws.StatusError(400),
    67  		},
    68  
    69  		{
    70  			name: "big response",
    71  			// This test expects that even when server sent unsuccessful
    72  			// response with body that does not fit to Dialer read buffer,
    73  			// OnResponse will still be called with full response bytes.
    74  			resp: &http.Response{
    75  				StatusCode: 200,
    76  				ProtoMajor: 1,
    77  				ProtoMinor: 1,
    78  				Body: ioutil.NopCloser(bytes.NewReader(
    79  					bytes.Repeat([]byte("x"), 5000),
    80  				)),
    81  				ContentLength: 5000,
    82  			},
    83  			// Additional data sent. We expect it will not be shown in
    84  			// OnResponse.
    85  			body: bytes.Repeat([]byte("y"), 1000),
    86  			err:  ws.StatusError(200),
    87  		},
    88  	} {
    89  		t.Run(test.name, func(t *testing.T) {
    90  			client, server := net.Pipe()
    91  
    92  			var (
    93  				actReq, actRes []byte
    94  				expReq, expRes []byte
    95  			)
    96  			dd := DebugDialer{
    97  				Dialer: ws.Dialer{
    98  					NetDial: func(_ context.Context, _, _ string) (net.Conn, error) {
    99  						return client, nil
   100  					},
   101  				},
   102  				OnRequest:  func(p []byte) { actReq = p },
   103  				OnResponse: func(p []byte) { actRes = p },
   104  			}
   105  			go func() {
   106  				var (
   107  					reqBuf bytes.Buffer
   108  					resBuf bytes.Buffer
   109  				)
   110  				var (
   111  					tr = io.TeeReader(server, &reqBuf)
   112  					bw = bufio.NewWriterSize(server, 65536)
   113  					mw = io.MultiWriter(bw, &resBuf)
   114  				)
   115  				conn := struct {
   116  					io.Reader
   117  					io.Writer
   118  				}{
   119  					tr, mw,
   120  				}
   121  				if test.resp == nil {
   122  					_, err := ws.Upgrade(conn)
   123  					if err != nil {
   124  						t.Fatal(err)
   125  					}
   126  				} else {
   127  					if _, err := http.ReadRequest(bufio.NewReader(conn)); err != nil {
   128  						t.Fatal(err)
   129  					}
   130  					if err := test.resp.Write(conn); err != nil {
   131  						t.Fatal(err)
   132  					}
   133  				}
   134  
   135  				expReq = reqBuf.Bytes()
   136  				expRes = resBuf.Bytes()
   137  
   138  				if test.body != nil {
   139  					bw.Write(test.body)
   140  				}
   141  				bw.Flush()
   142  				server.Close()
   143  			}()
   144  
   145  			conn, br, _, err := dd.Dial(bg, "ws://stub")
   146  			if err != test.err {
   147  				t.Fatalf("unexpected error: %v; want %v", err, test.err)
   148  			}
   149  			if conn != client {
   150  				t.Errorf("returned connection is non raw")
   151  			}
   152  			if br != nil {
   153  				body, err := ioutil.ReadAll(br)
   154  				if err != nil {
   155  					t.Fatal(err)
   156  				}
   157  				if !bytes.Equal(body, test.body) {
   158  					t.Errorf("unexpected buffered body: %q; want %q", body, test.body)
   159  				}
   160  			}
   161  			if !bytes.Equal(actReq, expReq) {
   162  				t.Errorf(
   163  					"unexpected request bytes:\nact %d bytes:\n%s\nexp %d bytes:\n%s\n",
   164  					len(actReq), actReq, len(expReq), expReq,
   165  				)
   166  			}
   167  			if !bytes.Equal(actRes, expRes) {
   168  				t.Errorf(
   169  					"unexpected response bytes:\nact %d bytes:\n%s\nexp %d bytes:\n%s\n",
   170  					len(actRes), actRes, len(expRes), expRes,
   171  				)
   172  			}
   173  		})
   174  	}
   175  }