github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/client/internal/message/retry_test.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package message
    16  
    17  import (
    18  	"math/rand"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"google.golang.org/protobuf/proto"
    25  
    26  	"github.com/google/fleetspeak/fleetspeak/src/client/comms"
    27  	"github.com/google/fleetspeak/fleetspeak/src/client/service"
    28  	"github.com/google/fleetspeak/fleetspeak/src/client/stats"
    29  
    30  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    31  	anypb "google.golang.org/protobuf/types/known/anypb"
    32  )
    33  
    34  type statsCollector struct {
    35  	stats.RetryLoopCollector
    36  	retries, pending, pendingSize atomic.Int64
    37  }
    38  
    39  func (sc *statsCollector) BeforeMessageRetry(msg *fspb.Message) {
    40  	sc.retries.Add(1)
    41  }
    42  
    43  func (sc *statsCollector) MessagePending(msg *fspb.Message, size int) {
    44  	sc.pending.Add(1)
    45  	sc.pendingSize.Add(int64(size))
    46  }
    47  
    48  func (sc *statsCollector) MessageAcknowledged(msg *fspb.Message, size int) {
    49  	sc.pending.Add(-1)
    50  	sc.pendingSize.Add(-int64(size))
    51  }
    52  
    53  func makeMessages(count, size int) []service.AckMessage {
    54  	var ret []service.AckMessage
    55  	for i := range count {
    56  		payload := make([]byte, size)
    57  		rand.Read(payload)
    58  		ret = append(ret, service.AckMessage{
    59  			M: &fspb.Message{
    60  				MessageId: []byte{0, 0, 0, byte(i >> 8), byte(i | 0xFF)},
    61  				Source: &fspb.Address{
    62  					ServiceName: "TestService",
    63  					ClientId:    []byte{0, 0, 1},
    64  				},
    65  				Destination: &fspb.Address{
    66  					ServiceName: "TestService",
    67  				},
    68  				MessageType: "TestMessageType",
    69  				Data:        &anypb.Any{Value: payload},
    70  			}})
    71  	}
    72  	return ret
    73  }
    74  
    75  func TestRetryLoopNormal(t *testing.T) {
    76  	sc := &statsCollector{}
    77  	in := make(chan service.AckMessage)
    78  	out := make(chan comms.MessageInfo, 100)
    79  	go RetryLoop(in, out, sc, 20*1024*1024, 100)
    80  	defer close(in)
    81  
    82  	// Normal flow.
    83  	msgs := makeMessages(10, 5)
    84  	for _, m := range msgs {
    85  		in <- m
    86  	}
    87  
    88  	for _, m := range msgs {
    89  		got := <-out
    90  		if !proto.Equal(m.M, got.M) {
    91  			t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m)
    92  		}
    93  		got.Ack()
    94  	}
    95  	select {
    96  	case mi := <-out:
    97  		t.Errorf("Expected empty output channel, but read: %v", mi.M)
    98  	default:
    99  	}
   100  
   101  	retries := sc.retries.Load()
   102  	if retries != 0 {
   103  		t.Errorf("Unexpected number of retries reported, got: %d, want: 0", retries)
   104  	}
   105  }
   106  
   107  func TestRetryLoopNACK(t *testing.T) {
   108  	sc := &statsCollector{}
   109  	in := make(chan service.AckMessage)
   110  	out := make(chan comms.MessageInfo, 100)
   111  	go RetryLoop(in, out, sc, 20*1024*1024, 100)
   112  	defer close(in)
   113  
   114  	// Nack flow.
   115  	msgs := makeMessages(10, 5)
   116  
   117  	for _, m := range msgs {
   118  		in <- m
   119  	}
   120  	for _, m := range msgs {
   121  		got := <-out
   122  		if !proto.Equal(m.M, got.M) {
   123  			t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m)
   124  		}
   125  		got.Nack()
   126  	}
   127  	for _, m := range msgs {
   128  		got := <-out
   129  		if !proto.Equal(m.M, got.M) {
   130  			t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m)
   131  		}
   132  		got.Ack()
   133  	}
   134  	select {
   135  	case mi := <-out:
   136  		t.Errorf("Expected empty output channel, but read: %v", mi.M)
   137  	default:
   138  	}
   139  
   140  	retries := sc.retries.Load()
   141  	if retries != 10 {
   142  		t.Errorf("Unexpected number of retries reported, got: %d, want: 10", retries)
   143  	}
   144  }
   145  
   146  func TestRetryLoopSizing(t *testing.T) {
   147  	sc := &statsCollector{}
   148  	in := make(chan service.AckMessage)
   149  	out := make(chan comms.MessageInfo, 100)
   150  	go RetryLoop(in, out, sc, 20*1024*1024, 100)
   151  	defer close(in)
   152  
   153  	// Two test cases in which we try to overfill the buffer.
   154  	for _, tc := range []struct {
   155  		name                   string
   156  		count, size, shouldFit int
   157  	}{
   158  		{"Small Messages", 300, 5, 100},
   159  		{"Large Messages", 30, 1024 * 1024, 20},
   160  	} {
   161  		t.Run(tc.name, func(t *testing.T) {
   162  			// shouldFit should fit
   163  			msgs := makeMessages(tc.count, tc.size)
   164  			for i := range tc.shouldFit {
   165  				in <- msgs[i]
   166  			}
   167  
   168  			// Another message should not fit. Wait just a bit to make sure that it
   169  			// really won't fit.
   170  			select {
   171  			case in <- service.AckMessage{M: &fspb.Message{MessageId: []byte("asdf")}}:
   172  				t.Error("Was able to overstuff in.")
   173  			case <-time.After(100 * time.Millisecond):
   174  			}
   175  
   176  			var w sync.WaitGroup
   177  			w.Add(1)
   178  			// stuff the rest in as they fit:
   179  			go func() {
   180  				for i := tc.shouldFit; i < len(msgs); i++ {
   181  					in <- msgs[i]
   182  				}
   183  				w.Done()
   184  			}()
   185  
   186  			// Reading them all should be fine, so long as we ack them.
   187  			for _, m := range msgs {
   188  				got := <-out
   189  				if !proto.Equal(m.M, got.M) {
   190  					t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m)
   191  				}
   192  				got.Ack()
   193  			}
   194  			w.Wait()
   195  			select {
   196  			case mi := <-out:
   197  				t.Errorf("Expected empty output channel, but read: %v", mi.M)
   198  			default:
   199  			}
   200  
   201  			retries := sc.retries.Load()
   202  			if retries != 0 {
   203  				t.Errorf("Unexpected number of retries reported, got: %d, want: 0", retries)
   204  			}
   205  		})
   206  	}
   207  }
   208  
   209  func TestRetryLoopReportsPendingMessages(t *testing.T) {
   210  	sc := &statsCollector{}
   211  	in := make(chan service.AckMessage)
   212  	out := make(chan comms.MessageInfo, 100)
   213  	go RetryLoop(in, out, sc, 20*1024*1024, 100)
   214  	defer close(in)
   215  
   216  	msgs := makeMessages(10, 5)
   217  	var totalByteSize int64
   218  	for _, m := range msgs {
   219  		totalByteSize += int64(proto.Size(m.M))
   220  		in <- m
   221  	}
   222  
   223  	// Give RetryLoop goroutine a short while to take in msgs
   224  	time.Sleep(100 * time.Millisecond)
   225  	pending := sc.pending.Load()
   226  	if pending != 10 {
   227  		t.Errorf("Unexpected number of pending messages, got: %d, want: 10", pending)
   228  	}
   229  	pendingSize := sc.pendingSize.Load()
   230  	if pendingSize != totalByteSize {
   231  		t.Errorf("Unexpected size of pending messages, got: %d, want: %d", pendingSize, totalByteSize)
   232  	}
   233  
   234  	for range msgs {
   235  		got := <-out
   236  		got.Ack()
   237  	}
   238  
   239  	// Give RetryLoop goroutine a short while to process acks
   240  	time.Sleep(100 * time.Millisecond)
   241  	pending = sc.pending.Load()
   242  	if pending != 0 {
   243  		t.Errorf("Unexpected number of pending messages, got: %d, want: 0", pending)
   244  	}
   245  	pendingSize = sc.pendingSize.Load()
   246  	if pendingSize != 0 {
   247  		t.Errorf("Unexpected size of pending messages, got: %d, want: 0", pendingSize)
   248  	}
   249  }