github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/network/tcp_agent.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/15mga/kiwi"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/15mga/kiwi/ds"
    11  
    12  	"github.com/15mga/kiwi/util"
    13  )
    14  
    15  // NewTcpAgent receiver接收字节如果使用异步方式需要copy一份,否则数据会被覆盖
    16  func NewTcpAgent(addr string, receiver kiwi.FnAgentBytes, options ...kiwi.AgentOption) *tcpAgent {
    17  	ta := &tcpAgent{
    18  		agent: newAgent(addr, receiver, options...),
    19  	}
    20  	switch ta.option.HeadLen {
    21  	case 2:
    22  		ta.headReader = func(bytes []byte) int {
    23  			return int(bytes[0])<<8 | int(bytes[1])
    24  		}
    25  		ta.headWriter = func(buffer *util.ByteBuffer, bytes []byte) {
    26  			buffer.WUint16(uint16(len(bytes)))
    27  		}
    28  	case 4:
    29  		ta.headReader = func(bytes []byte) int {
    30  			return int(bytes[0])<<24 | int(bytes[1])<<16 | int(bytes[2])<<8 | int(bytes[3])
    31  		}
    32  		ta.headWriter = func(buffer *util.ByteBuffer, bytes []byte) {
    33  			buffer.WUint32(uint32(len(bytes)))
    34  		}
    35  	default:
    36  		panic("wrong head length")
    37  	}
    38  	return ta
    39  }
    40  
    41  type tcpAgent struct {
    42  	agent
    43  	conn       net.Conn
    44  	headReader util.BytesToInt
    45  	headWriter func(buffer *util.ByteBuffer, bytes []byte)
    46  }
    47  
    48  func (a *tcpAgent) Start(ctx context.Context, conn net.Conn) {
    49  	a.conn = conn
    50  	a.onClose = a.conn.Close
    51  	a.start(ctx)
    52  	switch a.option.AgentMode {
    53  	case kiwi.AgentRW:
    54  		go a.read()
    55  		go a.write()
    56  	case kiwi.AgentR:
    57  		go a.read()
    58  	case kiwi.AgentW:
    59  		go a.write()
    60  	}
    61  }
    62  
    63  func (a *tcpAgent) read() {
    64  	var (
    65  		buffer     = make([]byte, a.option.PacketMinCap)
    66  		ringBuffer = newRing(a.option.PacketMinCap, a.option.PacketMaxCap)
    67  		pkgLen     int
    68  		err        *util.Err
    69  		headLen    = a.option.HeadLen
    70  		headReader = a.headReader
    71  		dur        = time.Duration(a.option.DeadlineSecs)
    72  	)
    73  	defer func() {
    74  		r := recover()
    75  		if r != nil {
    76  			kiwi.Error2(util.EcRecover, util.M{
    77  				"remote addr": a.conn.RemoteAddr().String(),
    78  				"recover":     fmt.Sprintf("%s", r),
    79  			})
    80  			a.read()
    81  			return
    82  		}
    83  		a.close(err)
    84  	}()
    85  
    86  	for {
    87  		select {
    88  		case <-a.ctx.Done():
    89  			return
    90  		default:
    91  			if dur > 0 {
    92  				_ = a.conn.SetReadDeadline(time.Now().Add(time.Second * dur))
    93  			}
    94  			newLen, e := a.conn.Read(buffer)
    95  			if e != nil {
    96  				err = util.WrapErr(util.EcIo, e)
    97  				return
    98  			}
    99  			err = ringBuffer.Put(buffer[:newLen])
   100  			if err != nil {
   101  				return
   102  			}
   103  			for {
   104  				if pkgLen == 0 {
   105  					if ringBuffer.Available() < headLen {
   106  						break
   107  					}
   108  					_ = ringBuffer.Read(buffer, headLen)
   109  					pkgLen = headReader(buffer)
   110  					if pkgLen == 0 {
   111  						err = util.NewErr(util.EcBadHead, nil)
   112  						return
   113  					}
   114  				}
   115  				if ringBuffer.Available() < pkgLen {
   116  					break
   117  				}
   118  				_ = ringBuffer.Read(buffer, pkgLen)
   119  				//log.Debug("receive", util.M{
   120  				//	"len": pkgLen,
   121  				//	"hex": util.Hex(buffer[:pkgLen]),
   122  				//})
   123  				a.receiver(a, buffer[:pkgLen])
   124  				pkgLen = 0
   125  			}
   126  		}
   127  	}
   128  }
   129  
   130  func (a *tcpAgent) write() {
   131  	var (
   132  		err *util.Err
   133  	)
   134  	defer func() {
   135  		a.close(err)
   136  	}()
   137  
   138  	headWriter := a.headWriter
   139  
   140  	for {
   141  		select {
   142  		case <-a.ctx.Done():
   143  			return
   144  		case <-a.writeSignCh:
   145  			var elem *ds.LinkElem[[]byte]
   146  			a.enable.Mtx.Lock()
   147  			if a.enable.Disabled() {
   148  				a.enable.Mtx.Unlock()
   149  				return
   150  			}
   151  			elem = a.bytesLink.PopAll()
   152  			a.enable.Mtx.Unlock()
   153  			if elem == nil {
   154  				continue
   155  			}
   156  
   157  			for ; elem != nil; elem = elem.Next {
   158  				bytes := elem.Value
   159  				//log.Debug("send", util.M{
   160  				//	"len": len(bytes),
   161  				//	"hex": util.Hex(bytes),
   162  				//})
   163  				var buffer util.ByteBuffer
   164  				buffer.InitCap(len(bytes) + a.option.HeadLen)
   165  				headWriter(&buffer, bytes)
   166  				_, _ = buffer.Write(bytes)
   167  				_, e := a.conn.Write(buffer.All())
   168  				util.RecycleBytes(bytes)
   169  				buffer.Dispose()
   170  				if e != nil {
   171  					err = util.WrapErr(util.EcIo, e)
   172  					return
   173  				}
   174  			}
   175  		}
   176  	}
   177  }
   178  
   179  func newRing(minCap, maxCap int) *ring {
   180  	r := &ring{
   181  		buffer:      make([]byte, minCap),
   182  		bufferCap:   minCap,
   183  		halfBuffCap: minCap >> 1,
   184  		minCap:      minCap,
   185  		maxCap:      maxCap,
   186  		shrink:      64,
   187  		shrinkCount: 64,
   188  	}
   189  	r.defVal = r.buffer[0]
   190  	return r
   191  }
   192  
   193  type ring struct {
   194  	defVal      byte
   195  	available   int
   196  	readIdx     int
   197  	writeIdx    int
   198  	buffer      []byte
   199  	bufferCap   int
   200  	minCap      int
   201  	maxCap      int
   202  	halfBuffCap int
   203  	shrink      int
   204  	shrinkCount int
   205  }
   206  
   207  func (r *ring) Available() int {
   208  	return r.available
   209  }
   210  
   211  func (r *ring) testCap(c int) *util.Err {
   212  	if c > r.bufferCap {
   213  		c, ok := util.NextCap(c, r.bufferCap, 2048)
   214  		if ok {
   215  			if r.maxCap > 0 && c >= r.maxCap {
   216  				return util.NewErr(util.EcTooLong, util.M{
   217  					"total": c,
   218  				})
   219  			}
   220  			r.resetBuffer(c)
   221  		}
   222  		return nil
   223  	}
   224  	if r.minCap == r.bufferCap {
   225  		return nil
   226  	}
   227  	if c > r.halfBuffCap {
   228  		r.shrink = r.shrinkCount
   229  		return nil
   230  	}
   231  	r.shrink--
   232  	if r.shrink > 0 {
   233  		return nil
   234  	}
   235  	r.resetBuffer(r.halfBuffCap)
   236  	return nil
   237  }
   238  
   239  func (r *ring) resetBuffer(cap int) {
   240  	buf := make([]byte, cap)
   241  	if r.available > 0 {
   242  		if r.writeIdx > r.readIdx {
   243  			copy(buf, r.buffer[r.readIdx:r.writeIdx])
   244  		} else {
   245  			n := copy(buf, r.buffer[r.readIdx:])
   246  			copy(buf[n:], r.buffer[:r.writeIdx])
   247  		}
   248  	}
   249  	r.writeIdx = r.available
   250  	r.readIdx = 0
   251  	r.bufferCap = cap
   252  	r.halfBuffCap = cap >> 1
   253  	r.buffer = buf
   254  	r.shrink = r.shrinkCount
   255  	r.buffer = make([]byte, cap)
   256  }
   257  
   258  func (r *ring) Put(items []byte) *util.Err {
   259  	l := len(items)
   260  	c := r.available + l
   261  	err := r.testCap(c)
   262  	if err != nil {
   263  		return err
   264  	}
   265  	r.available = c
   266  	i := r.writeIdx + l
   267  	if i <= r.bufferCap {
   268  		copy(r.buffer[r.writeIdx:], items)
   269  		r.writeIdx = i
   270  	} else {
   271  		copy(r.buffer[r.writeIdx:r.bufferCap], items)
   272  		j := r.bufferCap - r.writeIdx
   273  		copy(r.buffer, items[j:l])
   274  		r.writeIdx = l - j
   275  	}
   276  	return nil
   277  }
   278  
   279  func (r *ring) Read(s []byte, l int) *util.Err {
   280  	sl := len(s)
   281  	if l > sl || l > r.available {
   282  		return util.NewErr(util.EcNotEnough, util.M{
   283  			"length":    l,
   284  			"slice":     sl,
   285  			"available": r.available,
   286  		})
   287  	}
   288  	r.read(s, l)
   289  	return nil
   290  }
   291  
   292  func (r *ring) read(s []byte, l int) {
   293  	p := r.readIdx + l
   294  	if p < r.bufferCap {
   295  		copy(s, r.buffer[r.readIdx:p])
   296  		r.readIdx = p
   297  	} else {
   298  		p -= r.bufferCap
   299  		copy(s, r.buffer[r.readIdx:r.bufferCap])
   300  		copy(s[r.bufferCap-r.readIdx:], r.buffer[:p])
   301  		r.readIdx = p
   302  	}
   303  	r.available -= l
   304  }