github.com/nikandfor/hacked@v0.0.0-20231207014854-3b383967fdf4/hnet/listen.go (about) 1 package hnet 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "net/netip" 8 "time" 9 ) 10 11 type ( 12 StoppableListener struct { 13 context.Context 14 net.Listener 15 } 16 17 StoppableConn struct { 18 context.Context 19 net.Conn 20 } 21 22 ReaderFrom interface { 23 ReadFrom(p []byte) (int, net.Addr, error) 24 } 25 26 ReaderFromUDP interface { 27 ReadFromUDP(p []byte) (int, *net.UDPAddr, error) 28 } 29 30 ReaderFromUDPAddrPort interface { 31 ReadFromUDPAddrPort(p []byte) (int, netip.AddrPort, error) 32 } 33 34 ReaderMsgUDP interface { 35 ReadMsgUDP(p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) 36 } 37 38 ReaderMsgUDPAddrPort interface { 39 ReadMsgUDPAddrPort(p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) 40 } 41 ) 42 43 func Accept(ctx context.Context, l net.Listener) (net.Conn, error) { 44 d, ok := l.(interface { 45 SetDeadline(time.Time) error 46 }) 47 48 if !ok { 49 return l.Accept() 50 } 51 52 stopc := make(chan struct{}) 53 defer close(stopc) 54 55 go func() { 56 select { 57 case <-ctx.Done(): 58 case <-stopc: 59 return 60 } 61 62 _ = d.SetDeadline(time.Unix(1, 0)) 63 }() 64 65 c, err := l.Accept() 66 if c != nil || !isTimeout(err) { 67 return c, err 68 } 69 70 select { 71 case <-ctx.Done(): 72 err = ctx.Err() 73 default: 74 } 75 76 return nil, err 77 } 78 79 func Read(ctx context.Context, r io.Reader, p []byte) (int, error) { 80 d, ok := r.(interface { 81 SetReadDeadline(time.Time) error 82 }) 83 84 if !ok { 85 return r.Read(p) 86 } 87 88 stopc := make(chan struct{}) 89 defer close(stopc) 90 91 go func() { 92 select { 93 case <-ctx.Done(): 94 case <-stopc: 95 return 96 } 97 98 _ = d.SetReadDeadline(time.Unix(1, 0)) 99 }() 100 101 n, err := r.Read(p) 102 103 err = fixError(ctx, err) 104 105 return n, err 106 } 107 108 func ReadFrom(ctx context.Context, r ReaderFrom, p []byte) (int, net.Addr, error) { 109 d, ok := r.(interface { 110 SetReadDeadline(time.Time) error 111 }) 112 113 if !ok { 114 return r.ReadFrom(p) 115 } 116 117 stopc := make(chan struct{}) 118 defer close(stopc) 119 120 go func() { 121 select { 122 case <-ctx.Done(): 123 case <-stopc: 124 return 125 } 126 127 _ = d.SetReadDeadline(time.Unix(1, 0)) 128 }() 129 130 n, addr, err := r.ReadFrom(p) 131 132 err = fixError(ctx, err) 133 134 return n, addr, err 135 } 136 137 func ReadFromUDP(ctx context.Context, r ReaderFromUDP, p []byte) (int, *net.UDPAddr, error) { 138 d, ok := r.(interface { 139 SetReadDeadline(time.Time) error 140 }) 141 142 if !ok { 143 return r.ReadFromUDP(p) 144 } 145 146 stopc := make(chan struct{}) 147 defer close(stopc) 148 149 go func() { 150 select { 151 case <-ctx.Done(): 152 case <-stopc: 153 return 154 } 155 156 _ = d.SetReadDeadline(time.Unix(1, 0)) 157 }() 158 159 n, addr, err := r.ReadFromUDP(p) 160 161 err = fixError(ctx, err) 162 163 return n, addr, err 164 } 165 166 func ReadFromUDPAddrPort(ctx context.Context, r ReaderFromUDPAddrPort, p []byte) (int, netip.AddrPort, error) { 167 d, ok := r.(interface { 168 SetReadDeadline(time.Time) error 169 }) 170 171 if !ok { 172 return r.ReadFromUDPAddrPort(p) 173 } 174 175 stopc := make(chan struct{}) 176 defer close(stopc) 177 178 go func() { 179 select { 180 case <-ctx.Done(): 181 case <-stopc: 182 return 183 } 184 185 _ = d.SetReadDeadline(time.Unix(1, 0)) 186 }() 187 188 n, addr, err := r.ReadFromUDPAddrPort(p) 189 190 err = fixError(ctx, err) 191 192 return n, addr, err 193 } 194 195 func ReadMsgUDP(ctx context.Context, r ReaderMsgUDP, p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { 196 d, ok := r.(interface { 197 SetReadDeadline(time.Time) error 198 }) 199 200 if !ok { 201 return r.ReadMsgUDP(p, oob) 202 } 203 204 stopc := make(chan struct{}) 205 defer close(stopc) 206 207 go func() { 208 select { 209 case <-ctx.Done(): 210 case <-stopc: 211 return 212 } 213 214 _ = d.SetReadDeadline(time.Unix(1, 0)) 215 }() 216 217 n, oobn, flags, addr, err = r.ReadMsgUDP(p, oob) 218 219 err = fixError(ctx, err) 220 221 return 222 } 223 224 func ReadMsgUDPAddrPort(ctx context.Context, r ReaderMsgUDPAddrPort, p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { 225 d, ok := r.(interface { 226 SetReadDeadline(time.Time) error 227 }) 228 229 if !ok { 230 return r.ReadMsgUDPAddrPort(p, oob) 231 } 232 233 stopc := make(chan struct{}) 234 defer close(stopc) 235 236 go func() { 237 select { 238 case <-ctx.Done(): 239 case <-stopc: 240 return 241 } 242 243 _ = d.SetReadDeadline(time.Unix(1, 0)) 244 }() 245 246 n, oobn, flags, addr, err = r.ReadMsgUDPAddrPort(p, oob) 247 248 err = fixError(ctx, err) 249 250 return 251 } 252 253 func NewStoppableListener(ctx context.Context, l net.Listener) net.Listener { 254 return StoppableListener{ 255 Context: ctx, 256 Listener: l, 257 } 258 } 259 260 func (l StoppableListener) Accept() (net.Conn, error) { 261 return Accept(l.Context, l.Listener) 262 } 263 264 func NewStoppableConn(ctx context.Context, c net.Conn) net.Conn { 265 return StoppableConn{ 266 Context: ctx, 267 Conn: c, 268 } 269 } 270 271 func (c StoppableConn) Read(p []byte) (n int, err error) { 272 defer stopper(c.Context, c.Conn.SetReadDeadline)() 273 274 n, err = c.Conn.Read(p) 275 err = fixError(c.Context, err) 276 277 return 278 } 279 280 func (c StoppableConn) Write(p []byte) (n int, err error) { 281 defer stopper(c.Context, c.Conn.SetWriteDeadline)() 282 283 n, err = c.Conn.Write(p) 284 err = fixError(c.Context, err) 285 286 return 287 } 288 289 func stopper(ctx context.Context, dead func(time.Time) error) func() { 290 donec := make(chan struct{}) 291 292 go func() { 293 select { 294 case <-ctx.Done(): 295 case <-donec: 296 return 297 } 298 299 _ = dead(time.Unix(1, 0)) 300 }() 301 302 return func() { 303 close(donec) 304 } 305 } 306 307 func isTimeout(err error) bool { 308 to, ok := err.(interface{ Timeout() bool }) 309 310 return ok && to.Timeout() 311 } 312 313 func fixError(ctx context.Context, err error) error { 314 if isTimeout(err) { 315 select { 316 case <-ctx.Done(): 317 err = ctx.Err() 318 default: 319 } 320 } 321 322 return err 323 }