github.com/ethersphere/bee/v2@v2.2.0/pkg/p2p/streamtest/streamtest.go (about) 1 // Copyright 2020 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package streamtest 6 7 import ( 8 "context" 9 "errors" 10 "io" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/ethersphere/bee/v2/pkg/p2p" 16 "github.com/ethersphere/bee/v2/pkg/spinlock" 17 "github.com/ethersphere/bee/v2/pkg/swarm" 18 ma "github.com/multiformats/go-multiaddr" 19 ) 20 21 var ( 22 ErrRecordsNotFound = errors.New("records not found") 23 ErrStreamNotSupported = errors.New("stream not supported") 24 ErrStreamClosed = errors.New("stream closed") 25 26 noopMiddleware = func(f p2p.HandlerFunc) p2p.HandlerFunc { 27 return f 28 } 29 ) 30 31 type Recorder struct { 32 base swarm.Address 33 fullNode bool 34 records map[string][]*Record 35 recordsMu sync.Mutex 36 protocols []p2p.ProtocolSpec 37 middlewares []p2p.HandlerMiddleware 38 streamErr func(swarm.Address, string, string, string) error 39 pingErr func(ma.Multiaddr) (time.Duration, error) 40 protocolsWithPeers map[string]p2p.ProtocolSpec 41 } 42 43 func WithProtocols(protocols ...p2p.ProtocolSpec) Option { 44 return optionFunc(func(r *Recorder) { 45 r.protocols = append(r.protocols, protocols...) 46 }) 47 } 48 49 func WithPeerProtocols(protocolsWithPeers map[string]p2p.ProtocolSpec) Option { 50 return optionFunc(func(r *Recorder) { 51 r.protocolsWithPeers = protocolsWithPeers 52 }) 53 } 54 55 func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option { 56 return optionFunc(func(r *Recorder) { 57 r.middlewares = append(r.middlewares, middlewares...) 58 }) 59 } 60 61 func WithBaseAddr(a swarm.Address) Option { 62 return optionFunc(func(r *Recorder) { 63 r.base = a 64 }) 65 } 66 67 func WithLightNode() Option { 68 return optionFunc(func(r *Recorder) { 69 r.fullNode = false 70 }) 71 } 72 73 func WithStreamError(streamErr func(swarm.Address, string, string, string) error) Option { 74 return optionFunc(func(r *Recorder) { 75 r.streamErr = streamErr 76 }) 77 } 78 79 func WithPingErr(pingErr func(ma.Multiaddr) (time.Duration, error)) Option { 80 return optionFunc(func(r *Recorder) { 81 r.pingErr = pingErr 82 }) 83 } 84 85 func New(opts ...Option) *Recorder { 86 r := &Recorder{ 87 records: make(map[string][]*Record), 88 fullNode: true, 89 } 90 91 r.middlewares = append(r.middlewares, noopMiddleware) 92 93 for _, o := range opts { 94 o.apply(r) 95 } 96 return r 97 } 98 99 func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) { 100 r.protocols = append(r.protocols, protocols...) 101 } 102 103 func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) { 104 if r.streamErr != nil { 105 err := r.streamErr(addr, protocolName, protocolVersion, streamName) 106 if err != nil { 107 return nil, err 108 } 109 } 110 111 recordIn := newRecord() 112 recordOut := newRecord() 113 streamOut := newStream(recordIn, recordOut) 114 streamIn := newStream(recordOut, recordIn) 115 116 var handler p2p.HandlerFunc 117 var headler p2p.HeadlerFunc 118 peerHandlers, ok := r.protocolsWithPeers[addr.String()] 119 if !ok { 120 for _, p := range r.protocols { 121 if p.Name == protocolName && p.Version == protocolVersion { 122 peerHandlers = p 123 } 124 } 125 } 126 for _, s := range peerHandlers.StreamSpecs { 127 if s.Name == streamName { 128 handler = s.Handler 129 headler = s.Headler 130 } 131 } 132 if handler == nil { 133 return nil, ErrStreamNotSupported 134 } 135 for i := len(r.middlewares) - 1; i >= 0; i-- { 136 handler = r.middlewares[i](handler) 137 } 138 if headler != nil { 139 streamOut.headers = headler(h, addr) 140 } 141 record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})} 142 go func() { 143 defer close(record.done) 144 145 // pass a new context to handler, 146 streamIn.responseHeaders = streamOut.headers 147 // do not cancel it with the client stream context 148 err := handler(context.Background(), p2p.Peer{Address: r.base, FullNode: r.fullNode}, streamIn) 149 if err != nil && !errors.Is(err, io.EOF) { 150 record.setErr(err) 151 } 152 }() 153 154 id := addr.String() + p2p.NewSwarmStreamName(protocolName, protocolVersion, streamName) 155 156 r.recordsMu.Lock() 157 defer r.recordsMu.Unlock() 158 159 r.records[id] = append(r.records[id], record) 160 return streamOut, nil 161 } 162 163 func (r *Recorder) Ping(ctx context.Context, addr ma.Multiaddr) (rtt time.Duration, err error) { 164 if r.pingErr != nil { 165 return r.pingErr(addr) 166 } 167 return rtt, err 168 } 169 170 func (r *Recorder) Records(addr swarm.Address, protocolName, protocolVersio, streamName string) ([]*Record, error) { 171 id := addr.String() + p2p.NewSwarmStreamName(protocolName, protocolVersio, streamName) 172 173 r.recordsMu.Lock() 174 defer r.recordsMu.Unlock() 175 176 records, ok := r.records[id] 177 if !ok { 178 return nil, ErrRecordsNotFound 179 } 180 // wait for all records goroutines to terminate 181 for _, r := range records { 182 <-r.done 183 } 184 return records, nil 185 } 186 187 // WaitRecords waits for some time for records to come into the recorder. If msgs is 0, the timeoutSec period is waited to verify 188 // that _no_ messages arrive during this time period. 189 func (r *Recorder) WaitRecords(t *testing.T, addr swarm.Address, proto, version, stream string, msgs, timeoutSec int) []*Record { 190 t.Helper() 191 192 var recs []*Record 193 err := spinlock.Wait(time.Second*time.Duration(timeoutSec), func() bool { 194 recs, _ = r.Records(addr, proto, version, stream) 195 if l := len(recs); l > msgs { 196 t.Fatalf("too many records. want %d got %d", msgs, l) 197 } else if msgs > 0 && l == msgs { 198 return true 199 } 200 return false 201 // we can be here if msgs == 0 && l == 0 202 // or msgs = x && l < x, both cases are fine 203 // and we should continue waiting 204 }) 205 if err != nil && msgs > 0 { 206 t.Fatal("timed out while waiting for records") 207 } 208 209 return recs 210 } 211 212 type Record struct { 213 in *record 214 out *record 215 err error 216 errMu sync.Mutex 217 done chan struct{} 218 } 219 220 func (r *Record) In() []byte { 221 return r.in.bytes() 222 } 223 224 func (r *Record) Out() []byte { 225 return r.out.bytes() 226 } 227 228 func (r *Record) Err() error { 229 r.errMu.Lock() 230 defer r.errMu.Unlock() 231 232 return r.err 233 } 234 235 func (r *Record) setErr(err error) { 236 r.errMu.Lock() 237 defer r.errMu.Unlock() 238 239 r.err = err 240 } 241 242 type stream struct { 243 in *record 244 out *record 245 headers p2p.Headers 246 responseHeaders p2p.Headers 247 closed bool 248 lock sync.Mutex 249 } 250 251 func newStream(in, out *record) *stream { 252 return &stream{in: in, out: out} 253 } 254 255 func (s *stream) Read(p []byte) (int, error) { 256 if s.Closed() { 257 return 0, ErrStreamClosed 258 } 259 260 return s.out.Read(p) 261 } 262 263 func (s *stream) Write(p []byte) (int, error) { 264 if s.Closed() { 265 return 0, ErrStreamClosed 266 } 267 268 return s.in.Write(p) 269 } 270 271 func (s *stream) Headers() p2p.Headers { 272 return s.headers 273 } 274 275 func (s *stream) ResponseHeaders() p2p.Headers { 276 return s.responseHeaders 277 } 278 279 func (s *stream) Close() error { 280 s.lock.Lock() 281 defer s.lock.Unlock() 282 283 if s.closed { 284 return ErrStreamClosed 285 } 286 287 s.closed = true 288 s.in.close() 289 290 return nil 291 } 292 293 func (s *stream) Closed() bool { 294 s.lock.Lock() 295 defer s.lock.Unlock() 296 297 return s.closed 298 } 299 300 func (s *stream) FullClose() error { 301 s.lock.Lock() 302 defer s.lock.Unlock() 303 304 if s.closed { 305 return ErrStreamClosed 306 } 307 308 s.closed = true 309 s.in.close() 310 s.out.close() 311 312 return nil 313 } 314 315 func (s *stream) Reset() (err error) { 316 return s.FullClose() 317 } 318 319 type record struct { 320 b []byte 321 c int 322 lock sync.Mutex 323 dataSigC chan struct{} 324 closed bool 325 } 326 327 func newRecord() *record { 328 return &record{ 329 dataSigC: make(chan struct{}, 16), 330 } 331 } 332 333 func (r *record) Read(p []byte) (n int, err error) { 334 for r.c == r.bytesSize() { 335 _, ok := <-r.dataSigC 336 if !ok { 337 return 0, io.EOF 338 } 339 } 340 341 r.lock.Lock() 342 defer r.lock.Unlock() 343 344 end := r.c + len(p) 345 if end > len(r.b) { 346 end = len(r.b) 347 } 348 n = copy(p, r.b[r.c:end]) 349 r.c += n 350 351 return n, nil 352 } 353 354 func (r *record) Write(p []byte) (int, error) { 355 r.lock.Lock() 356 defer r.lock.Unlock() 357 358 if r.closed { 359 return 0, ErrStreamClosed 360 } 361 362 r.b = append(r.b, p...) 363 r.dataSigC <- struct{}{} 364 365 return len(p), nil 366 } 367 368 func (r *record) close() { 369 r.lock.Lock() 370 defer r.lock.Unlock() 371 372 if r.closed { 373 return 374 } 375 376 r.closed = true 377 close(r.dataSigC) 378 } 379 380 func (r *record) bytes() []byte { 381 return r.b 382 } 383 384 func (r *record) bytesSize() int { 385 r.lock.Lock() 386 defer r.lock.Unlock() 387 return len(r.b) 388 } 389 390 type Option interface { 391 apply(*Recorder) 392 } 393 type optionFunc func(*Recorder) 394 395 func (f optionFunc) apply(r *Recorder) { f(r) } 396 397 var _ p2p.StreamerDisconnecter = (*RecorderDisconnecter)(nil) 398 399 type RecorderDisconnecter struct { 400 *Recorder 401 disconnected map[string]struct{} 402 blocklisted map[string]time.Duration 403 mu sync.RWMutex 404 } 405 406 func NewRecorderDisconnecter(r *Recorder) *RecorderDisconnecter { 407 return &RecorderDisconnecter{ 408 Recorder: r, 409 disconnected: make(map[string]struct{}), 410 blocklisted: make(map[string]time.Duration), 411 } 412 } 413 414 func (r *RecorderDisconnecter) Disconnect(overlay swarm.Address, _ string) error { 415 r.mu.Lock() 416 defer r.mu.Unlock() 417 418 r.disconnected[overlay.String()] = struct{}{} 419 return nil 420 } 421 422 func (r *RecorderDisconnecter) Blocklist(overlay swarm.Address, d time.Duration, _ string) error { 423 r.mu.Lock() 424 defer r.mu.Unlock() 425 426 r.blocklisted[overlay.String()] = d 427 return nil 428 } 429 430 func (r *RecorderDisconnecter) IsDisconnected(overlay swarm.Address) bool { 431 r.mu.RLock() 432 defer r.mu.RUnlock() 433 434 _, yes := r.disconnected[overlay.String()] 435 return yes 436 } 437 438 func (r *RecorderDisconnecter) IsBlocklisted(overlay swarm.Address) (bool, time.Duration) { 439 r.mu.RLock() 440 defer r.mu.RUnlock() 441 442 d, yes := r.blocklisted[overlay.String()] 443 return yes, d 444 } 445 446 // NetworkStatus implements p2p.NetworkStatuser interface. 447 // It always returns p2p.NetworkStatusAvailable. 448 func (r *RecorderDisconnecter) NetworkStatus() p2p.NetworkStatus { 449 return p2p.NetworkStatusAvailable 450 }