github.com/fawick/restic@v0.1.1-0.20171126184616-c02923fbfc79/internal/worker/pool_test.go (about)

     1  package worker_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"github.com/restic/restic/internal/errors"
     8  
     9  	"github.com/restic/restic/internal/worker"
    10  )
    11  
    12  const concurrency = 10
    13  
    14  var errTooLarge = errors.New("too large")
    15  
    16  func square(ctx context.Context, job worker.Job) (interface{}, error) {
    17  	n := job.Data.(int)
    18  	if n > 2000 {
    19  		return nil, errTooLarge
    20  	}
    21  	return n * n, nil
    22  }
    23  
    24  func newBufferedPool(ctx context.Context, bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
    25  	inCh := make(chan worker.Job, bufsize)
    26  	outCh := make(chan worker.Job, bufsize)
    27  
    28  	return inCh, outCh, worker.New(ctx, n, f, inCh, outCh)
    29  }
    30  
    31  func TestPool(t *testing.T) {
    32  	inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
    33  
    34  	for i := 0; i < 150; i++ {
    35  		inCh <- worker.Job{Data: i}
    36  	}
    37  
    38  	close(inCh)
    39  	p.Wait()
    40  
    41  	for res := range outCh {
    42  		if res.Error != nil {
    43  			t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
    44  			continue
    45  		}
    46  
    47  		n := res.Data.(int)
    48  		m := res.Result.(int)
    49  
    50  		if m != n*n {
    51  			t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
    52  		}
    53  	}
    54  }
    55  
    56  func TestPoolErrors(t *testing.T) {
    57  	inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
    58  
    59  	for i := 0; i < 150; i++ {
    60  		inCh <- worker.Job{Data: i + 1900}
    61  	}
    62  
    63  	close(inCh)
    64  	p.Wait()
    65  
    66  	for res := range outCh {
    67  		n := res.Data.(int)
    68  
    69  		if n > 2000 {
    70  			if res.Error == nil {
    71  				t.Errorf("expected error not found, result is %v", res)
    72  				continue
    73  			}
    74  
    75  			if res.Error != errTooLarge {
    76  				t.Errorf("unexpected error found, result is %v", res)
    77  			}
    78  
    79  			continue
    80  		} else {
    81  			if res.Error != nil {
    82  				t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
    83  				continue
    84  			}
    85  		}
    86  
    87  		m := res.Result.(int)
    88  		if m != n*n {
    89  			t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
    90  		}
    91  	}
    92  }