github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsflate/writer_test.go (about)

     1  package wsflate
     2  
     3  import (
     4  	"bytes"
     5  	"compress/flate"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/url"
    11  	"testing"
    12  
    13  	"github.com/gobwas/httphead"
    14  	"github.com/simonmittag/ws"
    15  )
    16  
    17  func TestWriter(t *testing.T) {
    18  	var buf bytes.Buffer
    19  	w := NewWriter(&buf, func(w io.Writer) Compressor {
    20  		fw, _ := flate.NewWriter(w, 9)
    21  		return fw
    22  	})
    23  	data := []byte("hello, flate!")
    24  	for _, p := range bytes.SplitAfter(data, []byte{','}) {
    25  		w.Write(p)
    26  		w.Flush()
    27  	}
    28  	if err := w.Close(); err != nil {
    29  		t.Fatalf("unexpected Close() error: %v", err)
    30  	}
    31  	if err := w.Err(); err != nil {
    32  		t.Fatalf("unexpected Writer error: %v", err)
    33  	}
    34  
    35  	r := NewReader(&buf, func(r io.Reader) Decompressor {
    36  		return flate.NewReader(r)
    37  	})
    38  	act, err := ioutil.ReadAll(r)
    39  	if err != nil {
    40  		t.Fatalf("unexpected Reader error: %v", err)
    41  	}
    42  	if exp := data; !bytes.Equal(act, exp) {
    43  		t.Fatalf("unexpected bytes: %#q; want %#q", act, exp)
    44  	}
    45  }
    46  
    47  func TestExtensionNegotiation(t *testing.T) {
    48  	client, server := net.Pipe()
    49  
    50  	done := make(chan error)
    51  	go func() {
    52  		defer close(done)
    53  		var (
    54  			req bytes.Buffer
    55  			res bytes.Buffer
    56  		)
    57  		conn := struct {
    58  			io.Reader
    59  			io.Writer
    60  		}{
    61  			io.TeeReader(server, &req),
    62  			io.MultiWriter(server, &res),
    63  		}
    64  		e := Extension{
    65  			Parameters: Parameters{
    66  				ServerNoContextTakeover: true,
    67  				ClientNoContextTakeover: true,
    68  			},
    69  		}
    70  		u := ws.Upgrader{
    71  			Negotiate: e.Negotiate,
    72  		}
    73  		hs, err := u.Upgrade(&conn)
    74  		if err != nil {
    75  			done <- err
    76  			return
    77  		}
    78  
    79  		p, ok := e.Accepted()
    80  		t.Logf("accepted: %t %+v", ok, p)
    81  
    82  		fmt.Println(req.String())
    83  		fmt.Println(res.String())
    84  		t.Logf("server: %+v", hs)
    85  	}()
    86  
    87  	d := ws.Dialer{
    88  		Extensions: []httphead.Option{
    89  			(Parameters{
    90  				ServerNoContextTakeover: true,
    91  				ClientNoContextTakeover: true,
    92  				ClientMaxWindowBits:     8,
    93  				ServerMaxWindowBits:     10,
    94  			}).Option(),
    95  			(Parameters{
    96  				ClientMaxWindowBits: 1,
    97  			}).Option(),
    98  			(Parameters{}).Option(),
    99  		},
   100  	}
   101  
   102  	uri, err := url.Parse("ws://example.com")
   103  	if err != nil {
   104  		t.Fatal(err)
   105  	}
   106  	_, hs, err := d.Upgrade(client, uri)
   107  	if err != nil {
   108  		t.Fatalf("client: %v", err)
   109  	}
   110  	if n := len(hs.Extensions); n != 1 {
   111  		t.Fatalf("unexpected number of accepted extensions: %d", n)
   112  	}
   113  	var p Parameters
   114  	if err := p.Parse(hs.Extensions[0]); err != nil {
   115  		t.Fatalf("parse extension error: %v", err)
   116  	}
   117  	t.Logf("client params: %+v", p)
   118  	if err := <-done; err != nil {
   119  		t.Fatalf("server Upgrade() error: %v", err)
   120  	}
   121  }