github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/serial_test.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/fluhus/gostuff/gnum"
    10  )
    11  
    12  func ExampleSerial() {
    13  	ngoroutines := 4
    14  	var results []float64
    15  
    16  	Serial[int, float64](ngoroutines,
    17  		// Read/generate input data.
    18  		RangeInput(1, 101),
    19  		// Some processing.
    20  		func(a, i, g int) (float64, error) {
    21  			return float64(a*a) + 0.5, nil
    22  		},
    23  		// Accumulate/forward outputs.
    24  		func(a float64) error {
    25  			results = append(results, a)
    26  			return nil
    27  		})
    28  
    29  	fmt.Println(results[:3], results[len(results)-3:])
    30  
    31  	// Output:
    32  	// [1.5 4.5 9.5] [9604.5 9801.5 10000.5]
    33  }
    34  
    35  func ExampleSerial_parallelAggregation() {
    36  	ngoroutines := 4
    37  	results := make([]int, ngoroutines) // Goroutine-specific data and objects.
    38  
    39  	Serial(
    40  		ngoroutines,
    41  		// Read/generate input data.
    42  		RangeInput(1, 101),
    43  		// Accumulate in goroutine-specific memory.
    44  		func(a int, i, g int) (int, error) {
    45  			results[g] += a
    46  			return 0, nil // Unused.
    47  		},
    48  		// No outputs.
    49  		func(a int) error { return nil })
    50  
    51  	// Collect the results of all goroutines.
    52  	fmt.Println("Sum of 1-100:", gnum.Sum(results))
    53  
    54  	// Output:
    55  	// Sum of 1-100: 5050
    56  }
    57  
    58  func TestSerial(t *testing.T) {
    59  	for _, nt := range []int{1, 2, 4, 8} {
    60  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    61  			n := nt * 100
    62  			var result []int
    63  			err := Serial(
    64  				nt,
    65  				RangeInput(0, n),
    66  				func(a int, i int, g int) (int, error) {
    67  					time.Sleep(time.Millisecond * time.Duration(rand.Intn(3)))
    68  					return a * a, nil
    69  				},
    70  				func(i int) error {
    71  					result = append(result, i)
    72  					return nil
    73  				})
    74  			if err != nil {
    75  				t.Fatalf("Serial(...) failed: %d", err)
    76  			}
    77  			for i := range result {
    78  				if result[i] != i*i {
    79  					t.Errorf("result[%d]=%d, want %d", i, result[i], i*i)
    80  				}
    81  			}
    82  		})
    83  	}
    84  }
    85  
    86  func TestSerial_error(t *testing.T) {
    87  	for _, nt := range []int{1, 2, 4, 8} {
    88  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    89  			n := nt * 100
    90  			var result []int
    91  			err := Serial(
    92  				nt,
    93  				RangeInput(0, n),
    94  				func(a int, i int, g int) (int, error) {
    95  					time.Sleep(time.Millisecond * time.Duration(rand.Intn(3)))
    96  					if a > 300 {
    97  						return 0, fmt.Errorf("a too big: %d", a)
    98  					}
    99  					return a * a, nil
   100  				},
   101  				func(i int) error {
   102  					result = append(result, i)
   103  					return nil
   104  				})
   105  			if nt <= 3 {
   106  				if err != nil {
   107  					t.Fatalf("Serial(...) failed: %d", err)
   108  				}
   109  				for i := range result {
   110  					if result[i] != i*i {
   111  						t.Errorf("result[%d]=%d, want %d", i, result[i], i*i)
   112  					}
   113  				}
   114  			} else { // n > 3
   115  				if err == nil {
   116  					t.Fatalf("Serial(...) succeeded, want error")
   117  				}
   118  			}
   119  		})
   120  	}
   121  }