github.com/docker/docker@v299999999.0.0-20200612211812-aaf470eca7b5+incompatible/client/hijack_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"testing"
    12  
    13  	"github.com/docker/docker/api/server/httputils"
    14  	"github.com/docker/docker/api/types"
    15  	"github.com/pkg/errors"
    16  	"gotest.tools/v3/assert"
    17  )
    18  
    19  func TestTLSCloseWriter(t *testing.T) {
    20  	t.Parallel()
    21  
    22  	var chErr chan error
    23  	ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    24  		chErr = make(chan error, 1)
    25  		defer close(chErr)
    26  		if err := httputils.ParseForm(req); err != nil {
    27  			chErr <- errors.Wrap(err, "error parsing form")
    28  			http.Error(w, err.Error(), http.StatusInternalServerError)
    29  			return
    30  		}
    31  		r, rw, err := httputils.HijackConnection(w)
    32  		if err != nil {
    33  			chErr <- errors.Wrap(err, "error hijacking connection")
    34  			http.Error(w, err.Error(), http.StatusInternalServerError)
    35  			return
    36  		}
    37  		defer r.Close()
    38  
    39  		fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
    40  
    41  		buf := make([]byte, 5)
    42  		_, err = r.Read(buf)
    43  		if err != nil {
    44  			chErr <- errors.Wrap(err, "error reading from client")
    45  			return
    46  		}
    47  		_, err = rw.Write(buf)
    48  		if err != nil {
    49  			chErr <- errors.Wrap(err, "error writing to client")
    50  			return
    51  		}
    52  	})}}
    53  
    54  	var (
    55  		l   net.Listener
    56  		err error
    57  	)
    58  	for i := 1024; i < 10000; i++ {
    59  		l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
    60  		if err == nil {
    61  			break
    62  		}
    63  	}
    64  	assert.NilError(t, err)
    65  
    66  	ts.Listener = l
    67  	defer l.Close()
    68  
    69  	defer func() {
    70  		if chErr != nil {
    71  			assert.Assert(t, <-chErr)
    72  		}
    73  	}()
    74  
    75  	ts.StartTLS()
    76  	defer ts.Close()
    77  
    78  	serverURL, err := url.Parse(ts.URL)
    79  	assert.NilError(t, err)
    80  
    81  	client, err := NewClientWithOpts(WithHost("tcp://"+serverURL.Host), WithHTTPClient(ts.Client()))
    82  	assert.NilError(t, err)
    83  
    84  	resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
    85  	assert.NilError(t, err)
    86  	defer resp.Close()
    87  
    88  	if _, ok := resp.Conn.(types.CloseWriter); !ok {
    89  		t.Fatal("tls conn did not implement the CloseWrite interface")
    90  	}
    91  
    92  	_, err = resp.Conn.Write([]byte("hello"))
    93  	assert.NilError(t, err)
    94  
    95  	b, err := ioutil.ReadAll(resp.Reader)
    96  	assert.NilError(t, err)
    97  	assert.Assert(t, string(b) == "hello")
    98  	assert.Assert(t, resp.CloseWrite())
    99  
   100  	// This should error since writes are closed
   101  	_, err = resp.Conn.Write([]byte("no"))
   102  	assert.Assert(t, err != nil)
   103  }