github.com/cantara/gober@v0.18.8/websocket/server.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "encoding/binary" 6 "errors" 7 log "github.com/cantara/bragi/sbragi" 8 "github.com/gin-gonic/gin" 9 "github.com/gobwas/ws" 10 jsoniter "github.com/json-iterator/go" 11 "io" 12 "net" 13 "reflect" 14 "sync" 15 "time" 16 ) 17 18 var json = jsoniter.ConfigDefault 19 20 var BufferSize = 100 21 22 func Serve[T any](r *gin.RouterGroup, path string, acceptFunc func(c *gin.Context) bool, wsfunc WSHandler[T]) { 23 r.GET(path, func(c *gin.Context) { 24 if acceptFunc != nil && !acceptFunc(c) { 25 return //Could be smart to have some check of weather or not the statuscode code has been set. 26 } 27 conn, _, _, err := ws.UpgradeHTTP(c.Request, c.Writer) 28 if err != nil { 29 log.WithError(err).Fatal("while accepting websocket", "request", c.Request) 30 } 31 ctx, cancel := context.WithCancel(c.Request.Context()) 32 defer cancel() 33 clientClosed := false 34 reader := make(chan T, BufferSize) 35 writer := make(chan Write[T], BufferSize) 36 tick := time.Second * 50 37 sucker := webSucker[T]{ 38 pingTimout: tick, 39 pingTicker: time.NewTicker(tick), 40 writeLock: sync.Mutex{}, 41 conn: conn, 42 } 43 /* 44 connWriter := make(chan []byte, 1) 45 go func() { 46 defer func() { 47 if !clientClosed { 48 err = ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "writer closed"))) 49 log.WithError(err).Info("writing client websocket close frame") 50 } 51 log.WithError(conn.Close()).Info("closing client net conn") 52 }() 53 tickD := time.Second * 50 54 tkr := time.NewTicker(tickD) 55 defer tkr.Stop() 56 for { 57 select { 58 case write, ok := <-connWriter: 59 if !ok { 60 return 61 } 62 n, err := conn.Write(write) 63 total := n 64 for err == nil && total < len(write) { 65 n, err = conn.Write(write[total:]) 66 total += n 67 } 68 if err != nil { 69 log.WithError(err).Error("while writing to websocket", "path", path, "type", reflect.TypeOf(write).String(), "data", write) // This could end up logging person sensitive data. 70 return 71 } 72 tkr.Reset(tickD) 73 case <-tkr.C: 74 connWriter <- ws.CompiledPing 75 log.WithError(err).Info("wrote ping from server") 76 } 77 } 78 }() 79 */ 80 go func() { 81 defer func() { 82 if !clientClosed { 83 err = ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "writer closed"))) 84 log.WithError(err).Info("writing client websocket close frame") 85 } 86 log.WithError(conn.Close()).Info("closing client net conn") 87 }() 88 for { 89 select { 90 case <-ctx.Done(): 91 return 92 case write, ok := <-writer: 93 if !ok { 94 return 95 } 96 //err := WriteWebsocket[T](connWriter, write) 97 err := sucker.Write(write) 98 if err != nil { 99 if errors.Is(err, net.ErrClosed) { 100 clientClosed = true 101 cancel() 102 return 103 } 104 log.WithError(err).Error("while writing to websocket", "path", path, "request", c.Request, "type", reflect.TypeOf(write).String()) // This could end up logging person sensitive data. 105 return 106 } 107 case <-sucker.pingTicker.C: 108 err = sucker.Ping() 109 if err != nil { 110 if errors.Is(err, ErrNoErrorHandled) { 111 //log.Debug("no ping already waiting for pong from client") 112 continue 113 } 114 if errors.Is(err, net.ErrClosed) { 115 clientClosed = true 116 cancel() 117 return 118 } 119 } 120 log.WithError(err).Debug("wrote ping from server") 121 } 122 } 123 }() 124 go func() { 125 defer close(reader) 126 var read T 127 var err error 128 for { 129 select { 130 case <-ctx.Done(): 131 return 132 default: 133 //read, err = ReadWebsocket[T](conn, connWriter) 134 read, err = sucker.Read() 135 if err != nil { 136 if errors.Is(err, ErrNoErrorHandled) { 137 continue 138 } 139 if errors.Is(err, ErrNotImplemented) { 140 log.WithError(err).Warning("continuing after packet is discarded") 141 continue 142 } 143 if errors.Is(err, net.ErrClosed) { 144 clientClosed = true 145 cancel() 146 return 147 } 148 if errors.Is(err, io.EOF) { 149 clientClosed = true 150 cancel() 151 log.Info("websocket is closed, server closing...") //This works, but gave a wrong impression, changed slightly 152 return 153 } 154 log.WithError(err).Error("while server reading from websocket", "path", path, "request", c.Request, "type", reflect.TypeOf(read).String()) // This could end up logging person sensitive data. 155 return 156 } 157 reader <- read 158 } 159 } 160 }() 161 wsfunc(reader, writer, c.Params, ctx) 162 }) 163 } 164 165 type webSucker[T any] struct { 166 pingTimout time.Duration 167 pingTicker *time.Ticker 168 pingLock sync.Mutex 169 writeLock sync.Mutex 170 conn net.Conn 171 } 172 173 func (sucker *webSucker[T]) Ping() (err error) { 174 if !sucker.pingLock.TryLock() { 175 return ErrNoErrorHandled 176 } 177 return sucker.WriteConn(ws.CompiledPing) 178 } 179 180 func (sucker *webSucker[T]) WriteConn(write []byte) (err error) { 181 defer sucker.pingTicker.Reset(sucker.pingTimout) 182 sucker.writeLock.Lock() 183 defer sucker.writeLock.Unlock() 184 var n int 185 n, err = sucker.conn.Write(write) 186 total := n 187 for err == nil && total < len(write) { 188 n, err = sucker.conn.Write(write[total:]) 189 total += n 190 } 191 return 192 } 193 194 func (sucker *webSucker[T]) Write(write Write[T]) (err error) { 195 defer func() { 196 if write.Err != nil { 197 close(write.Err) 198 } 199 }() 200 payload, err := json.Marshal(write.Data) 201 if err != nil { 202 if write.Err != nil { 203 write.Err <- err 204 } 205 return err 206 } 207 var frame []byte 208 frame, err = ws.CompileFrame(ws.NewTextFrame(payload)) 209 if err != nil { 210 return 211 } 212 err = sucker.WriteConn(frame) 213 /* 214 err = sucker.WriteConn(append(websocketHeaderBytes(ws.Header{ 215 Fin: true, 216 Rsv: 0, 217 OpCode: ws.OpText, 218 Masked: false, 219 Mask: [4]byte{}, 220 Length: int64(len(payload)), 221 }), payload...)) 222 */ 223 if err != nil { 224 if write.Err != nil { 225 write.Err <- err 226 } 227 return err 228 } 229 return 230 } 231 232 func (sucker *webSucker[T]) Read() (out T, err error) { 233 //defer sucker.pingTicker.Reset(sucker.pingTimout) 234 header, err := ws.ReadHeader(sucker.conn) 235 if err != nil { 236 if errors.Is(err, net.ErrClosed) { 237 err = io.EOF 238 return 239 } 240 return 241 } 242 log.Trace("packet received", "type", packetTypeToString(header.OpCode)) 243 sucker.pingTicker.Stop() 244 defer sucker.pingTicker.Reset(sucker.pingTimout) 245 if header.OpCode == ws.OpClose { 246 err = io.EOF 247 return 248 } 249 if header.OpCode == ws.OpPing { 250 log.Debug("ping received, ponging...") 251 payload := make([]byte, header.Length) 252 _, err = io.ReadFull(sucker.conn, payload) 253 if err != nil { 254 return 255 } 256 /* 257 var frame []byte 258 frame, err = ws.CompileFrame(ws.NewPongFrame(payload)) 259 if err != nil { 260 return 261 } 262 */ 263 err = sucker.WriteConn(ws.CompiledPong) 264 /* 265 err = sucker.WriteConn(append(websocketHeaderBytes(ws.Header{ 266 Fin: true, 267 Rsv: 0, 268 OpCode: ws.OpPong, 269 Masked: false, 270 Mask: [4]byte{}, 271 Length: header.Length, 272 }), payload...)) 273 */ 274 log.WithError(err).Trace("while ponging") 275 err = ErrNoErrorHandled 276 return 277 } 278 279 /* 280 1. Should verify against outstanding ping TODO 281 2. Should ignore if no outstanding ping 282 */ 283 if header.OpCode == ws.OpPong { 284 log.Debug("pong received") 285 sucker.pingLock.Unlock() 286 if header.Length == 0 { 287 err = ErrNoErrorHandled 288 return 289 } 290 _, err = io.CopyN(io.Discard, sucker.conn, header.Length) 291 err = ErrNoErrorHandled 292 return 293 } 294 295 if header.OpCode == ws.OpContinuation { 296 _, err = io.CopyN(io.Discard, sucker.conn, header.Length) 297 err = ErrNotImplemented 298 return 299 } 300 301 if header.OpCode == ws.OpBinary { 302 _, err = io.CopyN(io.Discard, sucker.conn, header.Length) 303 err = ErrNotImplemented 304 return 305 } 306 307 payload := make([]byte, header.Length) 308 _, err = io.ReadFull(sucker.conn, payload) 309 if err != nil { 310 if errors.Is(err, net.ErrClosed) { 311 err = io.EOF 312 return 313 } 314 return 315 } 316 if header.Masked { 317 ws.Cipher(payload, header.Mask, 0) 318 } 319 err = json.Unmarshal(payload, &out) 320 return 321 } 322 323 /* 324 func ReadWebsocket[T any](conn io.Reader, writer chan<- []byte) (out T, err error) { 325 header, err := ws.ReadHeader(conn) 326 if err != nil { 327 if errors.Is(err, net.ErrClosed) { 328 err = io.EOF 329 return 330 } 331 return 332 } 333 if header.OpCode == ws.OpClose { 334 err = io.EOF 335 return 336 } 337 if header.OpCode == ws.OpPing { 338 log.Info("ping received, ponging...") 339 //Could also use ws.NewPingFrame(body) 340 payload := make([]byte, header.Length) 341 _, err = io.ReadFull(conn, payload) 342 if err != nil { 343 return 344 } 345 346 writer <- append(websocketHeaderBytes(ws.Header{ //This can write to a closed channel 347 Fin: true, 348 Rsv: 0, 349 OpCode: ws.OpPong, 350 Masked: false, 351 Mask: [4]byte{}, 352 Length: header.Length, 353 }), payload...) 354 /* 355 err = ws.WriteHeader(conn, ws.Header{ 356 Fin: true, 357 Rsv: 0, 358 OpCode: ws.OpPong, 359 Masked: false, 360 Mask: [4]byte{}, 361 Length: header.Length, 362 }) 363 if err != nil { 364 return 365 } 366 _, err = io.CopyN(conn, conn, header.Length) 367 */ /* 368 err = ErrNoErrorHandled 369 return 370 } 371 372 /* 373 1. Should verify against outstanding ping TODO 374 2. Should ignore if no outstanding ping 375 */ /* 376 if header.OpCode == ws.OpPong { 377 log.Info("pong received") 378 if header.Length == 0 { 379 err = ErrNoErrorHandled 380 return 381 } 382 _, err = io.CopyN(io.Discard, conn, header.Length) 383 err = ErrNoErrorHandled 384 return 385 } 386 387 if header.OpCode == ws.OpContinuation { 388 _, err = io.CopyN(io.Discard, conn, header.Length) 389 err = ErrNotImplemented 390 return 391 } 392 393 if header.OpCode == ws.OpBinary { 394 _, err = io.CopyN(io.Discard, conn, header.Length) 395 err = ErrNotImplemented 396 return 397 } 398 399 payload := make([]byte, header.Length) 400 _, err = io.ReadFull(conn, payload) //Could be an idea to change this to ReadAll to not have EOF errors. Or silence them ourselves 401 /* 402 total, err := conn.Read(payload) 403 var n int 404 for err == nil && total < int(header.Length) { 405 n, err = conn.Read(payload[total:]) 406 total += n 407 } 408 */ /* 409 if err != nil { 410 if errors.Is(err, net.ErrClosed) { 411 err = io.EOF 412 return 413 } 414 return 415 } 416 if header.Masked { 417 ws.Cipher(payload, header.Mask, 0) 418 } 419 err = json.Unmarshal(payload, &out) 420 return 421 } 422 423 424 func WriteWebsocket[T any](writer chan<- []byte, write Write[T]) error { 425 defer func() { 426 if write.Err != nil { 427 close(write.Err) 428 } 429 }() 430 payload, err := json.Marshal(write.Data) 431 if err != nil { 432 if write.Err != nil { 433 write.Err <- err 434 } 435 return err 436 } 437 writer <- append(websocketHeaderBytes(ws.Header{ 438 Fin: true, 439 Rsv: 0, 440 OpCode: ws.OpText, 441 Masked: false, 442 Mask: [4]byte{}, 443 Length: int64(len(payload)), 444 }), payload...) 445 /* 446 err = ws.WriteFrame(conn, ws.Frame{ 447 Header: ws.Header{ 448 Fin: true, 449 Rsv: 0, 450 OpCode: ws.OpText, 451 Masked: false, 452 Mask: [4]byte{}, 453 Length: int64(len(payload)), 454 }, 455 Payload: payload, 456 }) 457 //_, err = conn.Write(payload) 458 if err != nil { 459 if write.Err != nil { 460 write.Err <- err 461 } 462 return err 463 } 464 */ /* 465 return nil 466 } 467 */ 468 469 func websocketHeaderBytes(h ws.Header) []byte { 470 bts := make([]byte, ws.MaxHeaderSize) 471 472 if h.Fin { 473 bts[0] |= bit0 474 } 475 bts[0] |= h.Rsv << 4 476 bts[0] |= byte(h.OpCode) 477 478 var n int 479 switch { 480 case h.Length <= len7: 481 bts[1] = byte(h.Length) 482 n = 2 483 484 case h.Length <= len16: 485 bts[1] = 126 486 binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length)) 487 n = 4 488 489 case h.Length <= len64: 490 bts[1] = 127 491 binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length)) 492 n = 10 493 494 default: 495 log.WithError(ws.ErrHeaderLengthUnexpected).Fatal("while creating websocket header bytes") 496 } 497 498 if h.Masked { 499 bts[1] |= bit0 500 n += copy(bts[n:], h.Mask[:]) 501 } 502 return bts[:n] 503 } 504 505 type WSHandler[T any] func(<-chan T, chan<- Write[T], gin.Params, context.Context) 506 507 var ErrNotImplemented = errors.New("operation not implemented") 508 var ErrNoErrorHandled = errors.New("handled") 509 510 const ( 511 bit0 = 0x80 512 513 len7 = int64(125) 514 len16 = int64(^(uint16(0))) 515 len64 = int64(^(uint64(0)) >> 1) 516 ) 517 518 func packetTypeToString(code ws.OpCode) string { 519 switch code { 520 case ws.OpText: 521 return "text" 522 case ws.OpBinary: 523 return "binary" 524 case ws.OpClose: 525 return "close" 526 case ws.OpPing: 527 return "ping" 528 case ws.OpPong: 529 return "pong" 530 case ws.OpContinuation: 531 return "continuation" 532 } 533 return "" 534 }