github.com/nikandfor/hacked@v0.0.0-20230429073333-a318d546207a/hnet/listen.go (about)

     1  package hnet
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"net/netip"
     8  	"time"
     9  )
    10  
    11  type (
    12  	StoppableConn struct {
    13  		context.Context
    14  		net.Conn
    15  	}
    16  
    17  	ReaderFrom interface {
    18  		ReadFrom(p []byte) (int, net.Addr, error)
    19  	}
    20  
    21  	ReaderFromUDP interface {
    22  		ReadFromUDP(p []byte) (int, *net.UDPAddr, error)
    23  	}
    24  
    25  	ReaderFromUDPAddrPort interface {
    26  		ReadFromUDPAddrPort(p []byte) (int, netip.AddrPort, error)
    27  	}
    28  
    29  	ReaderMsgUDP interface {
    30  		ReadMsgUDP(p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
    31  	}
    32  
    33  	ReaderMsgUDPAddrPort interface {
    34  		ReadMsgUDPAddrPort(p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error)
    35  	}
    36  )
    37  
    38  func Accept(ctx context.Context, l net.Listener) (net.Conn, error) {
    39  	d, ok := l.(interface {
    40  		SetDeadline(time.Time) error
    41  	})
    42  
    43  	if !ok {
    44  		return l.Accept()
    45  	}
    46  
    47  	stopc := make(chan struct{})
    48  	defer close(stopc)
    49  
    50  	go func() {
    51  		select {
    52  		case <-ctx.Done():
    53  		case <-stopc:
    54  			return
    55  		}
    56  
    57  		_ = d.SetDeadline(time.Unix(1, 0))
    58  	}()
    59  
    60  	c, err := l.Accept()
    61  	if c != nil || !isTimeout(err) {
    62  		return c, err
    63  	}
    64  
    65  	select {
    66  	case <-ctx.Done():
    67  		err = ctx.Err()
    68  	default:
    69  	}
    70  
    71  	return nil, err
    72  }
    73  
    74  func Read(ctx context.Context, r io.Reader, p []byte) (int, error) {
    75  	d, ok := r.(interface {
    76  		SetReadDeadline(time.Time) error
    77  	})
    78  
    79  	if !ok {
    80  		return r.Read(p)
    81  	}
    82  
    83  	stopc := make(chan struct{})
    84  	defer close(stopc)
    85  
    86  	go func() {
    87  		select {
    88  		case <-ctx.Done():
    89  		case <-stopc:
    90  			return
    91  		}
    92  
    93  		_ = d.SetReadDeadline(time.Unix(1, 0))
    94  	}()
    95  
    96  	n, err := r.Read(p)
    97  
    98  	err = fixError(ctx, err)
    99  
   100  	return n, err
   101  }
   102  
   103  func ReadFrom(ctx context.Context, r ReaderFrom, p []byte) (int, net.Addr, error) {
   104  	d, ok := r.(interface {
   105  		SetReadDeadline(time.Time) error
   106  	})
   107  
   108  	if !ok {
   109  		return r.ReadFrom(p)
   110  	}
   111  
   112  	stopc := make(chan struct{})
   113  	defer close(stopc)
   114  
   115  	go func() {
   116  		select {
   117  		case <-ctx.Done():
   118  		case <-stopc:
   119  			return
   120  		}
   121  
   122  		_ = d.SetReadDeadline(time.Unix(1, 0))
   123  	}()
   124  
   125  	n, addr, err := r.ReadFrom(p)
   126  
   127  	err = fixError(ctx, err)
   128  
   129  	return n, addr, err
   130  }
   131  
   132  func ReadFromUDP(ctx context.Context, r ReaderFromUDP, p []byte) (int, *net.UDPAddr, error) {
   133  	d, ok := r.(interface {
   134  		SetReadDeadline(time.Time) error
   135  	})
   136  
   137  	if !ok {
   138  		return r.ReadFromUDP(p)
   139  	}
   140  
   141  	stopc := make(chan struct{})
   142  	defer close(stopc)
   143  
   144  	go func() {
   145  		select {
   146  		case <-ctx.Done():
   147  		case <-stopc:
   148  			return
   149  		}
   150  
   151  		_ = d.SetReadDeadline(time.Unix(1, 0))
   152  	}()
   153  
   154  	n, addr, err := r.ReadFromUDP(p)
   155  
   156  	err = fixError(ctx, err)
   157  
   158  	return n, addr, err
   159  }
   160  
   161  func ReadFromUDPAddrPort(ctx context.Context, r ReaderFromUDPAddrPort, p []byte) (int, netip.AddrPort, error) {
   162  	d, ok := r.(interface {
   163  		SetReadDeadline(time.Time) error
   164  	})
   165  
   166  	if !ok {
   167  		return r.ReadFromUDPAddrPort(p)
   168  	}
   169  
   170  	stopc := make(chan struct{})
   171  	defer close(stopc)
   172  
   173  	go func() {
   174  		select {
   175  		case <-ctx.Done():
   176  		case <-stopc:
   177  			return
   178  		}
   179  
   180  		_ = d.SetReadDeadline(time.Unix(1, 0))
   181  	}()
   182  
   183  	n, addr, err := r.ReadFromUDPAddrPort(p)
   184  
   185  	err = fixError(ctx, err)
   186  
   187  	return n, addr, err
   188  }
   189  
   190  func ReadMsgUDP(ctx context.Context, r ReaderMsgUDP, p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
   191  	d, ok := r.(interface {
   192  		SetReadDeadline(time.Time) error
   193  	})
   194  
   195  	if !ok {
   196  		return r.ReadMsgUDP(p, oob)
   197  	}
   198  
   199  	stopc := make(chan struct{})
   200  	defer close(stopc)
   201  
   202  	go func() {
   203  		select {
   204  		case <-ctx.Done():
   205  		case <-stopc:
   206  			return
   207  		}
   208  
   209  		_ = d.SetReadDeadline(time.Unix(1, 0))
   210  	}()
   211  
   212  	n, oobn, flags, addr, err = r.ReadMsgUDP(p, oob)
   213  
   214  	err = fixError(ctx, err)
   215  
   216  	return
   217  }
   218  
   219  func ReadMsgUDPAddrPort(ctx context.Context, r ReaderMsgUDPAddrPort, p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
   220  	d, ok := r.(interface {
   221  		SetReadDeadline(time.Time) error
   222  	})
   223  
   224  	if !ok {
   225  		return r.ReadMsgUDPAddrPort(p, oob)
   226  	}
   227  
   228  	stopc := make(chan struct{})
   229  	defer close(stopc)
   230  
   231  	go func() {
   232  		select {
   233  		case <-ctx.Done():
   234  		case <-stopc:
   235  			return
   236  		}
   237  
   238  		_ = d.SetReadDeadline(time.Unix(1, 0))
   239  	}()
   240  
   241  	n, oobn, flags, addr, err = r.ReadMsgUDPAddrPort(p, oob)
   242  
   243  	err = fixError(ctx, err)
   244  
   245  	return
   246  }
   247  
   248  func NewStoppableConn(ctx context.Context, c net.Conn) net.Conn {
   249  	return StoppableConn{
   250  		Context: ctx,
   251  		Conn:    c,
   252  	}
   253  }
   254  
   255  func (c StoppableConn) Read(p []byte) (n int, err error) {
   256  	defer stopper(c.Context, c.Conn.SetReadDeadline)()
   257  
   258  	n, err = c.Conn.Read(p)
   259  	err = fixError(c.Context, err)
   260  
   261  	return
   262  }
   263  
   264  func (c StoppableConn) Write(p []byte) (n int, err error) {
   265  	defer stopper(c.Context, c.Conn.SetWriteDeadline)()
   266  
   267  	n, err = c.Conn.Write(p)
   268  	err = fixError(c.Context, err)
   269  
   270  	return
   271  }
   272  
   273  func stopper(ctx context.Context, dead func(time.Time) error) func() {
   274  	donec := make(chan struct{})
   275  
   276  	go func() {
   277  		select {
   278  		case <-ctx.Done():
   279  		case <-donec:
   280  			return
   281  		}
   282  
   283  		_ = dead(time.Unix(1, 0))
   284  	}()
   285  
   286  	return func() {
   287  		close(donec)
   288  	}
   289  }
   290  
   291  func isTimeout(err error) bool {
   292  	to, ok := err.(interface{ Timeout() bool })
   293  
   294  	return ok && to.Timeout()
   295  }
   296  
   297  func fixError(ctx context.Context, err error) error {
   298  	if isTimeout(err) {
   299  		select {
   300  		case <-ctx.Done():
   301  			err = ctx.Err()
   302  		default:
   303  		}
   304  	}
   305  
   306  	return err
   307  }