github.com/nikandfor/hacked@v0.0.0-20230429073333-a318d546207a/hnet/listen.go (about) 1 package hnet 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "net/netip" 8 "time" 9 ) 10 11 type ( 12 StoppableConn struct { 13 context.Context 14 net.Conn 15 } 16 17 ReaderFrom interface { 18 ReadFrom(p []byte) (int, net.Addr, error) 19 } 20 21 ReaderFromUDP interface { 22 ReadFromUDP(p []byte) (int, *net.UDPAddr, error) 23 } 24 25 ReaderFromUDPAddrPort interface { 26 ReadFromUDPAddrPort(p []byte) (int, netip.AddrPort, error) 27 } 28 29 ReaderMsgUDP interface { 30 ReadMsgUDP(p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) 31 } 32 33 ReaderMsgUDPAddrPort interface { 34 ReadMsgUDPAddrPort(p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) 35 } 36 ) 37 38 func Accept(ctx context.Context, l net.Listener) (net.Conn, error) { 39 d, ok := l.(interface { 40 SetDeadline(time.Time) error 41 }) 42 43 if !ok { 44 return l.Accept() 45 } 46 47 stopc := make(chan struct{}) 48 defer close(stopc) 49 50 go func() { 51 select { 52 case <-ctx.Done(): 53 case <-stopc: 54 return 55 } 56 57 _ = d.SetDeadline(time.Unix(1, 0)) 58 }() 59 60 c, err := l.Accept() 61 if c != nil || !isTimeout(err) { 62 return c, err 63 } 64 65 select { 66 case <-ctx.Done(): 67 err = ctx.Err() 68 default: 69 } 70 71 return nil, err 72 } 73 74 func Read(ctx context.Context, r io.Reader, p []byte) (int, error) { 75 d, ok := r.(interface { 76 SetReadDeadline(time.Time) error 77 }) 78 79 if !ok { 80 return r.Read(p) 81 } 82 83 stopc := make(chan struct{}) 84 defer close(stopc) 85 86 go func() { 87 select { 88 case <-ctx.Done(): 89 case <-stopc: 90 return 91 } 92 93 _ = d.SetReadDeadline(time.Unix(1, 0)) 94 }() 95 96 n, err := r.Read(p) 97 98 err = fixError(ctx, err) 99 100 return n, err 101 } 102 103 func ReadFrom(ctx context.Context, r ReaderFrom, p []byte) (int, net.Addr, error) { 104 d, ok := r.(interface { 105 SetReadDeadline(time.Time) error 106 }) 107 108 if !ok { 109 return r.ReadFrom(p) 110 } 111 112 stopc := make(chan struct{}) 113 defer close(stopc) 114 115 go func() { 116 select { 117 case <-ctx.Done(): 118 case <-stopc: 119 return 120 } 121 122 _ = d.SetReadDeadline(time.Unix(1, 0)) 123 }() 124 125 n, addr, err := r.ReadFrom(p) 126 127 err = fixError(ctx, err) 128 129 return n, addr, err 130 } 131 132 func ReadFromUDP(ctx context.Context, r ReaderFromUDP, p []byte) (int, *net.UDPAddr, error) { 133 d, ok := r.(interface { 134 SetReadDeadline(time.Time) error 135 }) 136 137 if !ok { 138 return r.ReadFromUDP(p) 139 } 140 141 stopc := make(chan struct{}) 142 defer close(stopc) 143 144 go func() { 145 select { 146 case <-ctx.Done(): 147 case <-stopc: 148 return 149 } 150 151 _ = d.SetReadDeadline(time.Unix(1, 0)) 152 }() 153 154 n, addr, err := r.ReadFromUDP(p) 155 156 err = fixError(ctx, err) 157 158 return n, addr, err 159 } 160 161 func ReadFromUDPAddrPort(ctx context.Context, r ReaderFromUDPAddrPort, p []byte) (int, netip.AddrPort, error) { 162 d, ok := r.(interface { 163 SetReadDeadline(time.Time) error 164 }) 165 166 if !ok { 167 return r.ReadFromUDPAddrPort(p) 168 } 169 170 stopc := make(chan struct{}) 171 defer close(stopc) 172 173 go func() { 174 select { 175 case <-ctx.Done(): 176 case <-stopc: 177 return 178 } 179 180 _ = d.SetReadDeadline(time.Unix(1, 0)) 181 }() 182 183 n, addr, err := r.ReadFromUDPAddrPort(p) 184 185 err = fixError(ctx, err) 186 187 return n, addr, err 188 } 189 190 func ReadMsgUDP(ctx context.Context, r ReaderMsgUDP, p, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { 191 d, ok := r.(interface { 192 SetReadDeadline(time.Time) error 193 }) 194 195 if !ok { 196 return r.ReadMsgUDP(p, oob) 197 } 198 199 stopc := make(chan struct{}) 200 defer close(stopc) 201 202 go func() { 203 select { 204 case <-ctx.Done(): 205 case <-stopc: 206 return 207 } 208 209 _ = d.SetReadDeadline(time.Unix(1, 0)) 210 }() 211 212 n, oobn, flags, addr, err = r.ReadMsgUDP(p, oob) 213 214 err = fixError(ctx, err) 215 216 return 217 } 218 219 func ReadMsgUDPAddrPort(ctx context.Context, r ReaderMsgUDPAddrPort, p, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { 220 d, ok := r.(interface { 221 SetReadDeadline(time.Time) error 222 }) 223 224 if !ok { 225 return r.ReadMsgUDPAddrPort(p, oob) 226 } 227 228 stopc := make(chan struct{}) 229 defer close(stopc) 230 231 go func() { 232 select { 233 case <-ctx.Done(): 234 case <-stopc: 235 return 236 } 237 238 _ = d.SetReadDeadline(time.Unix(1, 0)) 239 }() 240 241 n, oobn, flags, addr, err = r.ReadMsgUDPAddrPort(p, oob) 242 243 err = fixError(ctx, err) 244 245 return 246 } 247 248 func NewStoppableConn(ctx context.Context, c net.Conn) net.Conn { 249 return StoppableConn{ 250 Context: ctx, 251 Conn: c, 252 } 253 } 254 255 func (c StoppableConn) Read(p []byte) (n int, err error) { 256 defer stopper(c.Context, c.Conn.SetReadDeadline)() 257 258 n, err = c.Conn.Read(p) 259 err = fixError(c.Context, err) 260 261 return 262 } 263 264 func (c StoppableConn) Write(p []byte) (n int, err error) { 265 defer stopper(c.Context, c.Conn.SetWriteDeadline)() 266 267 n, err = c.Conn.Write(p) 268 err = fixError(c.Context, err) 269 270 return 271 } 272 273 func stopper(ctx context.Context, dead func(time.Time) error) func() { 274 donec := make(chan struct{}) 275 276 go func() { 277 select { 278 case <-ctx.Done(): 279 case <-donec: 280 return 281 } 282 283 _ = dead(time.Unix(1, 0)) 284 }() 285 286 return func() { 287 close(donec) 288 } 289 } 290 291 func isTimeout(err error) bool { 292 to, ok := err.(interface{ Timeout() bool }) 293 294 return ok && to.Timeout() 295 } 296 297 func fixError(ctx context.Context, err error) error { 298 if isTimeout(err) { 299 select { 300 case <-ctx.Done(): 301 err = ctx.Err() 302 default: 303 } 304 } 305 306 return err 307 }