github.com/matrixorigin/matrixone@v1.2.0/pkg/vm/engine/tae/logtail/service/session_test.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package service
    16  
    17  import (
    18  	"context"
    19  	"os"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/require"
    25  	"go.uber.org/zap"
    26  
    27  	"github.com/matrixorigin/matrixone/pkg/common/log"
    28  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    29  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    30  	"github.com/matrixorigin/matrixone/pkg/logutil"
    31  	"github.com/matrixorigin/matrixone/pkg/pb/api"
    32  	"github.com/matrixorigin/matrixone/pkg/pb/logtail"
    33  	"github.com/matrixorigin/matrixone/pkg/pb/metadata"
    34  	"github.com/matrixorigin/matrixone/pkg/pb/timestamp"
    35  )
    36  
    37  func TestMain(m *testing.M) {
    38  	// make responseBufferSize small enough temporarily
    39  	// Larger buffer size would have a negative effect on CI.
    40  	original := responseBufferSize
    41  	responseBufferSize = 1024
    42  	ret := m.Run()
    43  	responseBufferSize = original
    44  	os.Exit(ret)
    45  }
    46  
    47  func TestSessionManger(t *testing.T) {
    48  	sm := NewSessionManager()
    49  
    50  	ctx := context.Background()
    51  
    52  	// constructs mocker
    53  	logger := mockMOLogger()
    54  	pooler := NewLogtailResponsePool()
    55  	notifier := mockSessionErrorNotifier(logger.RawLogger())
    56  	sendTimeout := 5 * time.Second
    57  	poisonTime := 10 * time.Millisecond
    58  	heartbeatInterval := 50 * time.Millisecond
    59  	chunkSize := 1024
    60  
    61  	/* ---- 1. register sessioin A ---- */
    62  	csA := mockNormalClientSession(logger.RawLogger())
    63  	streamA := mockMorpcStream(csA, 10, chunkSize)
    64  	sessionA := sm.GetSession(
    65  		ctx, logger, pooler, notifier, streamA,
    66  		sendTimeout, poisonTime, heartbeatInterval,
    67  	)
    68  	require.NotNil(t, sessionA)
    69  	require.Equal(t, 1, len(sm.ListSession()))
    70  
    71  	/* ---- 2. register sessioin B ---- */
    72  	csB := mockNormalClientSession(logger.RawLogger())
    73  	streamB := mockMorpcStream(csB, 11, chunkSize)
    74  	sessionB := sm.GetSession(
    75  		ctx, logger, pooler, notifier, streamB,
    76  		sendTimeout, poisonTime, heartbeatInterval,
    77  	)
    78  	require.NotNil(t, sessionB)
    79  	require.Equal(t, 2, len(sm.ListSession()))
    80  
    81  	/* ---- 3. delete sessioin ---- */
    82  	sm.DeleteSession(streamA)
    83  	require.Equal(t, 1, len(sm.ListSession()))
    84  	sm.DeleteSession(streamB)
    85  	require.Equal(t, 0, len(sm.ListSession()))
    86  }
    87  
    88  func TestSessionError(t *testing.T) {
    89  	ctx, cancel := context.WithCancel(context.Background())
    90  	defer cancel()
    91  
    92  	// constructs mocker
    93  	logger := mockMOLogger()
    94  	pooler := NewLogtailResponsePool()
    95  	notifier := mockSessionErrorNotifier(logger.RawLogger())
    96  	cs := mockBrokenClientSession()
    97  	stream := mockMorpcStream(cs, 10, 1024)
    98  	sendTimeout := 5 * time.Second
    99  	poisionTime := 10 * time.Millisecond
   100  	heartbeatInterval := 50 * time.Millisecond
   101  
   102  	tableA := mockTable(1, 2, 3)
   103  	ss := NewSession(
   104  		ctx, logger, pooler, notifier, stream,
   105  		sendTimeout, poisionTime, heartbeatInterval,
   106  	)
   107  
   108  	/* ---- 1. send subscription response ---- */
   109  	err := ss.SendSubscriptionResponse(
   110  		context.Background(),
   111  		logtail.TableLogtail{
   112  			Table: &tableA,
   113  		},
   114  		nil,
   115  	)
   116  	require.NoError(t, err)
   117  
   118  	// wait session cleaned
   119  	<-ss.sessionCtx.Done()
   120  
   121  	/* ---- 2. send subscription response ---- */
   122  	err = ss.SendSubscriptionResponse(
   123  		context.Background(),
   124  		logtail.TableLogtail{
   125  			Table: &tableA,
   126  		},
   127  		nil,
   128  	)
   129  	require.Error(t, err)
   130  }
   131  
   132  func TestPoisionSession(t *testing.T) {
   133  	ctx, cancel := context.WithCancel(context.Background())
   134  	defer cancel()
   135  
   136  	// constructs mocker
   137  	logger := mockMOLogger()
   138  	pooler := NewLogtailResponsePool()
   139  	notifier := mockSessionErrorNotifier(logger.RawLogger())
   140  	cs := mockBlockStream()
   141  	stream := mockMorpcStream(cs, 10, 1024)
   142  	sendTimeout := 5 * time.Second
   143  	poisionTime := 10 * time.Millisecond
   144  	heartbeatInterval := 50 * time.Millisecond
   145  
   146  	tableA := mockTable(1, 2, 3)
   147  	ss := NewSession(
   148  		ctx, logger, pooler, notifier, stream,
   149  		sendTimeout, poisionTime, heartbeatInterval,
   150  	)
   151  
   152  	/* ---- 1. send response repeatedly ---- */
   153  	for i := 0; i < cap(ss.sendChan)+2; i++ {
   154  		err := ss.SendUpdateResponse(
   155  			context.Background(),
   156  			mockTimestamp(int64(i), 0),
   157  			mockTimestamp(int64(i+1), 0),
   158  			nil,
   159  			logtail.TableLogtail{
   160  				Table: &tableA,
   161  			},
   162  		)
   163  		if err != nil {
   164  			require.True(t, moerr.IsMoErrCode(err, moerr.ErrStreamClosed))
   165  			break
   166  		}
   167  	}
   168  }
   169  
   170  func TestSession(t *testing.T) {
   171  	ctx, cancel := context.WithCancel(context.Background())
   172  	defer cancel()
   173  
   174  	// constructs mocker
   175  	logger := mockMOLogger()
   176  	pooler := NewLogtailResponsePool()
   177  	notifier := mockSessionErrorNotifier(logger.RawLogger())
   178  	cs := mockNormalClientSession(logger.RawLogger())
   179  	stream := mockMorpcStream(cs, 10, 1024)
   180  	sendTimeout := 5 * time.Second
   181  	poisionTime := 10 * time.Millisecond
   182  	heartbeatInterval := 50 * time.Millisecond
   183  
   184  	// constructs tables
   185  	tableA := mockTable(1, 2, 3)
   186  	idA := MarshalTableID(&tableA)
   187  	tableB := mockTable(1, 4, 3)
   188  	idB := MarshalTableID(&tableB)
   189  
   190  	ss := NewSession(
   191  		ctx, logger, pooler, notifier, stream,
   192  		sendTimeout, poisionTime, heartbeatInterval,
   193  	)
   194  	defer ss.PostClean()
   195  
   196  	// no table resigered now
   197  	require.Equal(t, 0, len(ss.ListSubscribedTable()))
   198  
   199  	/* ---- 1. register table ---- */
   200  	require.False(t, ss.Register(idA, tableA))
   201  	require.True(t, ss.Register(idA, tableA))
   202  
   203  	/* ---- 2. unregister table ---- */
   204  	require.Equal(t, TableOnSubscription, ss.Unregister(idA))
   205  	require.Equal(t, TableNotFound, ss.Unregister(idA))
   206  
   207  	/* ---- 3. register more table ---- */
   208  	require.False(t, ss.Register(idA, tableA))
   209  	require.False(t, ss.Register(idB, tableB))
   210  	require.Equal(t, 0, len(ss.ListSubscribedTable()))
   211  
   212  	/* ---- 4. filter logtail ---- */
   213  	// promote state for table A
   214  	ss.AdvanceState(idA)
   215  	require.Equal(t, 1, len(ss.ListSubscribedTable()))
   216  	// promote state for non-exist table
   217  	ss.AdvanceState(TableID("non-exist"))
   218  	require.Equal(t, 1, len(ss.ListSubscribedTable()))
   219  	// filter logtail for subscribed table
   220  	qualified := ss.FilterLogtail(
   221  		mockWrapLogtail(tableA),
   222  		mockWrapLogtail(tableB),
   223  	)
   224  	require.Equal(t, 1, len(qualified))
   225  	require.Equal(t, tableA.String(), qualified[0].Table.String())
   226  
   227  	// promote state for table B
   228  	ss.AdvanceState(idB)
   229  	require.Equal(t, 2, len(ss.ListSubscribedTable()))
   230  	// filter logtail for subscribed table
   231  	qualified = ss.FilterLogtail(
   232  		mockWrapLogtail(tableA),
   233  		mockWrapLogtail(tableB),
   234  	)
   235  	require.Equal(t, 2, len(qualified))
   236  
   237  	/* ---- 5. send error response ---- */
   238  	err := ss.SendErrorResponse(
   239  		context.Background(),
   240  		tableA,
   241  		moerr.ErrInternal,
   242  		"interval error",
   243  	)
   244  	require.NoError(t, err)
   245  
   246  	/* ---- 6. send subscription response ---- */
   247  	err = ss.SendSubscriptionResponse(
   248  		context.Background(),
   249  		logtail.TableLogtail{
   250  			Table: &tableA,
   251  		},
   252  		nil,
   253  	)
   254  	require.NoError(t, err)
   255  
   256  	/* ---- 7. send unsubscription response ---- */
   257  	err = ss.SendUnsubscriptionResponse(
   258  		context.Background(),
   259  		tableA,
   260  	)
   261  	require.NoError(t, err)
   262  
   263  	/* ---- 8. send update response ---- */
   264  	{
   265  		from := mockTimestamp(1, 0)
   266  		to := mockTimestamp(2, 0)
   267  		err = ss.SendUpdateResponse(
   268  			context.Background(),
   269  			from,
   270  			to,
   271  			nil,
   272  			mockLogtail(tableA, to),
   273  			mockLogtail(tableB, to),
   274  		)
   275  		require.NoError(t, err)
   276  	}
   277  
   278  	/* ---- 9. publish update response ---- */
   279  	err = ss.Publish(
   280  		context.Background(),
   281  		mockTimestamp(2, 0),
   282  		mockTimestamp(3, 0),
   283  		nil,
   284  		mockWrapLogtail(tableA),
   285  		mockWrapLogtail(tableB),
   286  	)
   287  	require.NoError(t, err)
   288  }
   289  
   290  type blockStream struct {
   291  	once sync.Once
   292  	ch   chan bool
   293  }
   294  
   295  func mockBlockStream() morpc.ClientSession {
   296  	return &blockStream{
   297  		ch: make(chan bool),
   298  	}
   299  }
   300  
   301  func (m *blockStream) RemoteAddress() string {
   302  	return "block"
   303  }
   304  
   305  func (m *blockStream) Write(ctx context.Context, message morpc.Message) error {
   306  	<-m.ch
   307  	return moerr.NewStreamClosedNoCtx()
   308  }
   309  
   310  func (m *blockStream) AsyncWrite(message morpc.Message) error {
   311  	<-m.ch
   312  	return moerr.NewStreamClosedNoCtx()
   313  }
   314  
   315  func (m *blockStream) Close() error {
   316  	m.once.Do(func() {
   317  		close(m.ch)
   318  	})
   319  	return nil
   320  }
   321  
   322  func (m *blockStream) CreateCache(
   323  	ctx context.Context,
   324  	cacheID uint64) (morpc.MessageCache, error) {
   325  	panic("not implement")
   326  }
   327  
   328  func (m *blockStream) DeleteCache(cacheID uint64) {
   329  	panic("not implement")
   330  }
   331  
   332  func (m *blockStream) GetCache(cacheID uint64) (morpc.MessageCache, error) {
   333  	panic("not implement")
   334  }
   335  
   336  type brokenStream struct{}
   337  
   338  func mockBrokenClientSession() morpc.ClientSession {
   339  	return &brokenStream{}
   340  }
   341  
   342  func (m *brokenStream) RemoteAddress() string {
   343  	return "broken"
   344  }
   345  
   346  func (m *brokenStream) Write(ctx context.Context, message morpc.Message) error {
   347  	return moerr.NewStreamClosedNoCtx()
   348  }
   349  
   350  func (cs *brokenStream) AsyncWrite(response morpc.Message) error {
   351  	return nil
   352  }
   353  
   354  func (m *brokenStream) Close() error {
   355  	return nil
   356  }
   357  
   358  func (m *brokenStream) CreateCache(
   359  	ctx context.Context,
   360  	cacheID uint64) (morpc.MessageCache, error) {
   361  	panic("not implement")
   362  }
   363  
   364  func (m *brokenStream) DeleteCache(cacheID uint64) {
   365  	panic("not implement")
   366  }
   367  
   368  func (m *brokenStream) GetCache(cacheID uint64) (morpc.MessageCache, error) {
   369  	panic("not implement")
   370  }
   371  
   372  type normalStream struct {
   373  	logger *zap.Logger
   374  }
   375  
   376  func mockNormalClientSession(logger *zap.Logger) morpc.ClientSession {
   377  	return &normalStream{
   378  		logger: logger,
   379  	}
   380  }
   381  
   382  func (m *normalStream) RemoteAddress() string {
   383  	return "normal"
   384  }
   385  
   386  func (m *normalStream) Write(ctx context.Context, message morpc.Message) error {
   387  	response := message.(*LogtailResponseSegment)
   388  	m.logger.Info("write response segment:", zap.String("segment", response.String()))
   389  	return nil
   390  }
   391  
   392  func (m *normalStream) AsyncWrite(message morpc.Message) error {
   393  	response := message.(*LogtailResponseSegment)
   394  	m.logger.Info("write response segment:", zap.String("segment", response.String()))
   395  	return nil
   396  }
   397  
   398  func (m *normalStream) Close() error {
   399  	return nil
   400  }
   401  
   402  func (m *normalStream) CreateCache(
   403  	ctx context.Context,
   404  	cacheID uint64) (morpc.MessageCache, error) {
   405  	panic("not implement")
   406  }
   407  
   408  func (m *normalStream) DeleteCache(cacheID uint64) {
   409  	panic("not implement")
   410  }
   411  
   412  func (m *normalStream) GetCache(cacheID uint64) (morpc.MessageCache, error) {
   413  	panic("not implement")
   414  }
   415  
   416  type notifySessionError struct {
   417  	logger *zap.Logger
   418  }
   419  
   420  func mockSessionErrorNotifier(logger *zap.Logger) SessionErrorNotifier {
   421  	return &notifySessionError{
   422  		logger: logger,
   423  	}
   424  }
   425  
   426  func (m *notifySessionError) NotifySessionError(ss *Session, err error) {
   427  	if err != nil {
   428  		m.logger.Error("receive session error", zap.Error(err))
   429  		ss.PostClean()
   430  	}
   431  }
   432  
   433  func mockWrapLogtail(table api.TableID) wrapLogtail {
   434  	return wrapLogtail{
   435  		id: MarshalTableID(&table),
   436  		tail: logtail.TableLogtail{
   437  			Table: &table,
   438  		},
   439  	}
   440  }
   441  
   442  func mockLogtail(table api.TableID, ts timestamp.Timestamp) logtail.TableLogtail {
   443  	return logtail.TableLogtail{
   444  		CkpLocation: "checkpoint",
   445  		Table:       &table,
   446  		Ts:          &ts,
   447  	}
   448  }
   449  
   450  func mockMorpcStream(
   451  	cs morpc.ClientSession, id uint64, maxMessageSize int,
   452  ) morpcStream {
   453  	segments := NewLogtailServerSegmentPool(maxMessageSize)
   454  
   455  	return morpcStream{
   456  		streamID: id,
   457  		remote:   "mock",
   458  		limit:    segments.LeastEffectiveCapacity(),
   459  		logger:   mockMOLogger(),
   460  		cs:       cs,
   461  		segments: segments,
   462  	}
   463  }
   464  
   465  func mockMOLogger() *log.MOLogger {
   466  	return log.GetServiceLogger(
   467  		logutil.GetGlobalLogger().Named(LogtailServiceRPCName),
   468  		metadata.ServiceType_TN,
   469  		"uuid",
   470  	)
   471  }