go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tests/lazy1_test.go (about)

     1  package tests
     2  
     3  import (
     4  	"bytes"
     5  	"go-ml.dev/pkg/base/fu"
     6  	"go-ml.dev/pkg/base/tables/csv"
     7  	"go-ml.dev/pkg/base/tables/rdb"
     8  	"go-ml.dev/pkg/iokit"
     9  	"gotest.tools/assert"
    10  	"testing"
    11  )
    12  
    13  func Test_LazyCsvRdb1(t *testing.T) {
    14  	const CSV = `id,f1,f2,f3,f4
    15  4,1,2,3,4
    16  8,5,6,7,8
    17  12,9,10,11,12
    18  `
    19  
    20  	z := csv.Source(iokit.StringIO(CSV)).
    21  		Map(func(x struct {
    22  			Id string    `id`
    23  			F  []float64 `f*`
    24  		}) (y struct {
    25  			Id     string
    26  			Target float64
    27  		}) {
    28  			y.Target = fu.Maxd(0, x.F...)
    29  			y.Id = x.Id
    30  			return
    31  		})
    32  
    33  	err := z.
    34  		Parallel().
    35  		Drain(rdb.Sink("sqlite3:file:/tmp/test.db",
    36  			rdb.Table("maxd"),
    37  			rdb.DropIfExists,
    38  			rdb.VARCHAR("Id").PrimaryKey(),
    39  			rdb.DECIMAL("Target", 2)))
    40  
    41  	assert.NilError(t, err)
    42  
    43  	c := rdb.Source("sqlite3:file:/tmp/test.db",
    44  		rdb.Table("maxd")).
    45  		Filter(func(x struct {
    46  			Id     string
    47  			Target string
    48  		}) bool {
    49  			return x.Id == x.Target
    50  		}).
    51  		LuckyCount()
    52  
    53  	assert.Assert(t, c == 3)
    54  
    55  	c = rdb.Source("sqlite3:file:/tmp/test.db",
    56  		rdb.Table("maxd")).
    57  		Filter(func(x struct {
    58  			Id     string
    59  			Target string
    60  		}) bool {
    61  			return x.Id != x.Target
    62  		}).
    63  		LuckyCount()
    64  
    65  	assert.Assert(t, c == 0)
    66  
    67  	bf := bytes.Buffer{}
    68  	err = rdb.Source("sqlite3:file:/tmp/test.db", rdb.Query("select Id from maxd")).
    69  		Map(struct {
    70  			Id string
    71  			F4 float64 `Id`
    72  		}{}).
    73  		Update(func(x struct{ F4 float64 }) (y struct{ F3 float64 }) {
    74  			y.F3 = x.F4 - 1
    75  			return
    76  		}).
    77  		Update(func(x struct{ F4 float64 }) (y struct{ F2, F1 float64 }) {
    78  			y.F1 = x.F4 - 3
    79  			y.F2 = x.F4 - 2
    80  			return
    81  		}).
    82  		Parallel().
    83  		Map(struct {
    84  			Id             string
    85  			F1, F2, F3, F4 float64
    86  		}{}).
    87  		Drain(csv.Sink(iokit.Writer(&bf),
    88  			csv.Column("Id").As("id"),
    89  			csv.Column("F*").Round(2).As("f*")))
    90  
    91  	assert.NilError(t, err)
    92  	assert.Assert(t, bf.String() == CSV)
    93  }
    94  
    95  /*
    96  func Test_LazyBatch(t *testing.T) {
    97  	dataset := fu.External("https://datahub.io/machine-learning/iris/r/iris.csv",
    98  		fu.Cached("go-model/dataset/iris/iris.csv"))
    99  
   100  	cls := tables.Enumset{}
   101  
   102  	z := csv.Source(dataset,
   103  		csv.Float32("sepallength").As("Feature1"),
   104  		csv.Float32("sepalwidth").As("Feature2"),
   105  		csv.Float32("petallength").As("Feature3"),
   106  		csv.Float32("petalwidth").As("Feature4"),
   107  		csv.Meta(cls.Integer(), "class").As("Label"))
   108  
   109  	q := z.RandSkip(42, 0.3).Parallel().LuckyCollect()
   110  	q2 := z.Rand(42, 0.3).Parallel().LuckyCollect()
   111  	assert.Assert(t, q.Len()+q2.Len() == z.LuckyCount())
   112  
   113  	n := 0
   114  	l := 0
   115  	batch := 30
   116  	z.RandSkip(42, 0.3).Batch(batch).Parallel().LuckyDrain(func(v reflect.Value) error {
   117  		if v.Kind() == reflect.Bool {
   118  			return nil
   119  		}
   120  		k := v.Interface().(*tables.Table)
   121  		for g := 0; g < fu.Mini(batch, k.Len()); g++ {
   122  			a, b := k.Row(g), q.Row(n*batch+g)
   123  			for e, c := range a {
   124  				assert.DeepEqual(t, b[e].Interface(), c.Interface())
   125  			}
   126  			for e, c := range b {
   127  				assert.DeepEqual(t, a[e].Interface(), c.Interface())
   128  			}
   129  		}
   130  		n++
   131  		l += k.Len()
   132  		return nil
   133  	})
   134  
   135  	assert.Assert(t, l == q.Len())
   136  }
   137  */