github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/deadline/packet_reader.go (about) 1 package deadline 2 3 import ( 4 "net" 5 "os" 6 "time" 7 8 "github.com/sagernet/sing/common/atomic" 9 "github.com/sagernet/sing/common/buf" 10 M "github.com/sagernet/sing/common/metadata" 11 N "github.com/sagernet/sing/common/network" 12 ) 13 14 type TimeoutPacketReader interface { 15 N.NetPacketReader 16 SetReadDeadline(t time.Time) error 17 } 18 19 type PacketReader interface { 20 TimeoutPacketReader 21 N.WithUpstreamReader 22 N.ReaderWithUpstream 23 } 24 25 type packetReader struct { 26 TimeoutPacketReader 27 deadline atomic.TypedValue[time.Time] 28 pipeDeadline pipeDeadline 29 result chan *packetReadResult 30 done chan struct{} 31 } 32 33 type packetReadResult struct { 34 buffer *buf.Buffer 35 destination M.Socksaddr 36 err error 37 } 38 39 func NewPacketReader(timeoutReader TimeoutPacketReader) PacketReader { 40 return &packetReader{ 41 TimeoutPacketReader: timeoutReader, 42 pipeDeadline: makePipeDeadline(), 43 result: make(chan *packetReadResult, 1), 44 done: makeFilledChan(), 45 } 46 } 47 48 func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 49 select { 50 case result := <-r.result: 51 return r.pipeReturnFrom(result, p) 52 default: 53 } 54 select { 55 case result := <-r.result: 56 return r.pipeReturnFrom(result, p) 57 case <-r.pipeDeadline.wait(): 58 return 0, nil, os.ErrDeadlineExceeded 59 case <-r.done: 60 go r.pipeReadFrom(len(p)) 61 } 62 select { 63 case result := <-r.result: 64 return r.pipeReturnFrom(result, p) 65 case <-r.pipeDeadline.wait(): 66 return 0, nil, os.ErrDeadlineExceeded 67 } 68 } 69 70 func (r *packetReader) pipeReadFrom(pLen int) { 71 buffer := buf.NewSize(pLen) 72 n, addr, err := r.TimeoutPacketReader.ReadFrom(buffer.FreeBytes()) 73 buffer.Truncate(n) 74 r.result <- &packetReadResult{ 75 buffer: buffer, 76 destination: M.SocksaddrFromNet(addr), 77 err: err, 78 } 79 r.done <- struct{}{} 80 } 81 82 func (r *packetReader) pipeReturnFrom(result *packetReadResult, p []byte) (n int, addr net.Addr, err error) { 83 n = copy(p, result.buffer.Bytes()) 84 if result.destination.IsValid() { 85 if result.destination.IsFqdn() { 86 addr = result.destination 87 } else { 88 addr = result.destination.UDPAddr() 89 } 90 } 91 result.buffer.Advance(n) 92 if result.buffer.IsEmpty() { 93 result.buffer.Release() 94 err = result.err 95 } else { 96 r.result <- result 97 } 98 return 99 } 100 101 func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 102 select { 103 case result := <-r.result: 104 return r.pipeReturnFromBuffer(result, buffer) 105 default: 106 } 107 select { 108 case result := <-r.result: 109 return r.pipeReturnFromBuffer(result, buffer) 110 case <-r.pipeDeadline.wait(): 111 return M.Socksaddr{}, os.ErrDeadlineExceeded 112 case <-r.done: 113 go r.pipeReadFrom(buffer.FreeLen()) 114 } 115 select { 116 case result := <-r.result: 117 return r.pipeReturnFromBuffer(result, buffer) 118 case <-r.pipeDeadline.wait(): 119 return M.Socksaddr{}, os.ErrDeadlineExceeded 120 } 121 } 122 123 func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *buf.Buffer) (M.Socksaddr, error) { 124 n, _ := buffer.Write(result.buffer.Bytes()) 125 result.buffer.Advance(n) 126 if !result.buffer.IsEmpty() { 127 r.result <- result 128 return result.destination, nil 129 } else { 130 result.buffer.Release() 131 return result.destination, result.err 132 } 133 } 134 135 func (r *packetReader) SetReadDeadline(t time.Time) error { 136 r.deadline.Store(t) 137 r.pipeDeadline.set(t) 138 return nil 139 } 140 141 func (r *packetReader) ReaderReplaceable() bool { 142 select { 143 case <-r.done: 144 r.done <- struct{}{} 145 default: 146 return false 147 } 148 select { 149 case result := <-r.result: 150 r.result <- result 151 return false 152 default: 153 } 154 return r.deadline.Load().IsZero() 155 } 156 157 func (r *packetReader) UpstreamReader() any { 158 return r.TimeoutPacketReader 159 }