github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/network/tcp_agent.go (about) 1 package network 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/15mga/kiwi" 7 "net" 8 "time" 9 10 "github.com/15mga/kiwi/ds" 11 12 "github.com/15mga/kiwi/util" 13 ) 14 15 // NewTcpAgent receiver接收字节如果使用异步方式需要copy一份,否则数据会被覆盖 16 func NewTcpAgent(addr string, receiver kiwi.FnAgentBytes, options ...kiwi.AgentOption) *tcpAgent { 17 ta := &tcpAgent{ 18 agent: newAgent(addr, receiver, options...), 19 } 20 switch ta.option.HeadLen { 21 case 2: 22 ta.headReader = func(bytes []byte) int { 23 return int(bytes[0])<<8 | int(bytes[1]) 24 } 25 ta.headWriter = func(buffer *util.ByteBuffer, bytes []byte) { 26 buffer.WUint16(uint16(len(bytes))) 27 } 28 case 4: 29 ta.headReader = func(bytes []byte) int { 30 return int(bytes[0])<<24 | int(bytes[1])<<16 | int(bytes[2])<<8 | int(bytes[3]) 31 } 32 ta.headWriter = func(buffer *util.ByteBuffer, bytes []byte) { 33 buffer.WUint32(uint32(len(bytes))) 34 } 35 default: 36 panic("wrong head length") 37 } 38 return ta 39 } 40 41 type tcpAgent struct { 42 agent 43 conn net.Conn 44 headReader util.BytesToInt 45 headWriter func(buffer *util.ByteBuffer, bytes []byte) 46 } 47 48 func (a *tcpAgent) Start(ctx context.Context, conn net.Conn) { 49 a.conn = conn 50 a.onClose = a.conn.Close 51 a.start(ctx) 52 switch a.option.AgentMode { 53 case kiwi.AgentRW: 54 go a.read() 55 go a.write() 56 case kiwi.AgentR: 57 go a.read() 58 case kiwi.AgentW: 59 go a.write() 60 } 61 } 62 63 func (a *tcpAgent) read() { 64 var ( 65 buffer = make([]byte, a.option.PacketMinCap) 66 ringBuffer = newRing(a.option.PacketMinCap, a.option.PacketMaxCap) 67 pkgLen int 68 err *util.Err 69 headLen = a.option.HeadLen 70 headReader = a.headReader 71 dur = time.Duration(a.option.DeadlineSecs) 72 ) 73 defer func() { 74 r := recover() 75 if r != nil { 76 kiwi.Error2(util.EcRecover, util.M{ 77 "remote addr": a.conn.RemoteAddr().String(), 78 "recover": fmt.Sprintf("%s", r), 79 }) 80 a.read() 81 return 82 } 83 a.close(err) 84 }() 85 86 for { 87 select { 88 case <-a.ctx.Done(): 89 return 90 default: 91 if dur > 0 { 92 _ = a.conn.SetReadDeadline(time.Now().Add(time.Second * dur)) 93 } 94 newLen, e := a.conn.Read(buffer) 95 if e != nil { 96 err = util.WrapErr(util.EcIo, e) 97 return 98 } 99 err = ringBuffer.Put(buffer[:newLen]) 100 if err != nil { 101 return 102 } 103 for { 104 if pkgLen == 0 { 105 if ringBuffer.Available() < headLen { 106 break 107 } 108 _ = ringBuffer.Read(buffer, headLen) 109 pkgLen = headReader(buffer) 110 if pkgLen == 0 { 111 err = util.NewErr(util.EcBadHead, nil) 112 return 113 } 114 } 115 if ringBuffer.Available() < pkgLen { 116 break 117 } 118 _ = ringBuffer.Read(buffer, pkgLen) 119 //log.Debug("receive", util.M{ 120 // "len": pkgLen, 121 // "hex": util.Hex(buffer[:pkgLen]), 122 //}) 123 a.receiver(a, buffer[:pkgLen]) 124 pkgLen = 0 125 } 126 } 127 } 128 } 129 130 func (a *tcpAgent) write() { 131 var ( 132 err *util.Err 133 ) 134 defer func() { 135 a.close(err) 136 }() 137 138 headWriter := a.headWriter 139 140 for { 141 select { 142 case <-a.ctx.Done(): 143 return 144 case <-a.writeSignCh: 145 var elem *ds.LinkElem[[]byte] 146 a.enable.Mtx.Lock() 147 if a.enable.Disabled() { 148 a.enable.Mtx.Unlock() 149 return 150 } 151 elem = a.bytesLink.PopAll() 152 a.enable.Mtx.Unlock() 153 if elem == nil { 154 continue 155 } 156 157 for ; elem != nil; elem = elem.Next { 158 bytes := elem.Value 159 //log.Debug("send", util.M{ 160 // "len": len(bytes), 161 // "hex": util.Hex(bytes), 162 //}) 163 var buffer util.ByteBuffer 164 buffer.InitCap(len(bytes) + a.option.HeadLen) 165 headWriter(&buffer, bytes) 166 _, _ = buffer.Write(bytes) 167 _, e := a.conn.Write(buffer.All()) 168 util.RecycleBytes(bytes) 169 buffer.Dispose() 170 if e != nil { 171 err = util.WrapErr(util.EcIo, e) 172 return 173 } 174 } 175 } 176 } 177 } 178 179 func newRing(minCap, maxCap int) *ring { 180 r := &ring{ 181 buffer: make([]byte, minCap), 182 bufferCap: minCap, 183 halfBuffCap: minCap >> 1, 184 minCap: minCap, 185 maxCap: maxCap, 186 shrink: 64, 187 shrinkCount: 64, 188 } 189 r.defVal = r.buffer[0] 190 return r 191 } 192 193 type ring struct { 194 defVal byte 195 available int 196 readIdx int 197 writeIdx int 198 buffer []byte 199 bufferCap int 200 minCap int 201 maxCap int 202 halfBuffCap int 203 shrink int 204 shrinkCount int 205 } 206 207 func (r *ring) Available() int { 208 return r.available 209 } 210 211 func (r *ring) testCap(c int) *util.Err { 212 if c > r.bufferCap { 213 c, ok := util.NextCap(c, r.bufferCap, 2048) 214 if ok { 215 if r.maxCap > 0 && c >= r.maxCap { 216 return util.NewErr(util.EcTooLong, util.M{ 217 "total": c, 218 }) 219 } 220 r.resetBuffer(c) 221 } 222 return nil 223 } 224 if r.minCap == r.bufferCap { 225 return nil 226 } 227 if c > r.halfBuffCap { 228 r.shrink = r.shrinkCount 229 return nil 230 } 231 r.shrink-- 232 if r.shrink > 0 { 233 return nil 234 } 235 r.resetBuffer(r.halfBuffCap) 236 return nil 237 } 238 239 func (r *ring) resetBuffer(cap int) { 240 buf := make([]byte, cap) 241 if r.available > 0 { 242 if r.writeIdx > r.readIdx { 243 copy(buf, r.buffer[r.readIdx:r.writeIdx]) 244 } else { 245 n := copy(buf, r.buffer[r.readIdx:]) 246 copy(buf[n:], r.buffer[:r.writeIdx]) 247 } 248 } 249 r.writeIdx = r.available 250 r.readIdx = 0 251 r.bufferCap = cap 252 r.halfBuffCap = cap >> 1 253 r.buffer = buf 254 r.shrink = r.shrinkCount 255 r.buffer = make([]byte, cap) 256 } 257 258 func (r *ring) Put(items []byte) *util.Err { 259 l := len(items) 260 c := r.available + l 261 err := r.testCap(c) 262 if err != nil { 263 return err 264 } 265 r.available = c 266 i := r.writeIdx + l 267 if i <= r.bufferCap { 268 copy(r.buffer[r.writeIdx:], items) 269 r.writeIdx = i 270 } else { 271 copy(r.buffer[r.writeIdx:r.bufferCap], items) 272 j := r.bufferCap - r.writeIdx 273 copy(r.buffer, items[j:l]) 274 r.writeIdx = l - j 275 } 276 return nil 277 } 278 279 func (r *ring) Read(s []byte, l int) *util.Err { 280 sl := len(s) 281 if l > sl || l > r.available { 282 return util.NewErr(util.EcNotEnough, util.M{ 283 "length": l, 284 "slice": sl, 285 "available": r.available, 286 }) 287 } 288 r.read(s, l) 289 return nil 290 } 291 292 func (r *ring) read(s []byte, l int) { 293 p := r.readIdx + l 294 if p < r.bufferCap { 295 copy(s, r.buffer[r.readIdx:p]) 296 r.readIdx = p 297 } else { 298 p -= r.bufferCap 299 copy(s, r.buffer[r.readIdx:r.bufferCap]) 300 copy(s[r.bufferCap-r.readIdx:], r.buffer[:p]) 301 r.readIdx = p 302 } 303 r.available -= l 304 }