github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/topic/topicwriterinternal/queue_test.go (about)

     1  package topicwriterinternal
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math/rand"
     7  	"runtime"
     8  	"runtime/debug"
     9  	"sync/atomic"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/empty"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/grpcwrapper/rawtopic/rawtopicwriter"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    18  )
    19  
    20  func TestMessageQueue_AddMessages(t *testing.T) {
    21  	t.Run("Empty", func(t *testing.T) {
    22  		q := newMessageQueue()
    23  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 3, 5)))
    24  
    25  		require.Equal(t, 3, q.lastWrittenIndex)
    26  
    27  		expected := map[int]messageWithDataContent{
    28  			1: newTestMessageWithDataContent(1),
    29  			2: newTestMessageWithDataContent(3),
    30  			3: newTestMessageWithDataContent(5),
    31  		}
    32  		require.Equal(t, expected, q.messagesByOrder)
    33  
    34  		require.Len(t, q.seqNoToOrderID, 3)
    35  		require.Equal(t, 1, q.seqNoToOrderID[1])
    36  		require.Equal(t, 2, q.seqNoToOrderID[3])
    37  		require.Equal(t, 3, q.seqNoToOrderID[5])
    38  	})
    39  	t.Run("Closed", func(t *testing.T) {
    40  		q := newMessageQueue()
    41  		_ = q.Close(errors.New("err"))
    42  		require.Error(t, q.AddMessages(newTestMessagesWithContent(1, 3, 5)))
    43  	})
    44  	t.Run("OverflowIndex", func(t *testing.T) {
    45  		q := newMessageQueue()
    46  		q.lastWrittenIndex = maxInt - 1
    47  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 3, 5)))
    48  		require.Len(t, q.messagesByOrder, 3)
    49  		q.messagesByOrder[maxInt] = newTestMessageWithDataContent(1)
    50  		q.messagesByOrder[minInt] = newTestMessageWithDataContent(3)
    51  		q.messagesByOrder[minInt+1] = newTestMessageWithDataContent(5)
    52  		require.Equal(t, minInt+1, q.lastWrittenIndex)
    53  	})
    54  	t.Run("BadOrder", func(t *testing.T) {
    55  		q := newMessageQueue()
    56  		require.Error(t, q.AddMessages(newTestMessagesWithContent(2, 1)))
    57  	})
    58  }
    59  
    60  func TestMessageQueue_CheckMessages(t *testing.T) {
    61  	t.Run("Empty", func(t *testing.T) {
    62  		q := newMessageQueue()
    63  		require.NoError(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent()))
    64  	})
    65  	t.Run("Unordered", func(t *testing.T) {
    66  		q := newMessageQueue()
    67  		require.Error(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent(2, 2)))
    68  		require.Error(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent(2, 1)))
    69  	})
    70  	t.Run("NoGreaterThenLastSent", func(t *testing.T) {
    71  		q := newMessageQueue()
    72  		q.lastSeqNo = 10
    73  		require.Error(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent(int(q.lastSeqNo-1))))
    74  		require.Error(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent(int(q.lastSeqNo))))
    75  		require.NoError(t, q.checkNewMessagesBeforeAddNeedLock(newTestMessagesWithContent(int(q.lastSeqNo+1))))
    76  	})
    77  }
    78  
    79  func TestMessageQueue_Close(t *testing.T) {
    80  	q := newMessageQueue()
    81  	testErr := errors.New("test")
    82  	require.NoError(t, q.Close(testErr))
    83  	require.Error(t, q.Close(errors.New("second")))
    84  	require.Equal(t, testErr, q.closedErr)
    85  	require.True(t, q.closed)
    86  	<-q.closedChan
    87  }
    88  
    89  func TestMessageQueue_GetMessages(t *testing.T) {
    90  	ctx := context.Background()
    91  	t.Run("Simple", func(t *testing.T) {
    92  		q := newMessageQueue()
    93  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2)))
    94  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(3, 4)))
    95  
    96  		messages, err := q.GetMessagesForSend(ctx)
    97  		require.NoError(t, err)
    98  		require.Equal(t, []int64{1, 2, 3, 4}, getSeqNumbers(messages))
    99  	})
   100  
   101  	t.Run("SendMessagesAfterStartWait", func(t *testing.T) {
   102  		q := newMessageQueue()
   103  
   104  		var err error
   105  		var messages []messageWithDataContent
   106  		gotMessages := make(empty.Chan)
   107  		go func() {
   108  			messages, err = q.GetMessagesForSend(ctx)
   109  			close(gotMessages)
   110  		}()
   111  
   112  		waitGetMessageStarted(&q)
   113  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2, 3)))
   114  
   115  		<-gotMessages
   116  		require.NoError(t, err)
   117  		require.Equal(t, []int64{1, 2, 3}, getSeqNumbers(messages))
   118  	})
   119  
   120  	t.Run("Stress", func(t *testing.T) {
   121  		iterations := 100000
   122  		q := newMessageQueue()
   123  
   124  		var lastSentSeqNo int64
   125  		sendFinished := make(empty.Chan)
   126  		fatalChan := make(chan string)
   127  
   128  		go func() {
   129  			//nolint:gosec
   130  			sendRand := rand.New(rand.NewSource(0))
   131  			for i := 0; i < iterations; i++ {
   132  				count := sendRand.Intn(10) + 1
   133  				var m []messageWithDataContent
   134  				for k := 0; k < count; k++ {
   135  					number := int(atomic.AddInt64(&lastSentSeqNo, 1))
   136  					m = append(m, newTestMessageWithDataContent(number))
   137  				}
   138  				require.NoError(t, q.AddMessages(m))
   139  			}
   140  			close(sendFinished)
   141  		}()
   142  
   143  		readFinished := make(empty.Chan)
   144  		var lastReadSeqNo atomic.Int64
   145  
   146  		readCtx, readCancel := xcontext.WithCancel(ctx)
   147  		defer readCancel()
   148  
   149  		go func() {
   150  			defer close(readFinished)
   151  
   152  			for {
   153  				messages, err := q.GetMessagesForSend(readCtx)
   154  				if err != nil {
   155  					break
   156  				}
   157  
   158  				for _, mess := range messages {
   159  					if lastReadSeqNo.Load()+1 != mess.SeqNo {
   160  						fatalChan <- string(debug.Stack())
   161  
   162  						return
   163  					}
   164  					lastReadSeqNo.Store(mess.SeqNo)
   165  				}
   166  			}
   167  		}()
   168  
   169  		select {
   170  		case <-sendFinished:
   171  		case stack := <-fatalChan:
   172  			t.Fatal(stack)
   173  		}
   174  
   175  		waitTimeout := time.Second * 10
   176  		startWait := time.Now()
   177  	waitReader:
   178  		for {
   179  			if lastReadSeqNo.Load() == lastSentSeqNo {
   180  				readCancel()
   181  			}
   182  			select {
   183  			case <-readFinished:
   184  				break waitReader
   185  			case stack := <-fatalChan:
   186  				t.Fatal(stack)
   187  			default:
   188  			}
   189  
   190  			runtime.Gosched()
   191  			if time.Since(startWait) > waitTimeout {
   192  				t.Fatal()
   193  			}
   194  		}
   195  	})
   196  
   197  	t.Run("ClosedContext", func(t *testing.T) {
   198  		closedCtx, cancel := xcontext.WithCancel(ctx)
   199  		cancel()
   200  
   201  		q := newMessageQueue()
   202  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2)))
   203  
   204  		_, err := q.GetMessagesForSend(closedCtx)
   205  		require.ErrorIs(t, err, context.Canceled)
   206  	})
   207  
   208  	t.Run("CallOnClosedQueue", func(t *testing.T) {
   209  		q := newMessageQueue()
   210  		_ = q.Close(errors.New("test"))
   211  		_, err := q.GetMessagesForSend(ctx)
   212  		require.Error(t, err)
   213  	})
   214  
   215  	t.Run("CloseContextAfterCall", func(t *testing.T) {
   216  		q := newMessageQueue()
   217  		q.notifyNewMessages()
   218  
   219  		var err error
   220  		gotErr := make(empty.Chan)
   221  		go func() {
   222  			_, err = q.GetMessagesForSend(ctx)
   223  			close(gotErr)
   224  		}()
   225  
   226  		waitGetMessageStarted(&q)
   227  
   228  		testErr := errors.New("test")
   229  		require.NoError(t, q.Close(testErr))
   230  
   231  		<-gotErr
   232  		require.ErrorIs(t, err, testErr)
   233  	})
   234  }
   235  
   236  func TestMessageQueue_ResetSentProgress(t *testing.T) {
   237  	ctx := context.Background()
   238  
   239  	t.Run("Simple", func(t *testing.T) {
   240  		q := newMessageQueue()
   241  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2, 3)))
   242  		res1, err := q.GetMessagesForSend(ctx)
   243  		require.NoError(t, err)
   244  
   245  		q.ResetSentProgress()
   246  		require.Equal(t, 0, q.lastSentIndex)
   247  		require.Equal(t, 3, q.lastWrittenIndex)
   248  		res2, err := q.GetMessagesForSend(ctx)
   249  		require.NoError(t, err)
   250  		require.Equal(t, res1, res2)
   251  	})
   252  
   253  	t.Run("Overflow", func(t *testing.T) {
   254  		q := newMessageQueue()
   255  		q.lastWrittenIndex = maxInt - 1
   256  		q.lastSentIndex = q.lastWrittenIndex
   257  
   258  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2, 3)))
   259  		res1, err := q.GetMessagesForSend(ctx)
   260  		require.NoError(t, err)
   261  
   262  		q.ResetSentProgress()
   263  		require.Equal(t, maxInt-1, q.lastSentIndex)
   264  		require.Equal(t, minInt+1, q.lastWrittenIndex)
   265  		res2, err := q.GetMessagesForSend(ctx)
   266  		require.NoError(t, err)
   267  		require.Equal(t, res1, res2)
   268  	})
   269  }
   270  
   271  func TestIsFirstCycledIndexLess(t *testing.T) {
   272  	table := []struct {
   273  		name   string
   274  		first  int
   275  		second int
   276  		result bool
   277  	}{
   278  		{
   279  			name:   "smallPositivesFirstLess",
   280  			first:  1,
   281  			second: 2,
   282  			result: true,
   283  		},
   284  		{
   285  			name:   "smallPositivesEquals",
   286  			first:  1,
   287  			second: 1,
   288  			result: false,
   289  		},
   290  		{
   291  			name:   "smallPositivesFirstGreater",
   292  			first:  2,
   293  			second: 1,
   294  			result: false,
   295  		},
   296  		{
   297  			name:   "edgePositivesFirstLess",
   298  			first:  minPositiveIndexWhichOrderLessThenNegative - 1,
   299  			second: minPositiveIndexWhichOrderLessThenNegative,
   300  			result: true,
   301  		},
   302  		{
   303  			name:   "edgePositivesFirstGreater",
   304  			first:  minPositiveIndexWhichOrderLessThenNegative,
   305  			second: minPositiveIndexWhichOrderLessThenNegative - 1,
   306  			result: false,
   307  		},
   308  		{
   309  			name:   "overflowEdgeFirstPositive",
   310  			first:  maxInt,
   311  			second: minInt,
   312  			result: true,
   313  		},
   314  		{
   315  			name:   "overflowEdgeFirstNegative",
   316  			first:  minInt,
   317  			second: maxInt,
   318  			result: false,
   319  		},
   320  		{
   321  			name:   "nearZeroFirstNegativeSecondZero",
   322  			first:  -1,
   323  			second: 0,
   324  			result: true,
   325  		},
   326  		{
   327  			name:   "nearZeroFirstZeroSecondNegative",
   328  			first:  0,
   329  			second: -1,
   330  			result: false,
   331  		},
   332  		{
   333  			name:   "nearZeroFirstZeroSecondPositive",
   334  			first:  0,
   335  			second: 1,
   336  			result: true,
   337  		},
   338  		{
   339  			name:   "nearZeroFirstNegativeSecondPositive",
   340  			first:  -1,
   341  			second: 1,
   342  			result: true,
   343  		},
   344  		{
   345  			name:   "nearZeroFirstPositiveSecondNegative",
   346  			first:  1,
   347  			second: -1,
   348  			result: false,
   349  		},
   350  	}
   351  
   352  	for _, test := range table {
   353  		t.Run(test.name, func(t *testing.T) {
   354  			require.Equal(t, test.result, isFirstCycledIndexLess(test.first, test.second))
   355  		})
   356  	}
   357  }
   358  
   359  func TestMinMaxIntConst(t *testing.T) {
   360  	v := maxInt
   361  	v++
   362  	require.Equal(t, minInt, v)
   363  }
   364  
   365  func TestSortIndexes(t *testing.T) {
   366  	table := []struct {
   367  		name     string
   368  		source   []int
   369  		expected []int
   370  	}{
   371  		{
   372  			name:     "empty",
   373  			source:   []int{},
   374  			expected: []int{},
   375  		},
   376  		{
   377  			name:     "usual",
   378  			source:   []int{30, 1, 2},
   379  			expected: []int{1, 2, 30},
   380  		},
   381  		{
   382  			name:     "nearZero",
   383  			source:   []int{0, 1, -1},
   384  			expected: []int{-1, 0, 1},
   385  		},
   386  		{
   387  			name:     "indexoverflow",
   388  			source:   []int{minInt, minInt + 1, maxInt - 1, maxInt},
   389  			expected: []int{maxInt - 1, maxInt, minInt, minInt + 1},
   390  		},
   391  	}
   392  
   393  	for _, test := range table {
   394  		t.Run(test.name, func(t *testing.T) {
   395  			sortMessageQueueIndexes(test.source)
   396  			require.Equal(t, test.expected, test.source)
   397  		})
   398  	}
   399  }
   400  
   401  func TestQueuePanicOnOverflow(t *testing.T) {
   402  	require.Panics(t, func() {
   403  		q := newMessageQueue()
   404  		q.messagesByOrder[123] = messageWithDataContent{}
   405  		q.lastWrittenIndex = maxInt
   406  		q.addMessageNeedLock(messageWithDataContent{})
   407  	})
   408  }
   409  
   410  func TestRegressionIssue1038_ReceiveAckAfterCloseQueue(t *testing.T) {
   411  	counter := 0
   412  
   413  	q := newMessageQueue()
   414  	q.OnAckReceived = func(count int) {
   415  		counter -= count
   416  	}
   417  	require.NoError(t, q.AddMessages(newTestMessagesWithContent(1)))
   418  	counter++
   419  
   420  	require.NoError(t, q.Close(errors.New("test err")))
   421  	require.ErrorIs(t, q.AcksReceived([]rawtopicwriter.WriteAck{
   422  		{
   423  			SeqNo:              1,
   424  			MessageWriteStatus: rawtopicwriter.MessageWriteStatus{},
   425  		},
   426  	}), errAckOnClosedMessageQueue)
   427  	require.Zero(t, counter)
   428  }
   429  
   430  func TestQueue_Ack(t *testing.T) {
   431  	t.Run("First", func(t *testing.T) {
   432  		q := newMessageQueue()
   433  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1, 2, 5)))
   434  
   435  		require.NoError(t, q.AcksReceived([]rawtopicwriter.WriteAck{
   436  			{
   437  				SeqNo: 2,
   438  			},
   439  		}))
   440  		expectedMap := map[int]messageWithDataContent{
   441  			1: newTestMessageWithDataContent(1),
   442  			3: newTestMessageWithDataContent(5),
   443  		}
   444  		require.Equal(t, expectedMap, q.messagesByOrder)
   445  	})
   446  	t.Run("Unexisted", func(t *testing.T) {
   447  		q := newMessageQueue()
   448  		require.NoError(t, q.AddMessages(newTestMessagesWithContent(1)))
   449  
   450  		// remove first with the seqno
   451  		require.Error(t, q.AcksReceived([]rawtopicwriter.WriteAck{
   452  			{
   453  				SeqNo: 5,
   454  			},
   455  		}))
   456  
   457  		expectedMap := map[int]messageWithDataContent{
   458  			1: newTestMessageWithDataContent(1),
   459  		}
   460  
   461  		require.Equal(t, expectedMap, q.messagesByOrder)
   462  	})
   463  
   464  	t.Run("OnAckReceived", func(t *testing.T) {
   465  		receivedCount := 0
   466  
   467  		q := newMessageQueue()
   468  		q.OnAckReceived = func(count int) {
   469  			receivedCount = count
   470  		}
   471  
   472  		err := q.AddMessages(newTestMessagesWithContent(1, 2, 3))
   473  		require.NoError(t, err)
   474  
   475  		err = q.AcksReceived([]rawtopicwriter.WriteAck{
   476  			{
   477  				SeqNo: 1,
   478  			},
   479  			{
   480  				SeqNo: 3,
   481  			},
   482  		})
   483  
   484  		require.NoError(t, err)
   485  		require.Equal(t, 2, receivedCount)
   486  
   487  		// Double ack
   488  		err = q.AcksReceived([]rawtopicwriter.WriteAck{
   489  			{
   490  				SeqNo: 1,
   491  			},
   492  			{
   493  				SeqNo: 3,
   494  			},
   495  		})
   496  
   497  		require.Error(t, err)
   498  		require.Equal(t, 0, receivedCount)
   499  	})
   500  }
   501  
   502  func waitGetMessageStarted(q *messageQueue) {
   503  	q.notifyNewMessages()
   504  	for len(q.hasNewMessages) != 0 {
   505  		runtime.Gosched()
   506  	}
   507  }
   508  
   509  func getSeqNumbers(messages []messageWithDataContent) []int64 {
   510  	res := make([]int64, 0, len(messages))
   511  	for i := range messages {
   512  		res = append(res, messages[i].SeqNo)
   513  	}
   514  
   515  	return res
   516  }