github.com/matrixorigin/matrixone@v0.7.0/pkg/vm/engine/tae/logtail/service/session.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package service
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  	"time"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/log"
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    25  	"github.com/matrixorigin/matrixone/pkg/pb/api"
    26  	"github.com/matrixorigin/matrixone/pkg/pb/logtail"
    27  	"github.com/matrixorigin/matrixone/pkg/pb/timestamp"
    28  	"go.uber.org/zap"
    29  )
    30  
    31  type TableState int
    32  
    33  const (
    34  	TableOnSubscription TableState = iota
    35  	TableSubscribed
    36  	TableNotFound
    37  )
    38  
    39  // SessionManager manages all client sessions.
    40  type SessionManager struct {
    41  	sync.RWMutex
    42  	clients map[morpcStream]*Session
    43  }
    44  
    45  // NewSessionManager constructs a session manager.
    46  func NewSessionManager() *SessionManager {
    47  	return &SessionManager{
    48  		clients: make(map[morpcStream]*Session),
    49  	}
    50  }
    51  
    52  // GetSession constructs a session for new morpc.ClientSession.
    53  func (sm *SessionManager) GetSession(
    54  	rootCtx context.Context,
    55  	logger *log.MOLogger,
    56  	sendTimeout time.Duration,
    57  	responses ResponsePool,
    58  	notifier SessionErrorNotifier,
    59  	stream morpcStream,
    60  	poisionTime time.Duration,
    61  ) *Session {
    62  	sm.Lock()
    63  	defer sm.Unlock()
    64  
    65  	if _, ok := sm.clients[stream]; !ok {
    66  		sm.clients[stream] = NewSession(
    67  			rootCtx, logger, sendTimeout, responses, notifier, stream, poisionTime,
    68  		)
    69  	}
    70  	return sm.clients[stream]
    71  }
    72  
    73  // DeleteSession deletes session from manager.
    74  func (sm *SessionManager) DeleteSession(stream morpcStream) {
    75  	sm.Lock()
    76  	defer sm.Unlock()
    77  	delete(sm.clients, stream)
    78  }
    79  
    80  // ListSession takes a snapshot of all sessions.
    81  func (sm *SessionManager) ListSession() []*Session {
    82  	sm.RLock()
    83  	defer sm.RUnlock()
    84  
    85  	sessions := make([]*Session, 0, len(sm.clients))
    86  	for _, ss := range sm.clients {
    87  		sessions = append(sessions, ss)
    88  	}
    89  	return sessions
    90  }
    91  
    92  // message describes response to be sent.
    93  type message struct {
    94  	timeout  time.Duration
    95  	response *LogtailResponse
    96  }
    97  
    98  // morpcStream describes morpc stream.
    99  type morpcStream struct {
   100  	streamID uint64
   101  	limit    int
   102  	logger   *log.MOLogger
   103  	cs       morpc.ClientSession
   104  	segments SegmentPool
   105  }
   106  
   107  // Close closes morpc client session.
   108  func (s *morpcStream) Close() error {
   109  	return s.cs.Close()
   110  }
   111  
   112  // write sends response by segment.
   113  func (s *morpcStream) write(
   114  	ctx context.Context, response *LogtailResponse,
   115  ) error {
   116  	size := response.ProtoSize()
   117  	buf := make([]byte, size)
   118  	n, err := response.MarshalToSizedBuffer(buf[:size])
   119  	if err != nil {
   120  		return err
   121  	}
   122  	chunks := Split(buf[:n], s.limit)
   123  
   124  	s.logger.Debug("start to send response by segment",
   125  		zap.Int("chunk-number", len(chunks)),
   126  		zap.Int("chunk-limit", s.limit),
   127  		zap.Int("message-size", size),
   128  	)
   129  
   130  	for index, chunk := range chunks {
   131  		seg := s.segments.Acquire()
   132  		seg.SetID(s.streamID)
   133  		seg.MessageSize = int32(size)
   134  		seg.Sequence = int32(index + 1)
   135  		seg.MaxSequence = int32(len(chunks))
   136  		n := copy(seg.Payload, chunk)
   137  		seg.Payload = seg.Payload[:n]
   138  
   139  		s.logger.Debug("real segment proto size", zap.Int("ProtoSize", seg.ProtoSize()))
   140  
   141  		if err := s.cs.Write(ctx, seg); err != nil {
   142  			s.segments.Release(seg)
   143  			return err
   144  		}
   145  	}
   146  
   147  	return nil
   148  }
   149  
   150  // Session manages subscription for logtail client.
   151  type Session struct {
   152  	sessionCtx context.Context
   153  	cancelFunc context.CancelFunc
   154  	wg         sync.WaitGroup
   155  
   156  	logger      *log.MOLogger
   157  	sendTimeout time.Duration
   158  	responses   ResponsePool
   159  	notifier    SessionErrorNotifier
   160  
   161  	stream      morpcStream
   162  	poisionTime time.Duration
   163  	sendChan    chan message
   164  
   165  	mu     sync.RWMutex
   166  	tables map[TableID]TableState
   167  }
   168  
   169  type SessionErrorNotifier interface {
   170  	NotifySessionError(*Session, error)
   171  }
   172  
   173  // NewSession constructs a session for logtail client.
   174  func NewSession(
   175  	rootCtx context.Context,
   176  	logger *log.MOLogger,
   177  	sendTimeout time.Duration,
   178  	responses ResponsePool,
   179  	notifier SessionErrorNotifier,
   180  	stream morpcStream,
   181  	poisionTime time.Duration,
   182  ) *Session {
   183  	ctx, cancel := context.WithCancel(rootCtx)
   184  	ss := &Session{
   185  		sessionCtx:  ctx,
   186  		cancelFunc:  cancel,
   187  		logger:      logger,
   188  		sendTimeout: sendTimeout,
   189  		responses:   responses,
   190  		notifier:    notifier,
   191  		stream:      stream,
   192  		poisionTime: poisionTime,
   193  		sendChan:    make(chan message, 16), // buffer response for morpc client session
   194  		tables:      make(map[TableID]TableState),
   195  	}
   196  
   197  	sender := func() {
   198  		defer ss.wg.Done()
   199  
   200  		for {
   201  			select {
   202  			case <-ss.sessionCtx.Done():
   203  				ss.logger.Error("stop session sender", zap.Error(ss.sessionCtx.Err()))
   204  				return
   205  
   206  			case msg, ok := <-ss.sendChan:
   207  				if !ok {
   208  					ss.logger.Info("session sender channel closed")
   209  					return
   210  				}
   211  
   212  				sendFunc := func() error {
   213  					defer ss.responses.Release(msg.response)
   214  
   215  					ctx, cancel := context.WithTimeout(ss.sessionCtx, msg.timeout)
   216  					defer cancel()
   217  					return ss.stream.write(ctx, msg.response)
   218  				}
   219  
   220  				if err := sendFunc(); err != nil {
   221  					ss.logger.Error("fail to send logtail response", zap.Error(err))
   222  					ss.notifier.NotifySessionError(ss, err)
   223  					return
   224  				}
   225  			}
   226  		}
   227  	}
   228  
   229  	ss.wg.Add(1)
   230  	go sender()
   231  
   232  	return ss
   233  }
   234  
   235  // Drop closes sender goroutine.
   236  func (ss *Session) PostClean() {
   237  	ss.cancelFunc()
   238  	ss.wg.Wait()
   239  }
   240  
   241  // Register registers table for client.
   242  //
   243  // The returned true value indicates repeated subscription.
   244  func (ss *Session) Register(id TableID, table api.TableID) bool {
   245  	ss.mu.Lock()
   246  	defer ss.mu.Unlock()
   247  
   248  	if _, ok := ss.tables[id]; ok {
   249  		return true
   250  	}
   251  	ss.tables[id] = TableOnSubscription
   252  	return false
   253  }
   254  
   255  // Unsubscribe unsubscribes table.
   256  func (ss *Session) Unregister(id TableID) TableState {
   257  	ss.mu.Lock()
   258  	defer ss.mu.Unlock()
   259  
   260  	state, ok := ss.tables[id]
   261  	if !ok {
   262  		return TableNotFound
   263  	}
   264  	delete(ss.tables, id)
   265  	return state
   266  }
   267  
   268  // ListTable takes a snapshot of all
   269  func (ss *Session) ListSubscribedTable() []TableID {
   270  	ss.mu.RLock()
   271  	defer ss.mu.RUnlock()
   272  
   273  	ids := make([]TableID, 0, len(ss.tables))
   274  	for id, state := range ss.tables {
   275  		if state == TableSubscribed {
   276  			ids = append(ids, id)
   277  		}
   278  	}
   279  	return ids
   280  }
   281  
   282  // FilterLogtail selects logtail for expected tables.
   283  func (ss *Session) FilterLogtail(tails ...wrapLogtail) []logtail.TableLogtail {
   284  	ss.mu.RLock()
   285  	defer ss.mu.RUnlock()
   286  
   287  	qualified := make([]logtail.TableLogtail, 0, len(ss.tables))
   288  	for _, t := range tails {
   289  		if state, ok := ss.tables[t.id]; ok && state == TableSubscribed {
   290  			qualified = append(qualified, t.tail)
   291  		}
   292  	}
   293  	return qualified
   294  }
   295  
   296  // Publish publishes additional logtail.
   297  func (ss *Session) Publish(
   298  	ctx context.Context, from, to timestamp.Timestamp, wraps ...wrapLogtail,
   299  ) error {
   300  	sendCtx, cancel := context.WithTimeout(ctx, ss.sendTimeout)
   301  	defer cancel()
   302  
   303  	qualified := ss.FilterLogtail(wraps...)
   304  	return ss.SendUpdateResponse(sendCtx, from, to, qualified...)
   305  }
   306  
   307  // TransitionState marks table as subscribed.
   308  func (ss *Session) AdvanceState(id TableID) {
   309  	ss.logger.Debug("mark table as subscribed", zap.String("table-id", string(id)))
   310  
   311  	ss.mu.Lock()
   312  	defer ss.mu.Unlock()
   313  
   314  	if _, ok := ss.tables[id]; !ok {
   315  		return
   316  	}
   317  	ss.tables[id] = TableSubscribed
   318  }
   319  
   320  // SendErrorResponse sends error response to logtail client.
   321  func (ss *Session) SendErrorResponse(
   322  	sendCtx context.Context, table api.TableID, code uint16, message string,
   323  ) error {
   324  	resp := ss.responses.Acquire()
   325  	resp.Response = newErrorResponse(table, code, message)
   326  	return ss.SendResponse(sendCtx, resp)
   327  }
   328  
   329  // SendSubscriptionResponse sends subscription response.
   330  func (ss *Session) SendSubscriptionResponse(
   331  	sendCtx context.Context, tail logtail.TableLogtail,
   332  ) error {
   333  	resp := ss.responses.Acquire()
   334  	resp.Response = newSubscritpionResponse(tail)
   335  	return ss.SendResponse(sendCtx, resp)
   336  }
   337  
   338  // SendUnsubscriptionResponse sends unsubscription response.
   339  func (ss *Session) SendUnsubscriptionResponse(
   340  	sendCtx context.Context, table api.TableID,
   341  ) error {
   342  	resp := ss.responses.Acquire()
   343  	resp.Response = newUnsubscriptionResponse(table)
   344  	return ss.SendResponse(sendCtx, resp)
   345  }
   346  
   347  // SendUpdateResponse sends publishment response.
   348  func (ss *Session) SendUpdateResponse(
   349  	sendCtx context.Context, from, to timestamp.Timestamp, tails ...logtail.TableLogtail,
   350  ) error {
   351  	resp := ss.responses.Acquire()
   352  	resp.Response = newUpdateResponse(from, to, tails...)
   353  	return ss.SendResponse(sendCtx, resp)
   354  }
   355  
   356  // SendResponse sends response.
   357  //
   358  // If the sender of Session finished, it would block until
   359  // sendCtx/sessionCtx cancelled or timeout.
   360  func (ss *Session) SendResponse(
   361  	sendCtx context.Context, response *LogtailResponse,
   362  ) error {
   363  	select {
   364  	case <-ss.sessionCtx.Done():
   365  		ss.logger.Error("session context done", zap.Error(ss.sessionCtx.Err()))
   366  		ss.responses.Release(response)
   367  		return ss.sessionCtx.Err()
   368  	case <-sendCtx.Done():
   369  		ss.logger.Error("send context done", zap.Error(sendCtx.Err()))
   370  		ss.responses.Release(response)
   371  		return sendCtx.Err()
   372  	default:
   373  	}
   374  
   375  	select {
   376  	case <-time.After(ss.poisionTime):
   377  		ss.logger.Error("poision morpc client session detected, close it")
   378  		ss.responses.Release(response)
   379  		if err := ss.stream.Close(); err != nil {
   380  			ss.logger.Error("fail to close poision morpc client session", zap.Error(err))
   381  		}
   382  		return moerr.NewStreamClosedNoCtx()
   383  	case ss.sendChan <- message{timeout: ContextTimeout(sendCtx, ss.sendTimeout), response: response}:
   384  		return nil
   385  	}
   386  }
   387  
   388  // newUnsubscriptionResponse constructs response for unsubscription.
   389  // go:inline
   390  func newUnsubscriptionResponse(
   391  	table api.TableID,
   392  ) *logtail.LogtailResponse_UnsubscribeResponse {
   393  	return &logtail.LogtailResponse_UnsubscribeResponse{
   394  		UnsubscribeResponse: &logtail.UnSubscribeResponse{
   395  			Table: &table,
   396  		},
   397  	}
   398  }
   399  
   400  // newUpdateResponse constructs response for publishment.
   401  // go:inline
   402  func newUpdateResponse(
   403  	from, to timestamp.Timestamp, tails ...logtail.TableLogtail,
   404  ) *logtail.LogtailResponse_UpdateResponse {
   405  	return &logtail.LogtailResponse_UpdateResponse{
   406  		UpdateResponse: &logtail.UpdateResponse{
   407  			From:        &from,
   408  			To:          &to,
   409  			LogtailList: tails,
   410  		},
   411  	}
   412  }
   413  
   414  // newSubscritpionResponse constructs response for subscription.
   415  // go:inline
   416  func newSubscritpionResponse(
   417  	tail logtail.TableLogtail,
   418  ) *logtail.LogtailResponse_SubscribeResponse {
   419  	return &logtail.LogtailResponse_SubscribeResponse{
   420  		SubscribeResponse: &logtail.SubscribeResponse{
   421  			Logtail: tail,
   422  		},
   423  	}
   424  }
   425  
   426  // newErrorResponse constructs response for error condition.
   427  // go:inline
   428  func newErrorResponse(
   429  	table api.TableID, code uint16, message string,
   430  ) *logtail.LogtailResponse_Error {
   431  	return &logtail.LogtailResponse_Error{
   432  		Error: &logtail.ErrorResponse{
   433  			Table: &table,
   434  			Status: logtail.Status{
   435  				Code:    uint32(code),
   436  				Message: message,
   437  			},
   438  		},
   439  	}
   440  }