github.com/iDigitalFlame/xmt@v0.5.4/com/udp.go (about) 1 // Copyright (C) 2020 - 2023 iDigitalFlame 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU General Public License as published by 5 // the Free Software Foundation, either version 3 of the License, or 6 // any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU General Public License for more details. 12 // 13 // You should have received a copy of the GNU General Public License 14 // along with this program. If not, see <https://www.gnu.org/licenses/>. 15 // 16 17 package com 18 19 import ( 20 "context" 21 "io" 22 "net" 23 "sync" 24 "time" 25 26 "github.com/iDigitalFlame/xmt/util/bugtrack" 27 ) 28 29 const ( 30 udpLimit = 4096 31 32 readOp = time.Microsecond * 15 33 writeOp = time.Microsecond * 35 34 ) 35 36 var ( 37 empty time.Time 38 39 udpWake struct{} 40 udpDeadline = new(udpErr) 41 42 buffers = sync.Pool{ 43 New: func() interface{} { 44 var b [udpLimit]byte 45 return &b 46 }, 47 } 48 ) 49 50 type udpErr struct{} 51 type udpConn struct { 52 bufs chan udpData 53 sock *udpListener 54 wake chan struct{} 55 dev udpAddr 56 buf []byte 57 read, write time.Duration 58 lock sync.Mutex 59 } 60 type udpData struct { 61 _ [0]func() 62 b *[udpLimit]byte 63 n int 64 } 65 type udpCompat struct { 66 udpSock 67 } 68 type udpStream struct { 69 net.Conn 70 buf []byte 71 size int 72 fails uint8 73 read, write time.Duration 74 } 75 type udpSock interface { 76 udpSockInternal 77 net.PacketConn 78 } 79 type udpListener struct { 80 err error 81 ctx context.Context 82 del chan udpAddr 83 new chan *udpConn 84 cons map[udpAddr]*udpConn 85 sock *udpCompat 86 cancel context.CancelFunc 87 deadline time.Duration 88 lock sync.RWMutex 89 } 90 type udpConnector struct { 91 net.Dialer 92 } 93 94 func (udpErr) Timeout() bool { 95 return true 96 } 97 func (udpErr) Error() string { 98 return context.DeadlineExceeded.Error() 99 } 100 func (l *udpListener) purge() { 101 for { 102 select { 103 case d := <-l.del: 104 l.lock.Lock() 105 if c, ok := l.cons[d]; ok { 106 delete(l.cons, d) 107 close(c.bufs) 108 close(c.wake) 109 c.bufs, c.wake, c.sock = nil, nil, nil 110 c.lock.Unlock() 111 } 112 l.lock.Unlock() 113 case <-l.ctx.Done(): 114 return 115 } 116 } 117 } 118 func (udpErr) Temporary() bool { 119 return true 120 } 121 func (l *udpListener) listen() { 122 loop: 123 for l.sock.SetReadDeadline(empty); ; l.sock.SetReadDeadline(empty) { 124 var ( 125 b = buffers.Get().(*[udpLimit]byte) 126 n, a, err = l.sock.ReadPacket((*b)[:]) 127 ) 128 if bugtrack.Enabled { 129 bugtrack.Track("com.(*udpListener).listen(): Accept n=%d, a=%s, err=%s", n, a, err) 130 } 131 select { 132 case <-l.ctx.Done(): 133 buffers.Put(b) 134 break loop 135 default: 136 if err != nil && !a.IsValid() && n == 0 { 137 buffers.Put(b) 138 l.err = err 139 break loop 140 } 141 if n == 0 || !a.IsValid() { 142 buffers.Put(b) 143 continue loop 144 } 145 } 146 if !a.IsValid() { 147 buffers.Put(b) 148 continue 149 } 150 l.lock.RLock() 151 c, ok := l.cons[a] 152 if l.lock.RUnlock(); ok { 153 if c.lock.Lock(); c.bufs != nil { 154 if bugtrack.Enabled { 155 bugtrack.Track("com.(*udpListener).listen(): Pushing n=%d bytes to conn a=%s", n, a.String()) 156 } 157 c.bufs <- udpData{n: n, b: b} 158 c.lock.Unlock() 159 continue 160 } 161 c.lock.Unlock() 162 c = nil 163 } 164 if bugtrack.Enabled { 165 bugtrack.Track("com.(*udpListener).listen(): New tracked conn a=%s", a.String()) 166 } 167 c = &udpConn{dev: a, sock: l, bufs: make(chan udpData, 256), wake: make(chan struct{}, 1)} 168 c.append(n, b, false) 169 go c.receive(l.ctx) 170 l.lock.Lock() 171 l.cons[a] = c 172 l.lock.Unlock() 173 l.new <- c 174 } 175 l.cancel() 176 if err := l.sock.Close(); err != nil && l.err == nil { 177 l.err = err 178 } 179 l.lock.Lock() 180 for _, c := range l.cons { 181 c.Close() 182 } 183 l.lock.Unlock() 184 close(l.del) 185 close(l.new) 186 l.cons = nil 187 } 188 func (c *udpConn) Close() error { 189 if c.sock == nil { 190 return nil 191 } 192 c.lock.Lock() 193 c.sock.del <- c.dev 194 c.sock = nil 195 return nil 196 } 197 func (udpAddr) Network() string { 198 return NameUDP 199 } 200 func (s *udpStream) Close() error { 201 err := s.Conn.Close() 202 s.read, s.write = -1, -1 203 return err 204 } 205 func (l *udpListener) Close() error { 206 err := l.sock.Close() 207 l.cancel() 208 return err 209 } 210 func (l *udpListener) Addr() net.Addr { 211 return l.sock.LocalAddr() 212 } 213 func (c *udpConn) LocalAddr() net.Addr { 214 return c.dev 215 } 216 217 // NewUDP creates a new simple UDP based connector with the supplied timeout. 218 func NewUDP(t time.Duration) Connector { 219 if t < 0 { 220 t = DefaultTimeout 221 } 222 return &udpConnector{Dialer: net.Dialer{Timeout: t, KeepAlive: t}} 223 } 224 func (s *udpStream) readEnough() error { 225 if s.read > 0 { 226 return s.readEnoughTimeout(s.read, 25) 227 } 228 if s.size > 0 { 229 if bugtrack.Enabled { 230 bugtrack.Track("com.(*udpStream).readEnough(): Implementing our own timeout for a Read operation.") 231 } 232 return s.readEnoughTimeout(time.Millisecond*500, 25) 233 } 234 return s.readEnoughTimeout(time.Second*2, 2) 235 } 236 func (c *udpConn) RemoteAddr() net.Addr { 237 return c.dev 238 } 239 func (c *udpConn) receive(x context.Context) { 240 for { 241 select { 242 case <-x.Done(): 243 return 244 case p, ok := <-c.bufs: 245 if !ok { 246 return 247 } 248 c.append(p.n, p.b, true) 249 } 250 } 251 } 252 func (c *udpConn) Read(b []byte) (int, error) { 253 if len(c.buf) == 0 && c.bufs == nil { 254 if bugtrack.Enabled { 255 bugtrack.Track("com.(*udpCon).Read(): read on closed conn.") 256 } 257 return 0, io.ErrClosedPipe 258 } 259 var ( 260 t *time.Timer 261 n int 262 w <-chan time.Time 263 err error 264 ) 265 loop: 266 for n < len(b) { 267 if bugtrack.Enabled { 268 bugtrack.Track("com.(*udpCon).Read(): n=%d, len(b)=%d, len(c.buf)=%d", n, len(b), len(c.buf)) 269 } 270 if len(c.buf) > 0 { 271 c.lock.Lock() 272 v := copy(b[n:], c.buf) 273 if bugtrack.Enabled { 274 bugtrack.Track("com.(*udpCon).Read(): n=%d, v=%d, len(b)=%d, len(c.buf)=%d", n, v, len(b), len(c.buf)) 275 } 276 if c.buf = c.buf[v:]; len(c.buf) == 0 { 277 c.buf = nil 278 } 279 c.lock.Unlock() 280 n += v 281 continue 282 } 283 if n == 0 { 284 if c.bufs == nil { 285 err = io.EOF 286 break 287 } 288 if t != nil { 289 t.Stop() 290 t, w = nil, nil 291 } 292 if c.read > 0 { 293 t = time.NewTimer(c.read) 294 w = t.C 295 } 296 select { 297 case <-w: 298 err = udpDeadline 299 break loop 300 case <-c.wake: 301 continue loop 302 case <-c.sock.ctx.Done(): 303 err = io.ErrClosedPipe 304 break loop 305 } 306 } 307 break 308 } 309 if t != nil { 310 t.Stop() 311 } 312 if bugtrack.Enabled { 313 bugtrack.Track("com.(*udpCon).Read(): return n=%d, err=%s", n, err) 314 } 315 return n, err 316 } 317 func (c *udpConn) Write(b []byte) (int, error) { 318 if c.sock == nil { 319 return 0, io.ErrShortWrite 320 } 321 var ( 322 n int 323 t *time.Timer 324 w <-chan time.Time 325 err error 326 ) 327 loop: 328 for v, s, x := 0, 0, udpLimit; n < len(b) && s < len(b); { 329 if t != nil { 330 t.Stop() 331 w, t = nil, nil 332 } 333 if x > len(b) { 334 x = len(b) 335 } 336 if c.write > 0 { 337 t = time.NewTimer(c.write) 338 if w = t.C; bugtrack.Enabled { 339 bugtrack.Track("com.(*udpCon).Write(): Created timer with duration c.write=%s, n=%d, len(b)=%d.", c.write, n, len(b)) 340 } 341 } 342 v, err = c.sock.sock.WritePacket(b[s:x], c.dev) 343 if bugtrack.Enabled { 344 bugtrack.Track("com.(*udpCon).Write(): Wrote bytes out n=%d, len(b)=%d, s=%d, x=%d, v=%d.", n, len(b), s, x, v) 345 } 346 s += v 347 x += v 348 if n += v; err != nil { 349 break 350 } 351 select { 352 case <-w: 353 err = udpDeadline 354 break loop 355 case <-c.sock.ctx.Done(): 356 err = io.ErrClosedPipe 357 break loop 358 default: 359 time.Sleep(writeOp) 360 } 361 } 362 if t != nil { 363 t.Stop() 364 } 365 return n, err 366 } 367 func (s *udpStream) Read(b []byte) (int, error) { 368 if s.size == 0 || s.size < len(b) { 369 if err := s.readEnough(); err != nil { 370 if bugtrack.Enabled { 371 bugtrack.Track("com.(*udpStream).Read(): readEnough() err=%s", err) 372 } 373 return 0, err 374 } 375 } 376 if bugtrack.Enabled { 377 bugtrack.Track("com.(*udpStream).Read(): Read s.size=%d, len(s.buf)=%d, len(b)=%d", s.size, len(s.buf), len(b)) 378 } 379 n := copy(b, s.buf[:s.size]) 380 s.buf = s.buf[n:] 381 if s.size -= n; s.size <= 0 { 382 s.buf = nil 383 } 384 if bugtrack.Enabled { 385 bugtrack.Track("com.(*udpStream).Read(): Post-read n=%d, s.size=%d, len(s.buf)=%d, len(b)=%d", n, s.size, len(s.buf), len(b)) 386 } 387 return n, nil 388 } 389 func (s *udpStream) Write(b []byte) (int, error) { 390 var ( 391 t *time.Timer 392 w <-chan time.Time 393 n int 394 err error 395 ) 396 loop: 397 for e, c, x := 0, 0, udpLimit; n < len(b) && e < len(b); { 398 if t != nil { 399 t.Stop() 400 w, t = nil, nil 401 } 402 if x > len(b) { 403 x = len(b) 404 } 405 if s.write > 0 { 406 t = time.NewTimer(s.write) 407 w = t.C 408 s.Conn.SetWriteDeadline(time.Now().Add(s.write)) 409 } 410 if c, err = s.Conn.Write(b[e:x]); bugtrack.Enabled { 411 bugtrack.Track("com.(*udpStream).Write(): e=%d, x=%d, c=%d, n=%d, len(b)=%d, err=%s", e, x, c, n, len(b), err) 412 } 413 e += c 414 x += c 415 if n += c; err != nil { 416 break loop 417 } 418 select { 419 case <-w: 420 err = udpDeadline 421 break loop 422 default: 423 time.Sleep(writeOp) 424 } 425 } 426 return n, err 427 } 428 func (c *udpConn) SetDeadline(t time.Time) error { 429 if t.IsZero() { 430 c.read, c.write = 0, 0 431 return nil 432 } 433 d := time.Until(t) 434 if d <= 0 { 435 c.read, c.write = 0, 0 436 return nil 437 } 438 c.read, c.write = d, d 439 return nil 440 } 441 func (l *udpListener) Accept() (net.Conn, error) { 442 var ( 443 t *time.Timer 444 w <-chan time.Time 445 ) 446 if l.deadline > 0 { 447 t = time.NewTimer(l.deadline) 448 w = t.C 449 } 450 loop: 451 for l.err == nil { 452 select { 453 case <-w: 454 return nil, udpDeadline 455 case n := <-l.new: 456 return n, nil 457 case <-l.ctx.Done(): 458 break loop 459 } 460 } 461 if t != nil { 462 t.Stop() 463 } 464 return nil, l.err 465 } 466 func (s *udpStream) SetDeadline(t time.Time) error { 467 if t.IsZero() { 468 s.read, s.write = 0, 0 469 return s.Conn.SetDeadline(t) 470 } 471 d := time.Until(t) 472 if d <= 0 { 473 s.read, s.write = 0, 0 474 return s.Conn.SetDeadline(t) 475 } 476 s.read, s.write = d, d 477 return s.Conn.SetDeadline(t) 478 } 479 func (c *udpConn) SetReadDeadline(t time.Time) error { 480 if t.IsZero() { 481 c.read = 0 482 return nil 483 } 484 d := time.Until(t) 485 if d <= 0 { 486 c.read = 0 487 return nil 488 } 489 c.read = d 490 return nil 491 } 492 func (c *udpConn) SetWriteDeadline(t time.Time) error { 493 if t.IsZero() { 494 c.write = 0 495 return nil 496 } 497 d := time.Until(t) 498 if d <= 0 { 499 c.write = 0 500 return nil 501 } 502 c.write = d 503 return nil 504 } 505 func (s *udpStream) SetReadDeadline(t time.Time) error { 506 if t.IsZero() { 507 s.read = 0 508 return s.Conn.SetReadDeadline(t) 509 } 510 d := time.Until(t) 511 if d <= 0 { 512 s.read = 0 513 return s.Conn.SetReadDeadline(t) 514 } 515 s.read = d 516 return s.Conn.SetReadDeadline(t) 517 } 518 func (s *udpStream) SetWriteDeadline(t time.Time) error { 519 if t.IsZero() { 520 s.write = 0 521 return s.Conn.SetWriteDeadline(t) 522 } 523 d := time.Until(t) 524 if d <= 0 { 525 s.write = 0 526 return s.Conn.SetWriteDeadline(t) 527 } 528 s.write = d 529 return s.Conn.SetWriteDeadline(t) 530 } 531 func (c *udpConn) append(n int, b *[udpLimit]byte, w bool) { 532 if bugtrack.Enabled { 533 bugtrack.Track("com.(*udpCon).append(): n=%d, w=%t, len(c.buf)=%d", n, w, len(c.buf)) 534 } 535 c.lock.Lock() 536 c.buf = append(c.buf, (*b)[:n]...) 537 c.lock.Unlock() 538 if buffers.Put(b); w { 539 select { 540 case c.wake <- udpWake: 541 if bugtrack.Enabled { 542 bugtrack.Track("com.(*udpCon).append(): Triggering wake.") 543 } 544 default: 545 } 546 } 547 } 548 func (s *udpStream) readEnoughTimeout(d time.Duration, m int) error { 549 var ( 550 n int 551 err error 552 l = d // "Canary" value for timeout. 553 ) 554 for q, y, c, k := d/time.Duration(m), time.Now().Add(d), 0, 0; ; { 555 if len(s.buf) == 0 || len(s.buf)-s.size < udpLimit { 556 if bugtrack.Enabled { 557 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Expanding socket buffer free=%d, len(s.buf)=%d, s.size=%d.", len(s.buf)-s.size, len(s.buf), s.size) 558 } 559 s.buf = append(s.buf, make([]byte, udpLimit)...) 560 } 561 if time.Sleep(readOp); bugtrack.Enabled { 562 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Pre-read s.size=%d, len(s.buf)=%d, q=%s, n=%d, d=%s, c=%d, s.fails=%d", s.size, len(s.buf), q, n, d, c, s.fails) 563 } 564 if s.read > 0 && l != s.read { 565 // When in channel mode, this is set by 'SetDeadline', which allows 566 // the writer Goroutine to "bump" the timeout on the reader and allow 567 // it to NOT get caught in an infinate read Op. 568 l, c, q, y = s.read, 0, s.read/time.Duration(m), time.Now().Add(s.read) 569 if bugtrack.Enabled { 570 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): ReadDeadline was bumped to %s, c=0, q=%s", l, q) 571 } 572 } 573 s.Conn.SetReadDeadline(time.Now().Add(q)) 574 if n, err = s.Conn.Read(s.buf[s.size:]); bugtrack.Enabled { 575 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Post-read n=%d, err=%s", n, err) 576 } 577 if s.size += n; s.read == -1 { 578 return io.ErrClosedPipe 579 } 580 if n > 0 || err == nil { 581 if k++; k > 1 { 582 return nil 583 } 584 continue 585 } 586 if e, ok := err.(net.Error); ok && e.Timeout() { 587 if time.Now().After(y) { 588 err = nil 589 if c++; c > m || s.size > 0 { 590 if bugtrack.Enabled { 591 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Read timeout hit, n=%d, s.size=%d, len(s.buf)=%d, c=%d, s.fails=%d", n, s.size, len(s.buf), c, s.fails) 592 } 593 break 594 } 595 continue 596 } 597 if c++; c > m { 598 err = nil 599 break 600 } 601 continue 602 } 603 if err == io.EOF { 604 err = nil 605 } 606 break 607 } 608 if bugtrack.Enabled { 609 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Read return n=%d, s.size=%d, len(s.buf)=%d, err=%s, s.fails=%d.", n, s.size, len(s.buf), err, s.fails) 610 } 611 if err != nil { 612 return err 613 } 614 if s.fails > 1 && s.size == 0 { 615 if bugtrack.Enabled { 616 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Fail count reached with no progress! s.fails=%d, s.size=%d.", s.fails, s.size) 617 } 618 return io.ErrNoProgress 619 } 620 if s.size == 0 { 621 if s.fails++; bugtrack.Enabled { 622 bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Increasing fail count! s.fails=%d.", s.fails) 623 } 624 } 625 return nil 626 } 627 func (c *udpConnector) Connect(x context.Context, s string) (net.Conn, error) { 628 v, err := c.DialContext(x, NameUDP, s) 629 if err != nil { 630 return nil, err 631 } 632 return &udpStream{Conn: v}, nil 633 } 634 func (*udpConnector) Listen(x context.Context, s string) (net.Listener, error) { 635 c, err := ListenConfig.ListenPacket(x, NameUDP, s) 636 if err != nil { 637 return nil, err 638 } 639 l := &udpListener{ 640 new: make(chan *udpConn, 16), 641 del: make(chan udpAddr, 16), 642 cons: make(map[udpAddr]*udpConn), 643 sock: &udpCompat{c.(*net.UDPConn)}, 644 } 645 l.ctx, l.cancel = context.WithCancel(x) 646 go l.purge() 647 go l.listen() 648 return l, nil 649 }