github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/write.go (about) 1 // +build !js 2 3 package websocket 4 5 import ( 6 "bufio" 7 "context" 8 "crypto/rand" 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "io" 13 "time" 14 15 "github.com/klauspost/compress/flate" 16 17 "nhooyr.io/websocket/internal/errd" 18 ) 19 20 // Writer returns a writer bounded by the context that will write 21 // a WebSocket message of type dataType to the connection. 22 // 23 // You must close the writer once you have written the entire message. 24 // 25 // Only one writer can be open at a time, multiple calls will block until the previous writer 26 // is closed. 27 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { 28 w, err := c.writer(ctx, typ) 29 if err != nil { 30 return nil, fmt.Errorf("failed to get writer: %w", err) 31 } 32 return w, nil 33 } 34 35 // Write writes a message to the connection. 36 // 37 // See the Writer method if you want to stream a message. 38 // 39 // If compression is disabled or the threshold is not met, then it 40 // will write the message in a single frame. 41 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { 42 _, err := c.write(ctx, typ, p) 43 if err != nil { 44 return fmt.Errorf("failed to write msg: %w", err) 45 } 46 return nil 47 } 48 49 type msgWriter struct { 50 mw *msgWriterState 51 closed bool 52 } 53 54 func (mw *msgWriter) Write(p []byte) (int, error) { 55 if mw.closed { 56 return 0, errors.New("cannot use closed writer") 57 } 58 return mw.mw.Write(p) 59 } 60 61 func (mw *msgWriter) Close() error { 62 if mw.closed { 63 return errors.New("cannot use closed writer") 64 } 65 mw.closed = true 66 return mw.mw.Close() 67 } 68 69 type msgWriterState struct { 70 c *Conn 71 72 mu *mu 73 writeMu *mu 74 75 ctx context.Context 76 opcode opcode 77 flate bool 78 79 trimWriter *trimLastFourBytesWriter 80 dict slidingWindow 81 } 82 83 func newMsgWriterState(c *Conn) *msgWriterState { 84 mw := &msgWriterState{ 85 c: c, 86 mu: newMu(c), 87 writeMu: newMu(c), 88 } 89 return mw 90 } 91 92 func (mw *msgWriterState) ensureFlate() { 93 if mw.trimWriter == nil { 94 mw.trimWriter = &trimLastFourBytesWriter{ 95 w: writerFunc(mw.write), 96 } 97 } 98 99 mw.dict.init(8192) 100 mw.flate = true 101 } 102 103 func (mw *msgWriterState) flateContextTakeover() bool { 104 if mw.c.client { 105 return !mw.c.copts.clientNoContextTakeover 106 } 107 return !mw.c.copts.serverNoContextTakeover 108 } 109 110 func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { 111 err := c.msgWriterState.reset(ctx, typ) 112 if err != nil { 113 return nil, err 114 } 115 return &msgWriter{ 116 mw: c.msgWriterState, 117 closed: false, 118 }, nil 119 } 120 121 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { 122 mw, err := c.writer(ctx, typ) 123 if err != nil { 124 return 0, err 125 } 126 127 if !c.flate() { 128 defer c.msgWriterState.mu.unlock() 129 return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) 130 } 131 132 n, err := mw.Write(p) 133 if err != nil { 134 return n, err 135 } 136 137 err = mw.Close() 138 return n, err 139 } 140 141 func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { 142 err := mw.mu.lock(ctx) 143 if err != nil { 144 return err 145 } 146 147 mw.ctx = ctx 148 mw.opcode = opcode(typ) 149 mw.flate = false 150 151 mw.trimWriter.reset() 152 153 return nil 154 } 155 156 // Write writes the given bytes to the WebSocket connection. 157 func (mw *msgWriterState) Write(p []byte) (_ int, err error) { 158 err = mw.writeMu.lock(mw.ctx) 159 if err != nil { 160 return 0, fmt.Errorf("failed to write: %w", err) 161 } 162 defer mw.writeMu.unlock() 163 164 defer func() { 165 if err != nil { 166 err = fmt.Errorf("failed to write: %w", err) 167 mw.c.close(err) 168 } 169 }() 170 171 if mw.c.flate() { 172 // Only enables flate if the length crosses the 173 // threshold on the first frame 174 if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { 175 mw.ensureFlate() 176 } 177 } 178 179 if mw.flate { 180 err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) 181 if err != nil { 182 return 0, err 183 } 184 mw.dict.write(p) 185 return len(p), nil 186 } 187 188 return mw.write(p) 189 } 190 191 func (mw *msgWriterState) write(p []byte) (int, error) { 192 n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) 193 if err != nil { 194 return n, fmt.Errorf("failed to write data frame: %w", err) 195 } 196 mw.opcode = opContinuation 197 return n, nil 198 } 199 200 // Close flushes the frame to the connection. 201 func (mw *msgWriterState) Close() (err error) { 202 defer errd.Wrap(&err, "failed to close writer") 203 204 err = mw.writeMu.lock(mw.ctx) 205 if err != nil { 206 return err 207 } 208 defer mw.writeMu.unlock() 209 210 _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) 211 if err != nil { 212 return fmt.Errorf("failed to write fin frame: %w", err) 213 } 214 215 if mw.flate && !mw.flateContextTakeover() { 216 mw.dict.close() 217 } 218 mw.mu.unlock() 219 return nil 220 } 221 222 func (mw *msgWriterState) close() { 223 if mw.c.client { 224 mw.c.writeFrameMu.forceLock() 225 putBufioWriter(mw.c.bw) 226 } 227 228 mw.writeMu.forceLock() 229 mw.dict.close() 230 } 231 232 func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { 233 ctx, cancel := context.WithTimeout(ctx, time.Second*5) 234 defer cancel() 235 236 _, err := c.writeFrame(ctx, true, false, opcode, p) 237 if err != nil { 238 return fmt.Errorf("failed to write control frame %v: %w", opcode, err) 239 } 240 return nil 241 } 242 243 // frame handles all writes to the connection. 244 func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { 245 err = c.writeFrameMu.lock(ctx) 246 if err != nil { 247 return 0, err 248 } 249 defer c.writeFrameMu.unlock() 250 251 // If the state says a close has already been written, we wait until 252 // the connection is closed and return that error. 253 // 254 // However, if the frame being written is a close, that means its the close from 255 // the state being set so we let it go through. 256 c.closeMu.Lock() 257 wroteClose := c.wroteClose 258 c.closeMu.Unlock() 259 if wroteClose && opcode != opClose { 260 select { 261 case <-ctx.Done(): 262 return 0, ctx.Err() 263 case <-c.closed: 264 return 0, c.closeErr 265 } 266 } 267 268 select { 269 case <-c.closed: 270 return 0, c.closeErr 271 case c.writeTimeout <- ctx: 272 } 273 274 defer func() { 275 if err != nil { 276 select { 277 case <-c.closed: 278 err = c.closeErr 279 case <-ctx.Done(): 280 err = ctx.Err() 281 } 282 c.close(err) 283 err = fmt.Errorf("failed to write frame: %w", err) 284 } 285 }() 286 287 c.writeHeader.fin = fin 288 c.writeHeader.opcode = opcode 289 c.writeHeader.payloadLength = int64(len(p)) 290 291 if c.client { 292 c.writeHeader.masked = true 293 _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) 294 if err != nil { 295 return 0, fmt.Errorf("failed to generate masking key: %w", err) 296 } 297 c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) 298 } 299 300 c.writeHeader.rsv1 = false 301 if flate && (opcode == opText || opcode == opBinary) { 302 c.writeHeader.rsv1 = true 303 } 304 305 err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) 306 if err != nil { 307 return 0, err 308 } 309 310 n, err := c.writeFramePayload(p) 311 if err != nil { 312 return n, err 313 } 314 315 if c.writeHeader.fin { 316 err = c.bw.Flush() 317 if err != nil { 318 return n, fmt.Errorf("failed to flush: %w", err) 319 } 320 } 321 322 select { 323 case <-c.closed: 324 return n, c.closeErr 325 case c.writeTimeout <- context.Background(): 326 } 327 328 return n, nil 329 } 330 331 func (c *Conn) writeFramePayload(p []byte) (n int, err error) { 332 defer errd.Wrap(&err, "failed to write frame payload") 333 334 if !c.writeHeader.masked { 335 return c.bw.Write(p) 336 } 337 338 maskKey := c.writeHeader.maskKey 339 for len(p) > 0 { 340 // If the buffer is full, we need to flush. 341 if c.bw.Available() == 0 { 342 err = c.bw.Flush() 343 if err != nil { 344 return n, err 345 } 346 } 347 348 // Start of next write in the buffer. 349 i := c.bw.Buffered() 350 351 j := len(p) 352 if j > c.bw.Available() { 353 j = c.bw.Available() 354 } 355 356 _, err := c.bw.Write(p[:j]) 357 if err != nil { 358 return n, err 359 } 360 361 maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) 362 363 p = p[j:] 364 n += j 365 } 366 367 return n, nil 368 } 369 370 type writerFunc func(p []byte) (int, error) 371 372 func (f writerFunc) Write(p []byte) (int, error) { 373 return f(p) 374 } 375 376 // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer 377 // and returns it. 378 func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { 379 var writeBuf []byte 380 bw.Reset(writerFunc(func(p2 []byte) (int, error) { 381 writeBuf = p2[:cap(p2)] 382 return len(p2), nil 383 })) 384 385 bw.WriteByte(0) 386 bw.Flush() 387 388 bw.Reset(w) 389 390 return writeBuf 391 } 392 393 func (c *Conn) writeError(code StatusCode, err error) { 394 c.setCloseErr(err) 395 c.writeClose(code, err.Error()) 396 c.close(nil) 397 }