github.com/iDigitalFlame/xmt@v0.5.4/com/udp.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package com
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/iDigitalFlame/xmt/util/bugtrack"
    27  )
    28  
    29  const (
    30  	udpLimit = 4096
    31  
    32  	readOp  = time.Microsecond * 15
    33  	writeOp = time.Microsecond * 35
    34  )
    35  
    36  var (
    37  	empty time.Time
    38  
    39  	udpWake     struct{}
    40  	udpDeadline = new(udpErr)
    41  
    42  	buffers = sync.Pool{
    43  		New: func() interface{} {
    44  			var b [udpLimit]byte
    45  			return &b
    46  		},
    47  	}
    48  )
    49  
    50  type udpErr struct{}
    51  type udpConn struct {
    52  	bufs        chan udpData
    53  	sock        *udpListener
    54  	wake        chan struct{}
    55  	dev         udpAddr
    56  	buf         []byte
    57  	read, write time.Duration
    58  	lock        sync.Mutex
    59  }
    60  type udpData struct {
    61  	_ [0]func()
    62  	b *[udpLimit]byte
    63  	n int
    64  }
    65  type udpCompat struct {
    66  	udpSock
    67  }
    68  type udpStream struct {
    69  	net.Conn
    70  	buf         []byte
    71  	size        int
    72  	fails       uint8
    73  	read, write time.Duration
    74  }
    75  type udpSock interface {
    76  	udpSockInternal
    77  	net.PacketConn
    78  }
    79  type udpListener struct {
    80  	err      error
    81  	ctx      context.Context
    82  	del      chan udpAddr
    83  	new      chan *udpConn
    84  	cons     map[udpAddr]*udpConn
    85  	sock     *udpCompat
    86  	cancel   context.CancelFunc
    87  	deadline time.Duration
    88  	lock     sync.RWMutex
    89  }
    90  type udpConnector struct {
    91  	net.Dialer
    92  }
    93  
    94  func (udpErr) Timeout() bool {
    95  	return true
    96  }
    97  func (udpErr) Error() string {
    98  	return context.DeadlineExceeded.Error()
    99  }
   100  func (l *udpListener) purge() {
   101  	for {
   102  		select {
   103  		case d := <-l.del:
   104  			l.lock.Lock()
   105  			if c, ok := l.cons[d]; ok {
   106  				delete(l.cons, d)
   107  				close(c.bufs)
   108  				close(c.wake)
   109  				c.bufs, c.wake, c.sock = nil, nil, nil
   110  				c.lock.Unlock()
   111  			}
   112  			l.lock.Unlock()
   113  		case <-l.ctx.Done():
   114  			return
   115  		}
   116  	}
   117  }
   118  func (udpErr) Temporary() bool {
   119  	return true
   120  }
   121  func (l *udpListener) listen() {
   122  loop:
   123  	for l.sock.SetReadDeadline(empty); ; l.sock.SetReadDeadline(empty) {
   124  		var (
   125  			b         = buffers.Get().(*[udpLimit]byte)
   126  			n, a, err = l.sock.ReadPacket((*b)[:])
   127  		)
   128  		if bugtrack.Enabled {
   129  			bugtrack.Track("com.(*udpListener).listen(): Accept n=%d, a=%s, err=%s", n, a, err)
   130  		}
   131  		select {
   132  		case <-l.ctx.Done():
   133  			buffers.Put(b)
   134  			break loop
   135  		default:
   136  			if err != nil && !a.IsValid() && n == 0 {
   137  				buffers.Put(b)
   138  				l.err = err
   139  				break loop
   140  			}
   141  			if n == 0 || !a.IsValid() {
   142  				buffers.Put(b)
   143  				continue loop
   144  			}
   145  		}
   146  		if !a.IsValid() {
   147  			buffers.Put(b)
   148  			continue
   149  		}
   150  		l.lock.RLock()
   151  		c, ok := l.cons[a]
   152  		if l.lock.RUnlock(); ok {
   153  			if c.lock.Lock(); c.bufs != nil {
   154  				if bugtrack.Enabled {
   155  					bugtrack.Track("com.(*udpListener).listen(): Pushing n=%d bytes to conn a=%s", n, a.String())
   156  				}
   157  				c.bufs <- udpData{n: n, b: b}
   158  				c.lock.Unlock()
   159  				continue
   160  			}
   161  			c.lock.Unlock()
   162  			c = nil
   163  		}
   164  		if bugtrack.Enabled {
   165  			bugtrack.Track("com.(*udpListener).listen(): New tracked conn a=%s", a.String())
   166  		}
   167  		c = &udpConn{dev: a, sock: l, bufs: make(chan udpData, 256), wake: make(chan struct{}, 1)}
   168  		c.append(n, b, false)
   169  		go c.receive(l.ctx)
   170  		l.lock.Lock()
   171  		l.cons[a] = c
   172  		l.lock.Unlock()
   173  		l.new <- c
   174  	}
   175  	l.cancel()
   176  	if err := l.sock.Close(); err != nil && l.err == nil {
   177  		l.err = err
   178  	}
   179  	l.lock.Lock()
   180  	for _, c := range l.cons {
   181  		c.Close()
   182  	}
   183  	l.lock.Unlock()
   184  	close(l.del)
   185  	close(l.new)
   186  	l.cons = nil
   187  }
   188  func (c *udpConn) Close() error {
   189  	if c.sock == nil {
   190  		return nil
   191  	}
   192  	c.lock.Lock()
   193  	c.sock.del <- c.dev
   194  	c.sock = nil
   195  	return nil
   196  }
   197  func (udpAddr) Network() string {
   198  	return NameUDP
   199  }
   200  func (s *udpStream) Close() error {
   201  	err := s.Conn.Close()
   202  	s.read, s.write = -1, -1
   203  	return err
   204  }
   205  func (l *udpListener) Close() error {
   206  	err := l.sock.Close()
   207  	l.cancel()
   208  	return err
   209  }
   210  func (l *udpListener) Addr() net.Addr {
   211  	return l.sock.LocalAddr()
   212  }
   213  func (c *udpConn) LocalAddr() net.Addr {
   214  	return c.dev
   215  }
   216  
   217  // NewUDP creates a new simple UDP based connector with the supplied timeout.
   218  func NewUDP(t time.Duration) Connector {
   219  	if t < 0 {
   220  		t = DefaultTimeout
   221  	}
   222  	return &udpConnector{Dialer: net.Dialer{Timeout: t, KeepAlive: t}}
   223  }
   224  func (s *udpStream) readEnough() error {
   225  	if s.read > 0 {
   226  		return s.readEnoughTimeout(s.read, 25)
   227  	}
   228  	if s.size > 0 {
   229  		if bugtrack.Enabled {
   230  			bugtrack.Track("com.(*udpStream).readEnough(): Implementing our own timeout for a Read operation.")
   231  		}
   232  		return s.readEnoughTimeout(time.Millisecond*500, 25)
   233  	}
   234  	return s.readEnoughTimeout(time.Second*2, 2)
   235  }
   236  func (c *udpConn) RemoteAddr() net.Addr {
   237  	return c.dev
   238  }
   239  func (c *udpConn) receive(x context.Context) {
   240  	for {
   241  		select {
   242  		case <-x.Done():
   243  			return
   244  		case p, ok := <-c.bufs:
   245  			if !ok {
   246  				return
   247  			}
   248  			c.append(p.n, p.b, true)
   249  		}
   250  	}
   251  }
   252  func (c *udpConn) Read(b []byte) (int, error) {
   253  	if len(c.buf) == 0 && c.bufs == nil {
   254  		if bugtrack.Enabled {
   255  			bugtrack.Track("com.(*udpCon).Read(): read on closed conn.")
   256  		}
   257  		return 0, io.ErrClosedPipe
   258  	}
   259  	var (
   260  		t   *time.Timer
   261  		n   int
   262  		w   <-chan time.Time
   263  		err error
   264  	)
   265  loop:
   266  	for n < len(b) {
   267  		if bugtrack.Enabled {
   268  			bugtrack.Track("com.(*udpCon).Read(): n=%d, len(b)=%d, len(c.buf)=%d", n, len(b), len(c.buf))
   269  		}
   270  		if len(c.buf) > 0 {
   271  			c.lock.Lock()
   272  			v := copy(b[n:], c.buf)
   273  			if bugtrack.Enabled {
   274  				bugtrack.Track("com.(*udpCon).Read(): n=%d, v=%d, len(b)=%d, len(c.buf)=%d", n, v, len(b), len(c.buf))
   275  			}
   276  			if c.buf = c.buf[v:]; len(c.buf) == 0 {
   277  				c.buf = nil
   278  			}
   279  			c.lock.Unlock()
   280  			n += v
   281  			continue
   282  		}
   283  		if n == 0 {
   284  			if c.bufs == nil {
   285  				err = io.EOF
   286  				break
   287  			}
   288  			if t != nil {
   289  				t.Stop()
   290  				t, w = nil, nil
   291  			}
   292  			if c.read > 0 {
   293  				t = time.NewTimer(c.read)
   294  				w = t.C
   295  			}
   296  			select {
   297  			case <-w:
   298  				err = udpDeadline
   299  				break loop
   300  			case <-c.wake:
   301  				continue loop
   302  			case <-c.sock.ctx.Done():
   303  				err = io.ErrClosedPipe
   304  				break loop
   305  			}
   306  		}
   307  		break
   308  	}
   309  	if t != nil {
   310  		t.Stop()
   311  	}
   312  	if bugtrack.Enabled {
   313  		bugtrack.Track("com.(*udpCon).Read(): return n=%d, err=%s", n, err)
   314  	}
   315  	return n, err
   316  }
   317  func (c *udpConn) Write(b []byte) (int, error) {
   318  	if c.sock == nil {
   319  		return 0, io.ErrShortWrite
   320  	}
   321  	var (
   322  		n   int
   323  		t   *time.Timer
   324  		w   <-chan time.Time
   325  		err error
   326  	)
   327  loop:
   328  	for v, s, x := 0, 0, udpLimit; n < len(b) && s < len(b); {
   329  		if t != nil {
   330  			t.Stop()
   331  			w, t = nil, nil
   332  		}
   333  		if x > len(b) {
   334  			x = len(b)
   335  		}
   336  		if c.write > 0 {
   337  			t = time.NewTimer(c.write)
   338  			if w = t.C; bugtrack.Enabled {
   339  				bugtrack.Track("com.(*udpCon).Write(): Created timer with duration c.write=%s, n=%d, len(b)=%d.", c.write, n, len(b))
   340  			}
   341  		}
   342  		v, err = c.sock.sock.WritePacket(b[s:x], c.dev)
   343  		if bugtrack.Enabled {
   344  			bugtrack.Track("com.(*udpCon).Write(): Wrote bytes out n=%d, len(b)=%d, s=%d, x=%d, v=%d.", n, len(b), s, x, v)
   345  		}
   346  		s += v
   347  		x += v
   348  		if n += v; err != nil {
   349  			break
   350  		}
   351  		select {
   352  		case <-w:
   353  			err = udpDeadline
   354  			break loop
   355  		case <-c.sock.ctx.Done():
   356  			err = io.ErrClosedPipe
   357  			break loop
   358  		default:
   359  			time.Sleep(writeOp)
   360  		}
   361  	}
   362  	if t != nil {
   363  		t.Stop()
   364  	}
   365  	return n, err
   366  }
   367  func (s *udpStream) Read(b []byte) (int, error) {
   368  	if s.size == 0 || s.size < len(b) {
   369  		if err := s.readEnough(); err != nil {
   370  			if bugtrack.Enabled {
   371  				bugtrack.Track("com.(*udpStream).Read(): readEnough() err=%s", err)
   372  			}
   373  			return 0, err
   374  		}
   375  	}
   376  	if bugtrack.Enabled {
   377  		bugtrack.Track("com.(*udpStream).Read(): Read s.size=%d, len(s.buf)=%d, len(b)=%d", s.size, len(s.buf), len(b))
   378  	}
   379  	n := copy(b, s.buf[:s.size])
   380  	s.buf = s.buf[n:]
   381  	if s.size -= n; s.size <= 0 {
   382  		s.buf = nil
   383  	}
   384  	if bugtrack.Enabled {
   385  		bugtrack.Track("com.(*udpStream).Read(): Post-read n=%d, s.size=%d, len(s.buf)=%d, len(b)=%d", n, s.size, len(s.buf), len(b))
   386  	}
   387  	return n, nil
   388  }
   389  func (s *udpStream) Write(b []byte) (int, error) {
   390  	var (
   391  		t   *time.Timer
   392  		w   <-chan time.Time
   393  		n   int
   394  		err error
   395  	)
   396  loop:
   397  	for e, c, x := 0, 0, udpLimit; n < len(b) && e < len(b); {
   398  		if t != nil {
   399  			t.Stop()
   400  			w, t = nil, nil
   401  		}
   402  		if x > len(b) {
   403  			x = len(b)
   404  		}
   405  		if s.write > 0 {
   406  			t = time.NewTimer(s.write)
   407  			w = t.C
   408  			s.Conn.SetWriteDeadline(time.Now().Add(s.write))
   409  		}
   410  		if c, err = s.Conn.Write(b[e:x]); bugtrack.Enabled {
   411  			bugtrack.Track("com.(*udpStream).Write(): e=%d, x=%d, c=%d, n=%d, len(b)=%d, err=%s", e, x, c, n, len(b), err)
   412  		}
   413  		e += c
   414  		x += c
   415  		if n += c; err != nil {
   416  			break loop
   417  		}
   418  		select {
   419  		case <-w:
   420  			err = udpDeadline
   421  			break loop
   422  		default:
   423  			time.Sleep(writeOp)
   424  		}
   425  	}
   426  	return n, err
   427  }
   428  func (c *udpConn) SetDeadline(t time.Time) error {
   429  	if t.IsZero() {
   430  		c.read, c.write = 0, 0
   431  		return nil
   432  	}
   433  	d := time.Until(t)
   434  	if d <= 0 {
   435  		c.read, c.write = 0, 0
   436  		return nil
   437  	}
   438  	c.read, c.write = d, d
   439  	return nil
   440  }
   441  func (l *udpListener) Accept() (net.Conn, error) {
   442  	var (
   443  		t *time.Timer
   444  		w <-chan time.Time
   445  	)
   446  	if l.deadline > 0 {
   447  		t = time.NewTimer(l.deadline)
   448  		w = t.C
   449  	}
   450  loop:
   451  	for l.err == nil {
   452  		select {
   453  		case <-w:
   454  			return nil, udpDeadline
   455  		case n := <-l.new:
   456  			return n, nil
   457  		case <-l.ctx.Done():
   458  			break loop
   459  		}
   460  	}
   461  	if t != nil {
   462  		t.Stop()
   463  	}
   464  	return nil, l.err
   465  }
   466  func (s *udpStream) SetDeadline(t time.Time) error {
   467  	if t.IsZero() {
   468  		s.read, s.write = 0, 0
   469  		return s.Conn.SetDeadline(t)
   470  	}
   471  	d := time.Until(t)
   472  	if d <= 0 {
   473  		s.read, s.write = 0, 0
   474  		return s.Conn.SetDeadline(t)
   475  	}
   476  	s.read, s.write = d, d
   477  	return s.Conn.SetDeadline(t)
   478  }
   479  func (c *udpConn) SetReadDeadline(t time.Time) error {
   480  	if t.IsZero() {
   481  		c.read = 0
   482  		return nil
   483  	}
   484  	d := time.Until(t)
   485  	if d <= 0 {
   486  		c.read = 0
   487  		return nil
   488  	}
   489  	c.read = d
   490  	return nil
   491  }
   492  func (c *udpConn) SetWriteDeadline(t time.Time) error {
   493  	if t.IsZero() {
   494  		c.write = 0
   495  		return nil
   496  	}
   497  	d := time.Until(t)
   498  	if d <= 0 {
   499  		c.write = 0
   500  		return nil
   501  	}
   502  	c.write = d
   503  	return nil
   504  }
   505  func (s *udpStream) SetReadDeadline(t time.Time) error {
   506  	if t.IsZero() {
   507  		s.read = 0
   508  		return s.Conn.SetReadDeadline(t)
   509  	}
   510  	d := time.Until(t)
   511  	if d <= 0 {
   512  		s.read = 0
   513  		return s.Conn.SetReadDeadline(t)
   514  	}
   515  	s.read = d
   516  	return s.Conn.SetReadDeadline(t)
   517  }
   518  func (s *udpStream) SetWriteDeadline(t time.Time) error {
   519  	if t.IsZero() {
   520  		s.write = 0
   521  		return s.Conn.SetWriteDeadline(t)
   522  	}
   523  	d := time.Until(t)
   524  	if d <= 0 {
   525  		s.write = 0
   526  		return s.Conn.SetWriteDeadline(t)
   527  	}
   528  	s.write = d
   529  	return s.Conn.SetWriteDeadline(t)
   530  }
   531  func (c *udpConn) append(n int, b *[udpLimit]byte, w bool) {
   532  	if bugtrack.Enabled {
   533  		bugtrack.Track("com.(*udpCon).append(): n=%d, w=%t, len(c.buf)=%d", n, w, len(c.buf))
   534  	}
   535  	c.lock.Lock()
   536  	c.buf = append(c.buf, (*b)[:n]...)
   537  	c.lock.Unlock()
   538  	if buffers.Put(b); w {
   539  		select {
   540  		case c.wake <- udpWake:
   541  			if bugtrack.Enabled {
   542  				bugtrack.Track("com.(*udpCon).append(): Triggering wake.")
   543  			}
   544  		default:
   545  		}
   546  	}
   547  }
   548  func (s *udpStream) readEnoughTimeout(d time.Duration, m int) error {
   549  	var (
   550  		n   int
   551  		err error
   552  		l   = d // "Canary" value for timeout.
   553  	)
   554  	for q, y, c, k := d/time.Duration(m), time.Now().Add(d), 0, 0; ; {
   555  		if len(s.buf) == 0 || len(s.buf)-s.size < udpLimit {
   556  			if bugtrack.Enabled {
   557  				bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Expanding socket buffer free=%d, len(s.buf)=%d, s.size=%d.", len(s.buf)-s.size, len(s.buf), s.size)
   558  			}
   559  			s.buf = append(s.buf, make([]byte, udpLimit)...)
   560  		}
   561  		if time.Sleep(readOp); bugtrack.Enabled {
   562  			bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Pre-read s.size=%d, len(s.buf)=%d, q=%s, n=%d, d=%s, c=%d, s.fails=%d", s.size, len(s.buf), q, n, d, c, s.fails)
   563  		}
   564  		if s.read > 0 && l != s.read {
   565  			// When in channel mode, this is set by 'SetDeadline', which allows
   566  			// the writer Goroutine to "bump" the timeout on the reader and allow
   567  			// it to NOT get caught in an infinate read Op.
   568  			l, c, q, y = s.read, 0, s.read/time.Duration(m), time.Now().Add(s.read)
   569  			if bugtrack.Enabled {
   570  				bugtrack.Track("com.(*udpStream).readEnoughTimeout(): ReadDeadline was bumped to %s, c=0, q=%s", l, q)
   571  			}
   572  		}
   573  		s.Conn.SetReadDeadline(time.Now().Add(q))
   574  		if n, err = s.Conn.Read(s.buf[s.size:]); bugtrack.Enabled {
   575  			bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Post-read n=%d, err=%s", n, err)
   576  		}
   577  		if s.size += n; s.read == -1 {
   578  			return io.ErrClosedPipe
   579  		}
   580  		if n > 0 || err == nil {
   581  			if k++; k > 1 {
   582  				return nil
   583  			}
   584  			continue
   585  		}
   586  		if e, ok := err.(net.Error); ok && e.Timeout() {
   587  			if time.Now().After(y) {
   588  				err = nil
   589  				if c++; c > m || s.size > 0 {
   590  					if bugtrack.Enabled {
   591  						bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Read timeout hit, n=%d, s.size=%d, len(s.buf)=%d, c=%d, s.fails=%d", n, s.size, len(s.buf), c, s.fails)
   592  					}
   593  					break
   594  				}
   595  				continue
   596  			}
   597  			if c++; c > m {
   598  				err = nil
   599  				break
   600  			}
   601  			continue
   602  		}
   603  		if err == io.EOF {
   604  			err = nil
   605  		}
   606  		break
   607  	}
   608  	if bugtrack.Enabled {
   609  		bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Read return n=%d, s.size=%d, len(s.buf)=%d, err=%s, s.fails=%d.", n, s.size, len(s.buf), err, s.fails)
   610  	}
   611  	if err != nil {
   612  		return err
   613  	}
   614  	if s.fails > 1 && s.size == 0 {
   615  		if bugtrack.Enabled {
   616  			bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Fail count reached with no progress! s.fails=%d, s.size=%d.", s.fails, s.size)
   617  		}
   618  		return io.ErrNoProgress
   619  	}
   620  	if s.size == 0 {
   621  		if s.fails++; bugtrack.Enabled {
   622  			bugtrack.Track("com.(*udpStream).readEnoughTimeout(): Increasing fail count! s.fails=%d.", s.fails)
   623  		}
   624  	}
   625  	return nil
   626  }
   627  func (c *udpConnector) Connect(x context.Context, s string) (net.Conn, error) {
   628  	v, err := c.DialContext(x, NameUDP, s)
   629  	if err != nil {
   630  		return nil, err
   631  	}
   632  	return &udpStream{Conn: v}, nil
   633  }
   634  func (*udpConnector) Listen(x context.Context, s string) (net.Listener, error) {
   635  	c, err := ListenConfig.ListenPacket(x, NameUDP, s)
   636  	if err != nil {
   637  		return nil, err
   638  	}
   639  	l := &udpListener{
   640  		new:  make(chan *udpConn, 16),
   641  		del:  make(chan udpAddr, 16),
   642  		cons: make(map[udpAddr]*udpConn),
   643  		sock: &udpCompat{c.(*net.UDPConn)},
   644  	}
   645  	l.ctx, l.cancel = context.WithCancel(x)
   646  	go l.purge()
   647  	go l.listen()
   648  	return l, nil
   649  }