github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/dialer.go (about) 1 package wsutil 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "io" 8 "io/ioutil" 9 "net" 10 "net/http" 11 12 "github.com/ezoic/ws" 13 ) 14 15 // DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket 16 // handshake. That is, it gives ability to receive copied HTTP request and 17 // response bytes that made inside Dialer.Dial(). 18 // 19 // Note that it must not be used in production applications that requires 20 // Dial() to be efficient. 21 type DebugDialer struct { 22 // Dialer contains WebSocket connection establishment options. 23 Dialer ws.Dialer 24 25 // OnRequest and OnResponse are the callbacks that will be called with the 26 // HTTP request and response respectively. 27 OnRequest, OnResponse func([]byte) 28 } 29 30 // Dial connects to the url host and upgrades connection to WebSocket. It makes 31 // it by calling d.Dialer.Dial(). 32 func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) { 33 // Need to copy Dialer to prevent original object mutation. 34 dialer := d.Dialer 35 var ( 36 reqBuf bytes.Buffer 37 resBuf bytes.Buffer 38 39 resContentLength int64 40 ) 41 userWrap := dialer.WrapConn 42 dialer.WrapConn = func(c net.Conn) net.Conn { 43 if userWrap != nil { 44 c = userWrap(c) 45 } 46 47 // Save the pointer to the raw connection. 48 conn = c 49 50 var ( 51 r io.Reader = conn 52 w io.Writer = conn 53 ) 54 if d.OnResponse != nil { 55 r = &prefetchResponseReader{ 56 source: conn, 57 buffer: &resBuf, 58 contentLength: &resContentLength, 59 } 60 } 61 if d.OnRequest != nil { 62 w = io.MultiWriter(conn, &reqBuf) 63 } 64 return rwConn{conn, r, w} 65 } 66 67 _, br, hs, err = dialer.Dial(ctx, urlstr) 68 69 if onRequest := d.OnRequest; onRequest != nil { 70 onRequest(reqBuf.Bytes()) 71 } 72 if onResponse := d.OnResponse; onResponse != nil { 73 // We must split response inside buffered bytes from other received 74 // bytes from server. 75 p := resBuf.Bytes() 76 n := bytes.Index(p, headEnd) 77 h := n + len(headEnd) // Head end index. 78 n = h + int(resContentLength) // Body end index. 79 80 onResponse(p[:n]) 81 82 if br != nil { 83 // If br is non-nil, then it mean two things. First is that 84 // handshake is OK and server has sent additional bytes – probably 85 // immediate sent frames (or weird but possible response body). 86 // Second, the bad one, is that br buffer's source is now rwConn 87 // instance from above WrapConn call. It is incorrect, so we must 88 // fix it. 89 var r io.Reader = conn 90 if len(p) > h { 91 // Buffer contains more than just HTTP headers bytes. 92 r = io.MultiReader( 93 bytes.NewReader(p[h:]), 94 conn, 95 ) 96 } 97 br.Reset(r) 98 // Must make br.Buffered() to be non-zero. 99 br.Peek(len(p[h:])) 100 } 101 } 102 103 return conn, br, hs, err 104 } 105 106 type rwConn struct { 107 net.Conn 108 109 r io.Reader 110 w io.Writer 111 } 112 113 func (rwc rwConn) Read(p []byte) (int, error) { 114 return rwc.r.Read(p) 115 } 116 func (rwc rwConn) Write(p []byte) (int, error) { 117 return rwc.w.Write(p) 118 } 119 120 var headEnd = []byte("\r\n\r\n") 121 122 type prefetchResponseReader struct { 123 source io.Reader // Original connection source. 124 reader io.Reader // Wrapped reader used to read from by clients. 125 buffer *bytes.Buffer 126 127 contentLength *int64 128 } 129 130 func (r *prefetchResponseReader) Read(p []byte) (int, error) { 131 if r.reader == nil { 132 resp, err := http.ReadResponse(bufio.NewReader( 133 io.TeeReader(r.source, r.buffer), 134 ), nil) 135 if err == nil { 136 *r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body) 137 resp.Body.Close() 138 } 139 bts := r.buffer.Bytes() 140 r.reader = io.MultiReader( 141 bytes.NewReader(bts), 142 r.source, 143 ) 144 } 145 return r.reader.Read(p) 146 }