github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/dialer.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  
    12  	"github.com/ezoic/ws"
    13  )
    14  
    15  // DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
    16  // handshake. That is, it gives ability to receive copied HTTP request and
    17  // response bytes that made inside Dialer.Dial().
    18  //
    19  // Note that it must not be used in production applications that requires
    20  // Dial() to be efficient.
    21  type DebugDialer struct {
    22  	// Dialer contains WebSocket connection establishment options.
    23  	Dialer ws.Dialer
    24  
    25  	// OnRequest and OnResponse are the callbacks that will be called with the
    26  	// HTTP request and response respectively.
    27  	OnRequest, OnResponse func([]byte)
    28  }
    29  
    30  // Dial connects to the url host and upgrades connection to WebSocket. It makes
    31  // it by calling d.Dialer.Dial().
    32  func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
    33  	// Need to copy Dialer to prevent original object mutation.
    34  	dialer := d.Dialer
    35  	var (
    36  		reqBuf bytes.Buffer
    37  		resBuf bytes.Buffer
    38  
    39  		resContentLength int64
    40  	)
    41  	userWrap := dialer.WrapConn
    42  	dialer.WrapConn = func(c net.Conn) net.Conn {
    43  		if userWrap != nil {
    44  			c = userWrap(c)
    45  		}
    46  
    47  		// Save the pointer to the raw connection.
    48  		conn = c
    49  
    50  		var (
    51  			r io.Reader = conn
    52  			w io.Writer = conn
    53  		)
    54  		if d.OnResponse != nil {
    55  			r = &prefetchResponseReader{
    56  				source:        conn,
    57  				buffer:        &resBuf,
    58  				contentLength: &resContentLength,
    59  			}
    60  		}
    61  		if d.OnRequest != nil {
    62  			w = io.MultiWriter(conn, &reqBuf)
    63  		}
    64  		return rwConn{conn, r, w}
    65  	}
    66  
    67  	_, br, hs, err = dialer.Dial(ctx, urlstr)
    68  
    69  	if onRequest := d.OnRequest; onRequest != nil {
    70  		onRequest(reqBuf.Bytes())
    71  	}
    72  	if onResponse := d.OnResponse; onResponse != nil {
    73  		// We must split response inside buffered bytes from other received
    74  		// bytes from server.
    75  		p := resBuf.Bytes()
    76  		n := bytes.Index(p, headEnd)
    77  		h := n + len(headEnd)         // Head end index.
    78  		n = h + int(resContentLength) // Body end index.
    79  
    80  		onResponse(p[:n])
    81  
    82  		if br != nil {
    83  			// If br is non-nil, then it mean two things. First is that
    84  			// handshake is OK and server has sent additional bytes – probably
    85  			// immediate sent frames (or weird but possible response body).
    86  			// Second, the bad one, is that br buffer's source is now rwConn
    87  			// instance from above WrapConn call. It is incorrect, so we must
    88  			// fix it.
    89  			var r io.Reader = conn
    90  			if len(p) > h {
    91  				// Buffer contains more than just HTTP headers bytes.
    92  				r = io.MultiReader(
    93  					bytes.NewReader(p[h:]),
    94  					conn,
    95  				)
    96  			}
    97  			br.Reset(r)
    98  			// Must make br.Buffered() to be non-zero.
    99  			br.Peek(len(p[h:]))
   100  		}
   101  	}
   102  
   103  	return conn, br, hs, err
   104  }
   105  
   106  type rwConn struct {
   107  	net.Conn
   108  
   109  	r io.Reader
   110  	w io.Writer
   111  }
   112  
   113  func (rwc rwConn) Read(p []byte) (int, error) {
   114  	return rwc.r.Read(p)
   115  }
   116  func (rwc rwConn) Write(p []byte) (int, error) {
   117  	return rwc.w.Write(p)
   118  }
   119  
   120  var headEnd = []byte("\r\n\r\n")
   121  
   122  type prefetchResponseReader struct {
   123  	source io.Reader // Original connection source.
   124  	reader io.Reader // Wrapped reader used to read from by clients.
   125  	buffer *bytes.Buffer
   126  
   127  	contentLength *int64
   128  }
   129  
   130  func (r *prefetchResponseReader) Read(p []byte) (int, error) {
   131  	if r.reader == nil {
   132  		resp, err := http.ReadResponse(bufio.NewReader(
   133  			io.TeeReader(r.source, r.buffer),
   134  		), nil)
   135  		if err == nil {
   136  			*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
   137  			resp.Body.Close()
   138  		}
   139  		bts := r.buffer.Bytes()
   140  		r.reader = io.MultiReader(
   141  			bytes.NewReader(bts),
   142  			r.source,
   143  		)
   144  	}
   145  	return r.reader.Read(p)
   146  }