go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tests/lazy1_test.go (about) 1 package tests 2 3 import ( 4 "bytes" 5 "go-ml.dev/pkg/base/fu" 6 "go-ml.dev/pkg/base/tables/csv" 7 "go-ml.dev/pkg/base/tables/rdb" 8 "go-ml.dev/pkg/iokit" 9 "gotest.tools/assert" 10 "testing" 11 ) 12 13 func Test_LazyCsvRdb1(t *testing.T) { 14 const CSV = `id,f1,f2,f3,f4 15 4,1,2,3,4 16 8,5,6,7,8 17 12,9,10,11,12 18 ` 19 20 z := csv.Source(iokit.StringIO(CSV)). 21 Map(func(x struct { 22 Id string `id` 23 F []float64 `f*` 24 }) (y struct { 25 Id string 26 Target float64 27 }) { 28 y.Target = fu.Maxd(0, x.F...) 29 y.Id = x.Id 30 return 31 }) 32 33 err := z. 34 Parallel(). 35 Drain(rdb.Sink("sqlite3:file:/tmp/test.db", 36 rdb.Table("maxd"), 37 rdb.DropIfExists, 38 rdb.VARCHAR("Id").PrimaryKey(), 39 rdb.DECIMAL("Target", 2))) 40 41 assert.NilError(t, err) 42 43 c := rdb.Source("sqlite3:file:/tmp/test.db", 44 rdb.Table("maxd")). 45 Filter(func(x struct { 46 Id string 47 Target string 48 }) bool { 49 return x.Id == x.Target 50 }). 51 LuckyCount() 52 53 assert.Assert(t, c == 3) 54 55 c = rdb.Source("sqlite3:file:/tmp/test.db", 56 rdb.Table("maxd")). 57 Filter(func(x struct { 58 Id string 59 Target string 60 }) bool { 61 return x.Id != x.Target 62 }). 63 LuckyCount() 64 65 assert.Assert(t, c == 0) 66 67 bf := bytes.Buffer{} 68 err = rdb.Source("sqlite3:file:/tmp/test.db", rdb.Query("select Id from maxd")). 69 Map(struct { 70 Id string 71 F4 float64 `Id` 72 }{}). 73 Update(func(x struct{ F4 float64 }) (y struct{ F3 float64 }) { 74 y.F3 = x.F4 - 1 75 return 76 }). 77 Update(func(x struct{ F4 float64 }) (y struct{ F2, F1 float64 }) { 78 y.F1 = x.F4 - 3 79 y.F2 = x.F4 - 2 80 return 81 }). 82 Parallel(). 83 Map(struct { 84 Id string 85 F1, F2, F3, F4 float64 86 }{}). 87 Drain(csv.Sink(iokit.Writer(&bf), 88 csv.Column("Id").As("id"), 89 csv.Column("F*").Round(2).As("f*"))) 90 91 assert.NilError(t, err) 92 assert.Assert(t, bf.String() == CSV) 93 } 94 95 /* 96 func Test_LazyBatch(t *testing.T) { 97 dataset := fu.External("https://datahub.io/machine-learning/iris/r/iris.csv", 98 fu.Cached("go-model/dataset/iris/iris.csv")) 99 100 cls := tables.Enumset{} 101 102 z := csv.Source(dataset, 103 csv.Float32("sepallength").As("Feature1"), 104 csv.Float32("sepalwidth").As("Feature2"), 105 csv.Float32("petallength").As("Feature3"), 106 csv.Float32("petalwidth").As("Feature4"), 107 csv.Meta(cls.Integer(), "class").As("Label")) 108 109 q := z.RandSkip(42, 0.3).Parallel().LuckyCollect() 110 q2 := z.Rand(42, 0.3).Parallel().LuckyCollect() 111 assert.Assert(t, q.Len()+q2.Len() == z.LuckyCount()) 112 113 n := 0 114 l := 0 115 batch := 30 116 z.RandSkip(42, 0.3).Batch(batch).Parallel().LuckyDrain(func(v reflect.Value) error { 117 if v.Kind() == reflect.Bool { 118 return nil 119 } 120 k := v.Interface().(*tables.Table) 121 for g := 0; g < fu.Mini(batch, k.Len()); g++ { 122 a, b := k.Row(g), q.Row(n*batch+g) 123 for e, c := range a { 124 assert.DeepEqual(t, b[e].Interface(), c.Interface()) 125 } 126 for e, c := range b { 127 assert.DeepEqual(t, a[e].Interface(), c.Interface()) 128 } 129 } 130 n++ 131 l += k.Len() 132 return nil 133 }) 134 135 assert.Assert(t, l == q.Len()) 136 } 137 */