github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/nserial_test.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"testing"
     7  )
     8  
     9  func TestNonSerial(t *testing.T) {
    10  	want := 21082009.0
    11  	for _, nt := range []int{1, 2, 4, 8} {
    12  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    13  			got := 0.0
    14  			NonSerial[int, float64](
    15  				nt,
    16  				RangeInput(1, 100001),
    17  				func(a, g int) (float64, error) {
    18  					return math.Sqrt(float64(a)), nil
    19  				},
    20  				func(a float64) error {
    21  					got += a
    22  					return nil
    23  				},
    24  			)
    25  			if math.Round(got) != want {
    26  				t.Fatalf("NonSerial: got %f, want %f", got, want)
    27  			}
    28  		})
    29  	}
    30  }
    31  
    32  func TestNonSerial_inputError(t *testing.T) {
    33  	for _, nt := range []int{1, 2, 4, 8} {
    34  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    35  			got := 0.0
    36  			err := NonSerial[int, float64](
    37  				nt,
    38  				func(yield func(int, error) bool) {
    39  					for i, err := range RangeInput(1, 100001) {
    40  						if i == 1000 {
    41  							yield(0, fmt.Errorf("oh no"))
    42  							return
    43  						}
    44  						if !yield(i, err) {
    45  							return
    46  						}
    47  					}
    48  				},
    49  				func(a, g int) (float64, error) {
    50  					return math.Sqrt(float64(a)), nil
    51  				},
    52  				func(a float64) error {
    53  					got += a
    54  					return nil
    55  				},
    56  			)
    57  			if err == nil {
    58  				t.Fatalf("NonSerial succeeded, want error")
    59  			}
    60  		})
    61  	}
    62  }
    63  
    64  func TestNonSerial_transformError(t *testing.T) {
    65  	for _, nt := range []int{1, 2, 4, 8} {
    66  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    67  			got := 0.0
    68  			err := NonSerial[int, float64](
    69  				nt,
    70  				RangeInput(1, 100001),
    71  				func(a, g int) (float64, error) {
    72  					if a == 1000 {
    73  						return 0, fmt.Errorf("oh no")
    74  					}
    75  					return math.Sqrt(float64(a)), nil
    76  				},
    77  				func(a float64) error {
    78  					got += a
    79  					return nil
    80  				},
    81  			)
    82  			if err == nil {
    83  				t.Fatalf("NonSerial succeeded, want error")
    84  			}
    85  		})
    86  	}
    87  }
    88  
    89  func TestNonSerial_outputError(t *testing.T) {
    90  	for _, nt := range []int{1, 2, 4, 8} {
    91  		t.Run(fmt.Sprint(nt), func(t *testing.T) {
    92  			got := 0.0
    93  			err := NonSerial[int, float64](
    94  				nt,
    95  				RangeInput(1, 100001),
    96  				func(a, g int) (float64, error) {
    97  					return math.Sqrt(float64(a)), nil
    98  				},
    99  				func(a float64) error {
   100  					if a == 32 {
   101  						return fmt.Errorf("oh no")
   102  					}
   103  					got += a
   104  					return nil
   105  				},
   106  			)
   107  			if err == nil {
   108  				t.Fatalf("NonSerial succeeded, want error")
   109  			}
   110  		})
   111  	}
   112  }
   113  
   114  // TODO(amit): Error tests.