github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/client/internal/message/retry_test.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package message 16 17 import ( 18 "math/rand" 19 "sync" 20 "sync/atomic" 21 "testing" 22 "time" 23 24 "google.golang.org/protobuf/proto" 25 26 "github.com/google/fleetspeak/fleetspeak/src/client/comms" 27 "github.com/google/fleetspeak/fleetspeak/src/client/service" 28 "github.com/google/fleetspeak/fleetspeak/src/client/stats" 29 30 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 31 anypb "google.golang.org/protobuf/types/known/anypb" 32 ) 33 34 type statsCollector struct { 35 stats.RetryLoopCollector 36 retries, pending, pendingSize atomic.Int64 37 } 38 39 func (sc *statsCollector) BeforeMessageRetry(msg *fspb.Message) { 40 sc.retries.Add(1) 41 } 42 43 func (sc *statsCollector) MessagePending(msg *fspb.Message, size int) { 44 sc.pending.Add(1) 45 sc.pendingSize.Add(int64(size)) 46 } 47 48 func (sc *statsCollector) MessageAcknowledged(msg *fspb.Message, size int) { 49 sc.pending.Add(-1) 50 sc.pendingSize.Add(-int64(size)) 51 } 52 53 func makeMessages(count, size int) []service.AckMessage { 54 var ret []service.AckMessage 55 for i := range count { 56 payload := make([]byte, size) 57 rand.Read(payload) 58 ret = append(ret, service.AckMessage{ 59 M: &fspb.Message{ 60 MessageId: []byte{0, 0, 0, byte(i >> 8), byte(i | 0xFF)}, 61 Source: &fspb.Address{ 62 ServiceName: "TestService", 63 ClientId: []byte{0, 0, 1}, 64 }, 65 Destination: &fspb.Address{ 66 ServiceName: "TestService", 67 }, 68 MessageType: "TestMessageType", 69 Data: &anypb.Any{Value: payload}, 70 }}) 71 } 72 return ret 73 } 74 75 func TestRetryLoopNormal(t *testing.T) { 76 sc := &statsCollector{} 77 in := make(chan service.AckMessage) 78 out := make(chan comms.MessageInfo, 100) 79 go RetryLoop(in, out, sc, 20*1024*1024, 100) 80 defer close(in) 81 82 // Normal flow. 83 msgs := makeMessages(10, 5) 84 for _, m := range msgs { 85 in <- m 86 } 87 88 for _, m := range msgs { 89 got := <-out 90 if !proto.Equal(m.M, got.M) { 91 t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m) 92 } 93 got.Ack() 94 } 95 select { 96 case mi := <-out: 97 t.Errorf("Expected empty output channel, but read: %v", mi.M) 98 default: 99 } 100 101 retries := sc.retries.Load() 102 if retries != 0 { 103 t.Errorf("Unexpected number of retries reported, got: %d, want: 0", retries) 104 } 105 } 106 107 func TestRetryLoopNACK(t *testing.T) { 108 sc := &statsCollector{} 109 in := make(chan service.AckMessage) 110 out := make(chan comms.MessageInfo, 100) 111 go RetryLoop(in, out, sc, 20*1024*1024, 100) 112 defer close(in) 113 114 // Nack flow. 115 msgs := makeMessages(10, 5) 116 117 for _, m := range msgs { 118 in <- m 119 } 120 for _, m := range msgs { 121 got := <-out 122 if !proto.Equal(m.M, got.M) { 123 t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m) 124 } 125 got.Nack() 126 } 127 for _, m := range msgs { 128 got := <-out 129 if !proto.Equal(m.M, got.M) { 130 t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m) 131 } 132 got.Ack() 133 } 134 select { 135 case mi := <-out: 136 t.Errorf("Expected empty output channel, but read: %v", mi.M) 137 default: 138 } 139 140 retries := sc.retries.Load() 141 if retries != 10 { 142 t.Errorf("Unexpected number of retries reported, got: %d, want: 10", retries) 143 } 144 } 145 146 func TestRetryLoopSizing(t *testing.T) { 147 sc := &statsCollector{} 148 in := make(chan service.AckMessage) 149 out := make(chan comms.MessageInfo, 100) 150 go RetryLoop(in, out, sc, 20*1024*1024, 100) 151 defer close(in) 152 153 // Two test cases in which we try to overfill the buffer. 154 for _, tc := range []struct { 155 name string 156 count, size, shouldFit int 157 }{ 158 {"Small Messages", 300, 5, 100}, 159 {"Large Messages", 30, 1024 * 1024, 20}, 160 } { 161 t.Run(tc.name, func(t *testing.T) { 162 // shouldFit should fit 163 msgs := makeMessages(tc.count, tc.size) 164 for i := range tc.shouldFit { 165 in <- msgs[i] 166 } 167 168 // Another message should not fit. Wait just a bit to make sure that it 169 // really won't fit. 170 select { 171 case in <- service.AckMessage{M: &fspb.Message{MessageId: []byte("asdf")}}: 172 t.Error("Was able to overstuff in.") 173 case <-time.After(100 * time.Millisecond): 174 } 175 176 var w sync.WaitGroup 177 w.Add(1) 178 // stuff the rest in as they fit: 179 go func() { 180 for i := tc.shouldFit; i < len(msgs); i++ { 181 in <- msgs[i] 182 } 183 w.Done() 184 }() 185 186 // Reading them all should be fine, so long as we ack them. 187 for _, m := range msgs { 188 got := <-out 189 if !proto.Equal(m.M, got.M) { 190 t.Errorf("Unexpected read from output channel. Got %v, want %v.", got.M, m) 191 } 192 got.Ack() 193 } 194 w.Wait() 195 select { 196 case mi := <-out: 197 t.Errorf("Expected empty output channel, but read: %v", mi.M) 198 default: 199 } 200 201 retries := sc.retries.Load() 202 if retries != 0 { 203 t.Errorf("Unexpected number of retries reported, got: %d, want: 0", retries) 204 } 205 }) 206 } 207 } 208 209 func TestRetryLoopReportsPendingMessages(t *testing.T) { 210 sc := &statsCollector{} 211 in := make(chan service.AckMessage) 212 out := make(chan comms.MessageInfo, 100) 213 go RetryLoop(in, out, sc, 20*1024*1024, 100) 214 defer close(in) 215 216 msgs := makeMessages(10, 5) 217 var totalByteSize int64 218 for _, m := range msgs { 219 totalByteSize += int64(proto.Size(m.M)) 220 in <- m 221 } 222 223 // Give RetryLoop goroutine a short while to take in msgs 224 time.Sleep(100 * time.Millisecond) 225 pending := sc.pending.Load() 226 if pending != 10 { 227 t.Errorf("Unexpected number of pending messages, got: %d, want: 10", pending) 228 } 229 pendingSize := sc.pendingSize.Load() 230 if pendingSize != totalByteSize { 231 t.Errorf("Unexpected size of pending messages, got: %d, want: %d", pendingSize, totalByteSize) 232 } 233 234 for range msgs { 235 got := <-out 236 got.Ack() 237 } 238 239 // Give RetryLoop goroutine a short while to process acks 240 time.Sleep(100 * time.Millisecond) 241 pending = sc.pending.Load() 242 if pending != 0 { 243 t.Errorf("Unexpected number of pending messages, got: %d, want: 0", pending) 244 } 245 pendingSize = sc.pendingSize.Load() 246 if pendingSize != 0 { 247 t.Errorf("Unexpected size of pending messages, got: %d, want: 0", pendingSize) 248 } 249 }