github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raywebsocket/conn.go (about) 1 package v2raywebsocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "io" 7 "net" 8 "net/http" 9 "os" 10 "time" 11 12 C "github.com/inazumav/sing-box/constant" 13 "github.com/sagernet/sing/common" 14 "github.com/sagernet/sing/common/buf" 15 E "github.com/sagernet/sing/common/exceptions" 16 "github.com/sagernet/websocket" 17 ) 18 19 type WebsocketConn struct { 20 *websocket.Conn 21 *Writer 22 remoteAddr net.Addr 23 reader io.Reader 24 } 25 26 func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn { 27 return &WebsocketConn{ 28 Conn: wsConn, 29 remoteAddr: remoteAddr, 30 Writer: NewWriter(wsConn, true), 31 } 32 } 33 34 func (c *WebsocketConn) Close() error { 35 err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout)) 36 if err != nil { 37 return c.Conn.Close() 38 } 39 return nil 40 } 41 42 func (c *WebsocketConn) Read(b []byte) (n int, err error) { 43 for { 44 if c.reader == nil { 45 _, c.reader, err = c.NextReader() 46 if err != nil { 47 err = wrapError(err) 48 return 49 } 50 } 51 n, err = c.reader.Read(b) 52 if E.IsMulti(err, io.EOF) { 53 c.reader = nil 54 continue 55 } 56 err = wrapError(err) 57 return 58 } 59 } 60 61 func (c *WebsocketConn) RemoteAddr() net.Addr { 62 if c.remoteAddr != nil { 63 return c.remoteAddr 64 } 65 return c.Conn.RemoteAddr() 66 } 67 68 func (c *WebsocketConn) SetDeadline(t time.Time) error { 69 return os.ErrInvalid 70 } 71 72 func (c *WebsocketConn) SetReadDeadline(t time.Time) error { 73 return os.ErrInvalid 74 } 75 76 func (c *WebsocketConn) SetWriteDeadline(t time.Time) error { 77 return os.ErrInvalid 78 } 79 80 func (c *WebsocketConn) NeedAdditionalReadDeadline() bool { 81 return true 82 } 83 84 func (c *WebsocketConn) Upstream() any { 85 return c.Conn.NetConn() 86 } 87 88 func (c *WebsocketConn) UpstreamWriter() any { 89 return c.Writer 90 } 91 92 type EarlyWebsocketConn struct { 93 *Client 94 ctx context.Context 95 conn *WebsocketConn 96 create chan struct{} 97 err error 98 } 99 100 func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { 101 if c.conn == nil { 102 <-c.create 103 if c.err != nil { 104 return 0, c.err 105 } 106 } 107 return c.conn.Read(b) 108 } 109 110 func (c *EarlyWebsocketConn) writeRequest(content []byte) error { 111 var ( 112 earlyData []byte 113 lateData []byte 114 conn *websocket.Conn 115 response *http.Response 116 err error 117 ) 118 if len(content) > int(c.maxEarlyData) { 119 earlyData = content[:c.maxEarlyData] 120 lateData = content[c.maxEarlyData:] 121 } else { 122 earlyData = content 123 } 124 if len(earlyData) > 0 { 125 earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) 126 if c.earlyDataHeaderName == "" { 127 requestURL := c.requestURL 128 requestURL.Path += earlyDataString 129 conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers) 130 } else { 131 headers := c.headers.Clone() 132 headers.Set(c.earlyDataHeaderName, earlyDataString) 133 conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers) 134 } 135 } else { 136 conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers) 137 } 138 if err != nil { 139 return wrapDialError(response, err) 140 } 141 c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)} 142 if len(lateData) > 0 { 143 _, err = c.conn.Write(lateData) 144 } 145 return err 146 } 147 148 func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { 149 if c.conn != nil { 150 return c.conn.Write(b) 151 } 152 err = c.writeRequest(b) 153 c.err = err 154 close(c.create) 155 if err != nil { 156 return 157 } 158 return len(b), nil 159 } 160 161 func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { 162 if c.conn != nil { 163 return c.conn.WriteBuffer(buffer) 164 } 165 err := c.writeRequest(buffer.Bytes()) 166 c.err = err 167 close(c.create) 168 return err 169 } 170 171 func (c *EarlyWebsocketConn) Close() error { 172 if c.conn == nil { 173 return nil 174 } 175 return c.conn.Close() 176 } 177 178 func (c *EarlyWebsocketConn) LocalAddr() net.Addr { 179 if c.conn == nil { 180 return nil 181 } 182 return c.conn.LocalAddr() 183 } 184 185 func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { 186 if c.conn == nil { 187 return nil 188 } 189 return c.conn.RemoteAddr() 190 } 191 192 func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { 193 return os.ErrInvalid 194 } 195 196 func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error { 197 return os.ErrInvalid 198 } 199 200 func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error { 201 return os.ErrInvalid 202 } 203 204 func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool { 205 return true 206 } 207 208 func (c *EarlyWebsocketConn) Upstream() any { 209 return common.PtrOrNil(c.conn) 210 } 211 212 func (c *EarlyWebsocketConn) LazyHeadroom() bool { 213 return c.conn == nil 214 } 215 216 func wrapError(err error) error { 217 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 218 return io.EOF 219 } 220 if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { 221 return net.ErrClosed 222 } 223 return err 224 }