github.com/nikandfor/hacked@v0.0.0-20231207014854-3b383967fdf4/hnet/listen.go (about)

     1  package hnet
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"net/netip"
     8  	"time"
     9  )
    10  
    11  type (
    12  	StoppableListener struct {
    13  		context.Context
    14  		net.Listener
    15  	}
    16  
    17  	StoppableConn struct {
    18  		context.Context
    19  		net.Conn
    20  	}
    21  
    22  	ReaderFrom interface {
    23  		ReadFrom(p []byte) (int, net.Addr, error)
    24  	}
    25  
    26  	ReaderFromUDP interface {
    27  		ReadFromUDP(p []byte) (int, *net.UDPAddr, error)
    28  	}
    29  
    30  	ReaderFromUDPAddrPort interface {
    31  		ReadFromUDPAddrPort(p []byte) (int, netip.AddrPort, error)
    32  	}
    33  
    34  	ReaderMsgUDP interface {
    35  		ReadMsgUDP(p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
    36  	}
    37  
    38  	ReaderMsgUDPAddrPort interface {
    39  		ReadMsgUDPAddrPort(p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error)
    40  	}
    41  )
    42  
    43  func Accept(ctx context.Context, l net.Listener) (net.Conn, error) {
    44  	d, ok := l.(interface {
    45  		SetDeadline(time.Time) error
    46  	})
    47  
    48  	if !ok {
    49  		return l.Accept()
    50  	}
    51  
    52  	stopc := make(chan struct{})
    53  	defer close(stopc)
    54  
    55  	go func() {
    56  		select {
    57  		case <-ctx.Done():
    58  		case <-stopc:
    59  			return
    60  		}
    61  
    62  		_ = d.SetDeadline(time.Unix(1, 0))
    63  	}()
    64  
    65  	c, err := l.Accept()
    66  	if c != nil || !isTimeout(err) {
    67  		return c, err
    68  	}
    69  
    70  	select {
    71  	case <-ctx.Done():
    72  		err = ctx.Err()
    73  	default:
    74  	}
    75  
    76  	return nil, err
    77  }
    78  
    79  func Read(ctx context.Context, r io.Reader, p []byte) (int, error) {
    80  	d, ok := r.(interface {
    81  		SetReadDeadline(time.Time) error
    82  	})
    83  
    84  	if !ok {
    85  		return r.Read(p)
    86  	}
    87  
    88  	stopc := make(chan struct{})
    89  	defer close(stopc)
    90  
    91  	go func() {
    92  		select {
    93  		case <-ctx.Done():
    94  		case <-stopc:
    95  			return
    96  		}
    97  
    98  		_ = d.SetReadDeadline(time.Unix(1, 0))
    99  	}()
   100  
   101  	n, err := r.Read(p)
   102  
   103  	err = fixError(ctx, err)
   104  
   105  	return n, err
   106  }
   107  
   108  func ReadFrom(ctx context.Context, r ReaderFrom, p []byte) (int, net.Addr, error) {
   109  	d, ok := r.(interface {
   110  		SetReadDeadline(time.Time) error
   111  	})
   112  
   113  	if !ok {
   114  		return r.ReadFrom(p)
   115  	}
   116  
   117  	stopc := make(chan struct{})
   118  	defer close(stopc)
   119  
   120  	go func() {
   121  		select {
   122  		case <-ctx.Done():
   123  		case <-stopc:
   124  			return
   125  		}
   126  
   127  		_ = d.SetReadDeadline(time.Unix(1, 0))
   128  	}()
   129  
   130  	n, addr, err := r.ReadFrom(p)
   131  
   132  	err = fixError(ctx, err)
   133  
   134  	return n, addr, err
   135  }
   136  
   137  func ReadFromUDP(ctx context.Context, r ReaderFromUDP, p []byte) (int, *net.UDPAddr, error) {
   138  	d, ok := r.(interface {
   139  		SetReadDeadline(time.Time) error
   140  	})
   141  
   142  	if !ok {
   143  		return r.ReadFromUDP(p)
   144  	}
   145  
   146  	stopc := make(chan struct{})
   147  	defer close(stopc)
   148  
   149  	go func() {
   150  		select {
   151  		case <-ctx.Done():
   152  		case <-stopc:
   153  			return
   154  		}
   155  
   156  		_ = d.SetReadDeadline(time.Unix(1, 0))
   157  	}()
   158  
   159  	n, addr, err := r.ReadFromUDP(p)
   160  
   161  	err = fixError(ctx, err)
   162  
   163  	return n, addr, err
   164  }
   165  
   166  func ReadFromUDPAddrPort(ctx context.Context, r ReaderFromUDPAddrPort, p []byte) (int, netip.AddrPort, error) {
   167  	d, ok := r.(interface {
   168  		SetReadDeadline(time.Time) error
   169  	})
   170  
   171  	if !ok {
   172  		return r.ReadFromUDPAddrPort(p)
   173  	}
   174  
   175  	stopc := make(chan struct{})
   176  	defer close(stopc)
   177  
   178  	go func() {
   179  		select {
   180  		case <-ctx.Done():
   181  		case <-stopc:
   182  			return
   183  		}
   184  
   185  		_ = d.SetReadDeadline(time.Unix(1, 0))
   186  	}()
   187  
   188  	n, addr, err := r.ReadFromUDPAddrPort(p)
   189  
   190  	err = fixError(ctx, err)
   191  
   192  	return n, addr, err
   193  }
   194  
   195  func ReadMsgUDP(ctx context.Context, r ReaderMsgUDP, p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
   196  	d, ok := r.(interface {
   197  		SetReadDeadline(time.Time) error
   198  	})
   199  
   200  	if !ok {
   201  		return r.ReadMsgUDP(p, oob)
   202  	}
   203  
   204  	stopc := make(chan struct{})
   205  	defer close(stopc)
   206  
   207  	go func() {
   208  		select {
   209  		case <-ctx.Done():
   210  		case <-stopc:
   211  			return
   212  		}
   213  
   214  		_ = d.SetReadDeadline(time.Unix(1, 0))
   215  	}()
   216  
   217  	n, oobn, flags, addr, err = r.ReadMsgUDP(p, oob)
   218  
   219  	err = fixError(ctx, err)
   220  
   221  	return
   222  }
   223  
   224  func ReadMsgUDPAddrPort(ctx context.Context, r ReaderMsgUDPAddrPort, p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
   225  	d, ok := r.(interface {
   226  		SetReadDeadline(time.Time) error
   227  	})
   228  
   229  	if !ok {
   230  		return r.ReadMsgUDPAddrPort(p, oob)
   231  	}
   232  
   233  	stopc := make(chan struct{})
   234  	defer close(stopc)
   235  
   236  	go func() {
   237  		select {
   238  		case <-ctx.Done():
   239  		case <-stopc:
   240  			return
   241  		}
   242  
   243  		_ = d.SetReadDeadline(time.Unix(1, 0))
   244  	}()
   245  
   246  	n, oobn, flags, addr, err = r.ReadMsgUDPAddrPort(p, oob)
   247  
   248  	err = fixError(ctx, err)
   249  
   250  	return
   251  }
   252  
   253  func NewStoppableListener(ctx context.Context, l net.Listener) net.Listener {
   254  	return StoppableListener{
   255  		Context:  ctx,
   256  		Listener: l,
   257  	}
   258  }
   259  
   260  func (l StoppableListener) Accept() (net.Conn, error) {
   261  	return Accept(l.Context, l.Listener)
   262  }
   263  
   264  func NewStoppableConn(ctx context.Context, c net.Conn) net.Conn {
   265  	return StoppableConn{
   266  		Context: ctx,
   267  		Conn:    c,
   268  	}
   269  }
   270  
   271  func (c StoppableConn) Read(p []byte) (n int, err error) {
   272  	defer stopper(c.Context, c.Conn.SetReadDeadline)()
   273  
   274  	n, err = c.Conn.Read(p)
   275  	err = fixError(c.Context, err)
   276  
   277  	return
   278  }
   279  
   280  func (c StoppableConn) Write(p []byte) (n int, err error) {
   281  	defer stopper(c.Context, c.Conn.SetWriteDeadline)()
   282  
   283  	n, err = c.Conn.Write(p)
   284  	err = fixError(c.Context, err)
   285  
   286  	return
   287  }
   288  
   289  func stopper(ctx context.Context, dead func(time.Time) error) func() {
   290  	donec := make(chan struct{})
   291  
   292  	go func() {
   293  		select {
   294  		case <-ctx.Done():
   295  		case <-donec:
   296  			return
   297  		}
   298  
   299  		_ = dead(time.Unix(1, 0))
   300  	}()
   301  
   302  	return func() {
   303  		close(donec)
   304  	}
   305  }
   306  
   307  func isTimeout(err error) bool {
   308  	to, ok := err.(interface{ Timeout() bool })
   309  
   310  	return ok && to.Timeout()
   311  }
   312  
   313  func fixError(ctx context.Context, err error) error {
   314  	if isTimeout(err) {
   315  		select {
   316  		case <-ctx.Done():
   317  			err = ctx.Err()
   318  		default:
   319  		}
   320  	}
   321  
   322  	return err
   323  }