github.com/keltia/go-ipfs@v0.3.8-0.20150909044612-210793031c63/thirdparty/notifier/notifier_test.go (about)

     1  package notifier
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  )
     9  
    10  // test data structures
    11  type Router struct {
    12  	queue    chan Packet
    13  	notifier Notifier
    14  }
    15  
    16  type Packet struct{}
    17  
    18  type RouterNotifiee interface {
    19  	Enqueued(*Router, Packet)
    20  	Forwarded(*Router, Packet)
    21  	Dropped(*Router, Packet)
    22  }
    23  
    24  func (r *Router) Notify(n RouterNotifiee) {
    25  	r.notifier.Notify(n)
    26  }
    27  
    28  func (r *Router) StopNotify(n RouterNotifiee) {
    29  	r.notifier.StopNotify(n)
    30  }
    31  
    32  func (r *Router) notifyAll(notify func(n RouterNotifiee)) {
    33  	r.notifier.NotifyAll(func(n Notifiee) {
    34  		notify(n.(RouterNotifiee))
    35  	})
    36  }
    37  
    38  func (r *Router) Receive(p Packet) {
    39  
    40  	select {
    41  	case r.queue <- p: // enqueued
    42  		r.notifyAll(func(n RouterNotifiee) {
    43  			n.Enqueued(r, p)
    44  		})
    45  
    46  	default: // drop
    47  		r.notifyAll(func(n RouterNotifiee) {
    48  			n.Dropped(r, p)
    49  		})
    50  	}
    51  }
    52  
    53  func (r *Router) Forward() {
    54  	p := <-r.queue
    55  	r.notifyAll(func(n RouterNotifiee) {
    56  		n.Forwarded(r, p)
    57  	})
    58  }
    59  
    60  type Metrics struct {
    61  	enqueued  int
    62  	forwarded int
    63  	dropped   int
    64  	received  chan struct{}
    65  	sync.Mutex
    66  }
    67  
    68  func (m *Metrics) Enqueued(*Router, Packet) {
    69  	m.Lock()
    70  	m.enqueued++
    71  	m.Unlock()
    72  	if m.received != nil {
    73  		m.received <- struct{}{}
    74  	}
    75  }
    76  
    77  func (m *Metrics) Forwarded(*Router, Packet) {
    78  	m.Lock()
    79  	m.forwarded++
    80  	m.Unlock()
    81  	if m.received != nil {
    82  		m.received <- struct{}{}
    83  	}
    84  }
    85  
    86  func (m *Metrics) Dropped(*Router, Packet) {
    87  	m.Lock()
    88  	m.dropped++
    89  	m.Unlock()
    90  	if m.received != nil {
    91  		m.received <- struct{}{}
    92  	}
    93  }
    94  
    95  func (m *Metrics) String() string {
    96  	m.Lock()
    97  	defer m.Unlock()
    98  	return fmt.Sprintf("%d enqueued, %d forwarded, %d in queue, %d dropped",
    99  		m.enqueued, m.forwarded, m.enqueued-m.forwarded, m.dropped)
   100  }
   101  
   102  func TestNotifies(t *testing.T) {
   103  
   104  	m := Metrics{received: make(chan struct{})}
   105  	r := Router{queue: make(chan Packet, 10)}
   106  	r.Notify(&m)
   107  
   108  	for i := 0; i < 10; i++ {
   109  		r.Receive(Packet{})
   110  		<-m.received
   111  		if m.enqueued != (1 + i) {
   112  			t.Error("not notifying correctly", m.enqueued, 1+i)
   113  		}
   114  
   115  	}
   116  
   117  	for i := 0; i < 10; i++ {
   118  		r.Receive(Packet{})
   119  		<-m.received
   120  		if m.enqueued != 10 {
   121  			t.Error("not notifying correctly", m.enqueued, 10)
   122  		}
   123  		if m.dropped != (1 + i) {
   124  			t.Error("not notifying correctly", m.dropped, 1+i)
   125  		}
   126  	}
   127  }
   128  
   129  func TestStopsNotifying(t *testing.T) {
   130  	m := Metrics{received: make(chan struct{})}
   131  	r := Router{queue: make(chan Packet, 10)}
   132  	r.Notify(&m)
   133  
   134  	for i := 0; i < 5; i++ {
   135  		r.Receive(Packet{})
   136  		<-m.received
   137  		if m.enqueued != (1 + i) {
   138  			t.Error("not notifying correctly")
   139  		}
   140  	}
   141  
   142  	r.StopNotify(&m)
   143  
   144  	for i := 0; i < 5; i++ {
   145  		r.Receive(Packet{})
   146  		select {
   147  		case <-m.received:
   148  			t.Error("did not stop notifying")
   149  		default:
   150  		}
   151  		if m.enqueued != 5 {
   152  			t.Error("did not stop notifying")
   153  		}
   154  	}
   155  }
   156  
   157  func TestThreadsafe(t *testing.T) {
   158  	N := 1000
   159  	r := Router{queue: make(chan Packet, 10)}
   160  	m1 := Metrics{received: make(chan struct{})}
   161  	m2 := Metrics{received: make(chan struct{})}
   162  	m3 := Metrics{received: make(chan struct{})}
   163  	r.Notify(&m1)
   164  	r.Notify(&m2)
   165  	r.Notify(&m3)
   166  
   167  	var n int
   168  	var wg sync.WaitGroup
   169  	for i := 0; i < N; i++ {
   170  		n++
   171  		wg.Add(1)
   172  		go func() {
   173  			defer wg.Done()
   174  			r.Receive(Packet{})
   175  		}()
   176  
   177  		if i%3 == 0 {
   178  			n++
   179  			wg.Add(1)
   180  			go func() {
   181  				defer wg.Done()
   182  				r.Forward()
   183  			}()
   184  		}
   185  	}
   186  
   187  	// drain queues
   188  	for i := 0; i < (n * 3); i++ {
   189  		select {
   190  		case <-m1.received:
   191  		case <-m2.received:
   192  		case <-m3.received:
   193  		}
   194  	}
   195  
   196  	wg.Wait()
   197  
   198  	// counts should be correct and all agree. and this should
   199  	// run fine under `go test -race -cpu=5`
   200  
   201  	t.Log("m1", m1.String())
   202  	t.Log("m2", m2.String())
   203  	t.Log("m3", m3.String())
   204  
   205  	if m1.String() != m2.String() || m2.String() != m3.String() {
   206  		t.Error("counts disagree")
   207  	}
   208  }
   209  
   210  type highwatermark struct {
   211  	mu    sync.Mutex
   212  	mark  int
   213  	limit int
   214  	errs  chan error
   215  }
   216  
   217  func (m *highwatermark) incr() {
   218  	m.mu.Lock()
   219  	m.mark++
   220  	// fmt.Println("incr", m.mark)
   221  	if m.mark > m.limit {
   222  		m.errs <- fmt.Errorf("went over limit: %d/%d", m.mark, m.limit)
   223  	}
   224  	m.mu.Unlock()
   225  }
   226  
   227  func (m *highwatermark) decr() {
   228  	m.mu.Lock()
   229  	m.mark--
   230  	// fmt.Println("decr", m.mark)
   231  	if m.mark < 0 {
   232  		m.errs <- fmt.Errorf("went under zero: %d/%d", m.mark, m.limit)
   233  	}
   234  	m.mu.Unlock()
   235  }
   236  
   237  func TestLimited(t *testing.T) {
   238  	timeout := 10 * time.Second // huge timeout.
   239  	limit := 9
   240  
   241  	hwm := highwatermark{limit: limit, errs: make(chan error, 100)}
   242  	n := RateLimited(limit) // will stop after 3 rounds
   243  	n.Notify(1)
   244  	n.Notify(2)
   245  	n.Notify(3)
   246  
   247  	entr := make(chan struct{})
   248  	exit := make(chan struct{})
   249  	done := make(chan struct{})
   250  	go func() {
   251  		for i := 0; i < 10; i++ {
   252  			// fmt.Printf("round: %d\n", i)
   253  			n.NotifyAll(func(e Notifiee) {
   254  				hwm.incr()
   255  				entr <- struct{}{}
   256  				<-exit // wait
   257  				hwm.decr()
   258  			})
   259  		}
   260  		done <- struct{}{}
   261  	}()
   262  
   263  	for i := 0; i < 30; {
   264  		select {
   265  		case <-entr:
   266  			continue // let as many enter as possible
   267  		case <-time.After(1 * time.Millisecond):
   268  		}
   269  
   270  		// let one exit
   271  		select {
   272  		case <-entr:
   273  			continue // in case of timing issues.
   274  		case exit <- struct{}{}:
   275  		case <-time.After(timeout):
   276  			t.Error("got stuck")
   277  		}
   278  		i++
   279  	}
   280  
   281  	select {
   282  	case <-done: // two parts done
   283  	case <-time.After(timeout):
   284  		t.Error("did not finish")
   285  	}
   286  
   287  	close(hwm.errs)
   288  	for err := range hwm.errs {
   289  		t.Error(err)
   290  	}
   291  }