github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/network/web_agent.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/15mga/kiwi"
     7  	"github.com/15mga/kiwi/ds"
     8  	"time"
     9  
    10  	"github.com/15mga/kiwi/util"
    11  	"github.com/fasthttp/websocket"
    12  )
    13  
    14  func NewWebAgent(addr string, msgType int, receiver kiwi.FnAgentBytes, options ...kiwi.AgentOption) *webAgent {
    15  	return &webAgent{
    16  		agent:   newAgent(addr, receiver, options...),
    17  		msgType: msgType,
    18  	}
    19  }
    20  
    21  type webAgent struct {
    22  	agent
    23  	msgType int
    24  	conn    *websocket.Conn
    25  }
    26  
    27  func (a *webAgent) Start(ctx context.Context, conn *websocket.Conn) {
    28  	a.conn = conn
    29  	a.onClose = a.conn.Close
    30  	a.start(ctx)
    31  	switch a.option.AgentMode {
    32  	case kiwi.AgentRW:
    33  		go a.read()
    34  		go a.write()
    35  	case kiwi.AgentR:
    36  		go a.read()
    37  	case kiwi.AgentW:
    38  		go a.write()
    39  	}
    40  }
    41  
    42  func (a *webAgent) read() {
    43  	var err *util.Err
    44  	defer func() {
    45  		r := recover()
    46  		if r != nil {
    47  			kiwi.Error2(util.EcRecover, util.M{
    48  				"remote addr": a.conn.RemoteAddr().String(),
    49  				"recover":     fmt.Sprintf("%s", r),
    50  			})
    51  			a.read()
    52  			return
    53  		}
    54  		a.close(err)
    55  	}()
    56  
    57  	dur := time.Duration(a.option.DeadlineSecs)
    58  	c := a.conn
    59  	c.SetReadLimit(int64(a.option.PacketMaxCap))
    60  	for {
    61  		select {
    62  		case <-a.ctx.Done():
    63  			return
    64  		default:
    65  			if dur > 0 {
    66  				_ = a.conn.SetReadDeadline(time.Now().Add(time.Second * dur))
    67  			}
    68  			mt, bytes, e := c.ReadMessage()
    69  			if e != nil {
    70  				err = util.WrapErr(util.EcIo, e)
    71  				return
    72  			}
    73  			if mt != a.msgType {
    74  				err = util.NewErr(util.EcWrongType, util.M{
    75  					"receive message type": mt,
    76  					"need message type":    a.msgType,
    77  				})
    78  				return
    79  			}
    80  			newLen := uint32(len(bytes))
    81  			if newLen == 0 {
    82  				break
    83  			}
    84  			a.receiver(a, bytes)
    85  		}
    86  	}
    87  }
    88  
    89  func (a *webAgent) write() {
    90  	var (
    91  		err *util.Err
    92  	)
    93  	defer func() {
    94  		a.close(err)
    95  	}()
    96  
    97  	c := a.conn
    98  	msgType := a.msgType
    99  	for {
   100  		select {
   101  		case <-a.ctx.Done():
   102  			return
   103  		case <-a.writeSignCh:
   104  			var elem *ds.LinkElem[[]byte]
   105  			a.enable.Mtx.Lock()
   106  			if a.enable.Disabled() {
   107  				a.enable.Mtx.Unlock()
   108  				return
   109  			}
   110  			elem = a.bytesLink.PopAll()
   111  			a.enable.Mtx.Unlock()
   112  			if elem == nil {
   113  				continue
   114  			}
   115  
   116  			for ; elem != nil; elem = elem.Next {
   117  				bytes := elem.Value
   118  				e := c.WriteMessage(msgType, bytes)
   119  				util.RecycleBytes(bytes)
   120  				if e != nil {
   121  					err = util.WrapErr(util.EcIo, e)
   122  					return
   123  				}
   124  			}
   125  		}
   126  	}
   127  }