github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/parallel/pipeline_test.go (about)

     1  package parallel_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/devseccon/trivy/pkg/parallel"
    13  )
    14  
    15  func TestPipeline_Do(t *testing.T) {
    16  	type field struct {
    17  		numWorkers int
    18  		items      []float64
    19  		onItem     func(context.Context, float64) (float64, error)
    20  	}
    21  	type testCase struct {
    22  		name    string
    23  		field   field
    24  		want    float64
    25  		wantErr require.ErrorAssertionFunc
    26  	}
    27  	tests := []testCase{
    28  		{
    29  			name: "pow",
    30  			field: field{
    31  				numWorkers: 5,
    32  				items: []float64{
    33  					1,
    34  					2,
    35  					3,
    36  					4,
    37  					5,
    38  					6,
    39  					7,
    40  					8,
    41  					9,
    42  					10,
    43  				},
    44  				onItem: func(_ context.Context, f float64) (float64, error) {
    45  					return math.Pow(f, 2), nil
    46  				},
    47  			},
    48  			want:    385,
    49  			wantErr: require.NoError,
    50  		},
    51  		{
    52  			name: "ceil",
    53  			field: field{
    54  				numWorkers: 3,
    55  				items: []float64{
    56  					1.1,
    57  					2.2,
    58  					3.3,
    59  					4.4,
    60  					5.5,
    61  					-1.1,
    62  					-2.2,
    63  					-3.3,
    64  				},
    65  				onItem: func(_ context.Context, f float64) (float64, error) {
    66  					return math.Round(f), nil
    67  				},
    68  			},
    69  			want:    10,
    70  			wantErr: require.NoError,
    71  		},
    72  		{
    73  			name: "error in series",
    74  			field: field{
    75  				numWorkers: 1,
    76  				items: []float64{
    77  					1,
    78  					2,
    79  					3,
    80  				},
    81  				onItem: func(_ context.Context, f float64) (float64, error) {
    82  					return 0, fmt.Errorf("error")
    83  				},
    84  			},
    85  			wantErr: require.Error,
    86  		},
    87  		{
    88  			name: "error in parallel",
    89  			field: field{
    90  				numWorkers: 3,
    91  				items: []float64{
    92  					1,
    93  					2,
    94  				},
    95  				onItem: func(_ context.Context, f float64) (float64, error) {
    96  					return 0, fmt.Errorf("error")
    97  				},
    98  			},
    99  			wantErr: require.Error,
   100  		},
   101  	}
   102  	for _, tt := range tests {
   103  		t.Run(tt.name, func(t *testing.T) {
   104  			var got float64
   105  			p := parallel.NewPipeline(tt.field.numWorkers, false, tt.field.items, tt.field.onItem, func(f float64) error {
   106  				got += f
   107  				return nil
   108  			})
   109  			err := p.Do(context.Background())
   110  			tt.wantErr(t, err)
   111  			assert.Equal(t, tt.want, got)
   112  		})
   113  	}
   114  }