github.com/glide-im/glide@v1.6.0/pkg/gate/user_client.go (about)

     1  package gate
     2  
     3  import (
     4  	"errors"
     5  	"github.com/glide-im/glide/pkg/conn"
     6  	"github.com/glide-im/glide/pkg/logger"
     7  	"github.com/glide-im/glide/pkg/messages"
     8  	"github.com/glide-im/glide/pkg/timingwheel"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  )
    13  
    14  // tw is a timer for heartbeat.
    15  var tw = timingwheel.NewTimingWheel(time.Millisecond*500, 3, 20)
    16  
    17  const (
    18  	defaultServerHeartbeatDuration = time.Second * 30
    19  	defaultHeartbeatDuration       = time.Second * 20
    20  	defaultHeartbeatLostLimit      = 3
    21  	defaultCloseImmediately        = false
    22  )
    23  
    24  // client state
    25  const (
    26  	_ int32 = iota
    27  	// stateRunning client is running, can runRead and runWrite message.
    28  	stateRunning
    29  	// stateClosed client is closed, cannot do anything.
    30  	stateClosed
    31  )
    32  
    33  // ClientConfig client config
    34  type ClientConfig struct {
    35  
    36  	// ClientHeartbeatDuration is the duration of heartbeat.
    37  	ClientHeartbeatDuration time.Duration
    38  
    39  	// ServerHeartbeatDuration is the duration of server heartbeat.
    40  	ServerHeartbeatDuration time.Duration
    41  
    42  	// HeartbeatLostLimit is the max lost heartbeat count.
    43  	HeartbeatLostLimit int
    44  
    45  	// CloseImmediately true express when client exit, discard all message in queue, and close connection immediately,
    46  	// otherwise client will close runRead, and mark as stateClosing, the client cannot receive and enqueue message,
    47  	// after all message in queue is sent, client will close runWrite and connection.
    48  	CloseImmediately bool
    49  }
    50  
    51  type MessageInterceptor = func(dc DefaultClient, msg *messages.GlideMessage) bool
    52  
    53  type DefaultClient interface {
    54  	Client
    55  
    56  	SetCredentials(credentials *ClientAuthCredentials)
    57  
    58  	GetCredentials() *ClientAuthCredentials
    59  
    60  	AddMessageInterceptor(interceptor MessageInterceptor)
    61  }
    62  
    63  var _ DefaultClient = (*UserClient)(nil)
    64  
    65  // UserClient represent a user conn client.
    66  type UserClient struct {
    67  
    68  	// conn is the real connection
    69  	conn conn.Connection
    70  
    71  	// state is the client state
    72  	state int32
    73  
    74  	// queuedMessage message count in the messages channel
    75  	queuedMessage int64
    76  	// messages is the buffered channel for message to push to client.
    77  	messages chan *messages.GlideMessage
    78  
    79  	// closeReadCh is the channel for runRead goroutine to close
    80  	closeReadCh chan struct{}
    81  	// closeWriteCh is the channel for runWrite goroutine to close
    82  	closeWriteCh chan struct{}
    83  
    84  	// closeWriteOnce is the once for close runWrite goroutine
    85  	closeWriteOnce sync.Once
    86  	// closeReadOnce is the once for close runRead goroutine
    87  	closeReadOnce sync.Once
    88  
    89  	// hbC is the timer for client heartbeat
    90  	hbC *timingwheel.Task
    91  	// hbS is the timer for server heartbeat
    92  	hbS *timingwheel.Task
    93  	// hbLost is the count of heartbeat lost
    94  	hbLost int
    95  
    96  	// info is the client info
    97  	info *Info
    98  
    99  	credentials *ClientAuthCredentials
   100  
   101  	// mgr the client manager which manage this client
   102  	mgr Gateway
   103  	// msgHandler client message handler
   104  	msgHandler MessageHandler
   105  
   106  	// config is the client config
   107  	config *ClientConfig
   108  }
   109  
   110  func NewClientWithConfig(conn conn.Connection, mgr Gateway, handler MessageHandler, config *ClientConfig) DefaultClient {
   111  	if config == nil {
   112  		config = &ClientConfig{
   113  			ClientHeartbeatDuration: defaultHeartbeatDuration,
   114  			ServerHeartbeatDuration: defaultServerHeartbeatDuration,
   115  			HeartbeatLostLimit:      defaultHeartbeatLostLimit,
   116  			CloseImmediately:        false,
   117  		}
   118  	}
   119  
   120  	ret := UserClient{
   121  		conn:         conn,
   122  		messages:     make(chan *messages.GlideMessage, 100),
   123  		closeReadCh:  make(chan struct{}),
   124  		closeWriteCh: make(chan struct{}),
   125  		hbC:          tw.After(config.ClientHeartbeatDuration),
   126  		hbS:          tw.After(config.ServerHeartbeatDuration),
   127  		info: &Info{
   128  			ConnectionAt: time.Now().UnixMilli(),
   129  			CliAddr:      conn.GetConnInfo().Addr,
   130  		},
   131  		mgr:        mgr,
   132  		msgHandler: handler,
   133  		config:     config,
   134  	}
   135  	return &ret
   136  }
   137  
   138  func NewClient(conn conn.Connection, mgr Gateway, handler MessageHandler) DefaultClient {
   139  	return NewClientWithConfig(conn, mgr, handler, nil)
   140  }
   141  
   142  func (c *UserClient) SetCredentials(credentials *ClientAuthCredentials) {
   143  	c.credentials = credentials
   144  	c.info.ConnectionId = credentials.ConnectionID
   145  }
   146  
   147  func (c *UserClient) GetCredentials() *ClientAuthCredentials {
   148  	return c.credentials
   149  }
   150  
   151  func (c *UserClient) AddMessageInterceptor(interceptor MessageInterceptor) {
   152  	h := c.msgHandler
   153  	c.msgHandler = func(cliInfo *Info, msg *messages.GlideMessage) {
   154  		if interceptor(c, msg) {
   155  			return
   156  		}
   157  		h(cliInfo, msg)
   158  	}
   159  }
   160  
   161  func (c *UserClient) GetInfo() Info {
   162  	return *c.info
   163  }
   164  
   165  // SetID set client id.
   166  func (c *UserClient) SetID(id ID) {
   167  	c.info.ID = id
   168  }
   169  
   170  // IsRunning return true if client is running
   171  func (c *UserClient) IsRunning() bool {
   172  	return atomic.LoadInt32(&c.state) == stateRunning
   173  }
   174  
   175  // EnqueueMessage enqueue message to client message queue.
   176  func (c *UserClient) EnqueueMessage(msg *messages.GlideMessage) error {
   177  	if atomic.LoadInt32(&c.state) == stateClosed {
   178  		return errors.New("client has closed")
   179  	}
   180  	logger.I("EnqueueMessage ID=%s msg=%v", c.info.ID, msg)
   181  	select {
   182  	case c.messages <- msg:
   183  		atomic.AddInt64(&c.queuedMessage, 1)
   184  	default:
   185  		logger.E("msg chan is full, id=%v", c.info.ID)
   186  	}
   187  	return nil
   188  }
   189  
   190  // runRead message from client.
   191  func (c *UserClient) runRead() {
   192  	defer func() {
   193  		err := recover()
   194  		if err != nil {
   195  			logger.E("read message panic: %v", err)
   196  			c.Exit()
   197  		}
   198  	}()
   199  
   200  	readChan, done := messageReader.ReadCh(c.conn)
   201  	var closeReason string
   202  	for {
   203  		select {
   204  		case <-c.closeReadCh:
   205  			if closeReason == "" {
   206  				closeReason = "closed initiative"
   207  			}
   208  			goto STOP
   209  		case <-c.hbC.C:
   210  			if !c.IsRunning() {
   211  				goto STOP
   212  			}
   213  			c.hbLost++
   214  			if c.hbLost > c.config.HeartbeatLostLimit {
   215  				closeReason = "heartbeat lost"
   216  				c.Exit()
   217  			}
   218  			c.hbC.Cancel()
   219  			c.hbC = tw.After(c.config.ClientHeartbeatDuration)
   220  			_ = c.EnqueueMessage(messages.NewMessage(0, messages.ActionHeartbeat, nil))
   221  		case msg := <-readChan:
   222  			if msg == nil {
   223  				closeReason = "readCh closed"
   224  				c.Exit()
   225  				continue
   226  			}
   227  			if msg.err != nil {
   228  				if messages.IsDecodeError(msg.err) {
   229  					_ = c.EnqueueMessage(messages.NewMessage(0, messages.ActionNotifyError, msg.err.Error()))
   230  					continue
   231  				}
   232  				closeReason = msg.err.Error()
   233  				c.Exit()
   234  				continue
   235  			}
   236  			if c.info.ID == "" {
   237  				closeReason = "client not logged"
   238  				c.Exit()
   239  				break
   240  			}
   241  			c.hbLost = 0
   242  			c.hbC.Cancel()
   243  			c.hbC = tw.After(c.config.ClientHeartbeatDuration)
   244  
   245  			if msg.m.GetAction() == messages.ActionHello {
   246  				c.handleHello(msg.m)
   247  			} else {
   248  				c.msgHandler(c.info, msg.m)
   249  			}
   250  			msg.Recycle()
   251  		}
   252  	}
   253  STOP:
   254  	close(done)
   255  	c.hbC.Cancel()
   256  	logger.I("read exit, reason=%s", closeReason)
   257  }
   258  
   259  // runWrite message to client.
   260  func (c *UserClient) runWrite() {
   261  	defer func() {
   262  		err := recover()
   263  		if err != nil {
   264  			logger.D("write message error, exit client: %v", err)
   265  			c.Exit()
   266  		}
   267  	}()
   268  
   269  	var closeReason string
   270  	for {
   271  		select {
   272  		case <-c.closeWriteCh:
   273  			if closeReason == "" {
   274  				closeReason = "closed initiative"
   275  			}
   276  			goto STOP
   277  		case <-c.hbS.C:
   278  			if !c.IsRunning() {
   279  				closeReason = "client not running"
   280  				goto STOP
   281  			}
   282  			_ = c.EnqueueMessage(messages.NewMessage(0, messages.ActionHeartbeat, nil))
   283  			c.hbS.Cancel()
   284  			c.hbS = tw.After(c.config.ServerHeartbeatDuration)
   285  		case m := <-c.messages:
   286  			if m == nil {
   287  				closeReason = "message is nil, maybe client has closed"
   288  				c.Exit()
   289  				break
   290  			}
   291  			c.write2Conn(m)
   292  			c.hbS.Cancel()
   293  			c.hbS = tw.After(c.config.ServerHeartbeatDuration)
   294  		}
   295  	}
   296  STOP:
   297  	c.hbS.Cancel()
   298  	logger.D("write exit, addr=%s, reason:%s", c.info.CliAddr, closeReason)
   299  }
   300  
   301  // Exit client, note: exit client will not close conn right now, but will close when message chan is empty.
   302  // It's close read right now, and close write2Conn when all message in queue is sent.
   303  func (c *UserClient) Exit() {
   304  	if atomic.LoadInt32(&c.state) == stateClosed {
   305  		return
   306  	}
   307  	atomic.StoreInt32(&c.state, stateClosed)
   308  
   309  	id := c.info.ID
   310  	// exit by client self, remove client from manager
   311  	if c.mgr != nil && id != "" {
   312  		_ = c.mgr.ExitClient(id)
   313  	}
   314  	c.SetID("")
   315  	c.mgr = nil
   316  	c.stopReadWrite()
   317  
   318  	if c.config.CloseImmediately {
   319  		// dropping all message in queue and close connection immediately
   320  		c.close()
   321  	} else {
   322  		// close connection when all message in queue is sent
   323  		go func() {
   324  			for {
   325  				select {
   326  				case m := <-c.messages:
   327  					c.write2Conn(m)
   328  				default:
   329  					goto END
   330  				}
   331  			}
   332  		END:
   333  			c.close()
   334  		}()
   335  	}
   336  }
   337  
   338  func (c *UserClient) Run() {
   339  	logger.I("new client running addr:%s id:%s", c.conn.GetConnInfo().Addr, c.info.ID)
   340  	atomic.StoreInt32(&c.state, stateRunning)
   341  	c.closeWriteOnce = sync.Once{}
   342  	c.closeReadOnce = sync.Once{}
   343  
   344  	go c.runRead()
   345  	go c.runWrite()
   346  }
   347  
   348  func (c *UserClient) isClosed() bool {
   349  	return atomic.LoadInt32(&c.state) == stateClosed
   350  }
   351  
   352  func (c *UserClient) close() {
   353  	close(c.messages)
   354  	_ = c.conn.Close()
   355  }
   356  
   357  func (c *UserClient) write2Conn(m *messages.GlideMessage) {
   358  	b, err := codec.Encode(m)
   359  	if err != nil {
   360  		logger.E("serialize output message", err)
   361  		return
   362  	}
   363  	err = c.conn.Write(b)
   364  	atomic.AddInt64(&c.queuedMessage, -1)
   365  	if err != nil {
   366  		logger.D("runWrite error: %s", err.Error())
   367  		c.closeWriteOnce.Do(func() {
   368  			close(c.closeWriteCh)
   369  		})
   370  	}
   371  }
   372  
   373  func (c *UserClient) stopReadWrite() {
   374  	c.closeWriteOnce.Do(func() {
   375  		close(c.closeWriteCh)
   376  	})
   377  	c.closeReadOnce.Do(func() {
   378  		close(c.closeReadCh)
   379  	})
   380  }
   381  
   382  func (c *UserClient) handleHello(m *messages.GlideMessage) {
   383  	hello := messages.Hello{}
   384  	err := m.Data.Deserialize(&hello)
   385  	if err != nil {
   386  		_ = c.EnqueueMessage(messages.NewMessage(0, messages.ActionNotifyError, "invalid handleHello message"))
   387  	} else {
   388  		c.info.Version = hello.ClientVersion
   389  	}
   390  }