go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tables/lazy.go (about)

     1  package tables
     2  
     3  import (
     4  	"fmt"
     5  	"go-ml.dev/pkg/base/fu"
     6  	"go-ml.dev/pkg/base/fu/lazy"
     7  	"go-ml.dev/pkg/zorros"
     8  	"math"
     9  	"reflect"
    10  	"sync"
    11  )
    12  
    13  type Lazy lazy.Source
    14  type Sink lazy.Sink
    15  
    16  func (Lazy) IsLazy() bool     { return true }
    17  func (zf Lazy) Table() *Table { return zf.LuckyCollect() }
    18  func (zf Lazy) Lazy() Lazy    { return zf }
    19  
    20  func SourceError(err error) Lazy {
    21  	return func() lazy.Stream {
    22  		return func(_ uint64) (reflect.Value, error) {
    23  			return reflect.Value{}, err
    24  		}
    25  	}
    26  }
    27  
    28  func SinkError(err error) Sink {
    29  	return func(_ reflect.Value) error {
    30  		return err
    31  	}
    32  }
    33  
    34  func (zf Lazy) Map(f interface{}) Lazy {
    35  	return func() lazy.Stream {
    36  		z := zf()
    37  		vf := reflect.ValueOf(f)
    38  		vt := vf.Type()
    39  		or, ir := vt, vt
    40  		if vf.Kind() == reflect.Func {
    41  			ir = vt.In(0)
    42  			or = vt.Out(0)
    43  		} else if vf.Kind() != reflect.Struct {
    44  			panic("only func(struct{...})struct{...} and struct{...} is allowed as an argument of lazy.Map")
    45  		}
    46  		unwrap := fu.Unwrapper(ir)
    47  		wrap := fu.Wrapper(or)
    48  		return func(index uint64) (v reflect.Value, err error) {
    49  			if v, err = z(index); err != nil || v.Kind() == reflect.Bool {
    50  				return v, err
    51  			}
    52  			x := unwrap(v.Interface().(fu.Struct))
    53  			if vf.Kind() == reflect.Func {
    54  				x = vf.Call([]reflect.Value{x})[0]
    55  			}
    56  			return reflect.ValueOf(wrap(x)), nil
    57  		}
    58  	}
    59  }
    60  
    61  func (zf Lazy) Update(f interface{}) Lazy {
    62  	return func() lazy.Stream {
    63  		z := zf()
    64  		vf := reflect.ValueOf(f)
    65  		vt := vf.Type()
    66  		or, ir := vt, vt
    67  		if vf.Kind() == reflect.Func {
    68  			ir = vt.In(0)
    69  			or = vt.Out(0)
    70  		} else if vf.Kind() != reflect.Struct {
    71  			panic("only func(struct{...})struct{...} and struct{...} is allowed as an argument of lazy.Transform")
    72  		}
    73  		unwrap := fu.Unwrapper(ir)
    74  		transform := fu.Transformer(or)
    75  		return func(index uint64) (v reflect.Value, err error) {
    76  			if v, err = z(index); err != nil || v.Kind() == reflect.Bool {
    77  				return v, err
    78  			}
    79  			x := unwrap(v.Interface().(fu.Struct))
    80  			if vf.Kind() == reflect.Func {
    81  				x = vf.Call([]reflect.Value{x})[0]
    82  			}
    83  			return transform(x, v), nil
    84  		}
    85  	}
    86  }
    87  
    88  func (zf Lazy) Filter(f interface{}) Lazy {
    89  	return func() lazy.Stream {
    90  		z := zf()
    91  		vf := reflect.ValueOf(f)
    92  		vt := vf.Type()
    93  		unwrap := fu.Unwrapper(vt.In(0))
    94  		return func(index uint64) (v reflect.Value, err error) {
    95  			if v, err = z(index); err != nil || v.Kind() == reflect.Bool {
    96  				return v, err
    97  			}
    98  			x := unwrap(v.Interface().(fu.Struct))
    99  			if vf.Call([]reflect.Value{x})[0].Bool() {
   100  				return
   101  			}
   102  			return reflect.ValueOf(true), nil
   103  		}
   104  	}
   105  }
   106  
   107  func (zf Lazy) First(n int) Lazy {
   108  	return Lazy(lazy.Source(zf).First(n))
   109  }
   110  
   111  func (zf Lazy) Parallel(concurrency ...int) Lazy {
   112  	return Lazy(lazy.Source(zf).Parallel(concurrency...))
   113  }
   114  
   115  const iniCollectLength = 13
   116  const maxChankLength = 10000
   117  
   118  func (zf Lazy) Collect() (t *Table, err error) {
   119  	length := 0
   120  	columns := []reflect.Value{}
   121  	names := []string{}
   122  	na := []fu.Bits{}
   123  	err = zf.Drain(func(v reflect.Value) error {
   124  		if v.Kind() != reflect.Bool {
   125  			lr := v.Interface().(fu.Struct)
   126  			if length == 0 {
   127  				names = lr.Names
   128  				columns = make([]reflect.Value, len(names))
   129  				na = make([]fu.Bits, len(names))
   130  				for i, x := range lr.Columns {
   131  					columns[i] = reflect.MakeSlice(reflect.SliceOf(x.Type()), 0, iniCollectLength)
   132  				}
   133  			}
   134  			defer func() {
   135  				if e := recover(); e != nil {
   136  					fmt.Println(e)
   137  					fmt.Println(lr.String())
   138  					panic(e)
   139  				}
   140  			}()
   141  			for i, x := range lr.Columns {
   142  				if lr.Na.Bit(i) {
   143  					columns[i] = reflect.Append(columns[i], reflect.Zero(columns[i].Type().Elem()))
   144  					na[i].Set(length, true)
   145  				} else {
   146  					columns[i] = reflect.Append(columns[i], x)
   147  				}
   148  			}
   149  			length++
   150  		}
   151  		return nil
   152  	})
   153  	if err != nil {
   154  		return
   155  	}
   156  	return MakeTable(names, columns, na, length), nil
   157  }
   158  
   159  func (zf Lazy) LuckyCollect() *Table {
   160  	t, err := zf.Collect()
   161  	if err != nil {
   162  		panic(zorros.Panic(err))
   163  	}
   164  	return t
   165  }
   166  
   167  func (zf Lazy) Drain(sink Sink) (err error) {
   168  	return lazy.Source(zf).Drain(sink)
   169  }
   170  
   171  func (zf Lazy) LuckyDrain(sink Sink) {
   172  	if err := zf.Drain(sink); err != nil {
   173  		panic(zorros.Panic(err))
   174  	}
   175  }
   176  
   177  func (zf Lazy) Count() (int, error) {
   178  	return lazy.Source(zf).Count()
   179  }
   180  
   181  func (zf Lazy) LuckyCount() int {
   182  	c, err := zf.Count()
   183  	if err != nil {
   184  		panic(zorros.Panic(err))
   185  	}
   186  	return c
   187  }
   188  
   189  func (zf Lazy) Rand(seed int, prob float64) Lazy {
   190  	return Lazy(lazy.Source(zf).Rand(seed, prob))
   191  }
   192  
   193  func (zf Lazy) RandSkip(seed int, prob float64) Lazy {
   194  	return Lazy(lazy.Source(zf).RandSkip(seed, prob))
   195  }
   196  
   197  func (zf Lazy) RandomFlag(c string, seed int, prob float64) Lazy {
   198  	return func() lazy.Stream {
   199  		z := zf()
   200  		nr := fu.NaiveRandom{Value: uint32(seed)}
   201  		wc := fu.WaitCounter{Value: 0}
   202  		return func(index uint64) (v reflect.Value, err error) {
   203  			v, err = z(index)
   204  			if index == lazy.STOP {
   205  				wc.Stop()
   206  			}
   207  			if wc.Wait(index) {
   208  				if err == nil && v.Kind() != reflect.Bool {
   209  					lr := v.Interface().(fu.Struct)
   210  					p := nr.Float()
   211  					val := reflect.ValueOf(p < prob)
   212  					v = reflect.ValueOf(lr.Set(c, val))
   213  				}
   214  				wc.Inc()
   215  			}
   216  			return
   217  		}
   218  	}
   219  }
   220  
   221  func (zf Lazy) Round(prec int) Lazy {
   222  	return func() lazy.Stream {
   223  		z := zf()
   224  		return func(index uint64) (v reflect.Value, err error) {
   225  			v, err = z(index)
   226  			if err != nil || v.Kind() == reflect.Bool {
   227  				return
   228  			}
   229  			lrx := v.Interface().(fu.Struct)
   230  			lr := lrx.Copy(0)
   231  			for i, c := range lr.Columns {
   232  				switch c.Kind() {
   233  				case reflect.Float32:
   234  					lr.Columns[i] = reflect.ValueOf(fu.Round32(float32(c.Float()), prec))
   235  				case reflect.Float64:
   236  					lr.Columns[i] = reflect.ValueOf(fu.Round64(c.Float(), prec))
   237  				}
   238  			}
   239  			return reflect.ValueOf(lr), nil
   240  		}
   241  	}
   242  }
   243  
   244  func (zf Lazy) IfFlag(c string) Lazy {
   245  	return func() lazy.Stream {
   246  		z := zf()
   247  		return func(index uint64) (v reflect.Value, err error) {
   248  			v, err = z(index)
   249  			if err != nil || v.Kind() == reflect.Bool {
   250  				return
   251  			}
   252  			lr := v.Interface().(fu.Struct)
   253  			if j := fu.IndexOf(c, lr.Names); j >= 0 && lr.Columns[j].Bool() {
   254  				return
   255  			}
   256  			return fu.True, nil
   257  		}
   258  	}
   259  }
   260  
   261  func (zf Lazy) IfNotFlag(c string) Lazy {
   262  	return func() lazy.Stream {
   263  		z := zf()
   264  		return func(index uint64) (v reflect.Value, err error) {
   265  			v, err = z(index)
   266  			if err != nil || v.Kind() == reflect.Bool {
   267  				return
   268  			}
   269  			lr := v.Interface().(fu.Struct)
   270  			if j := fu.IndexOf(c, lr.Names); j < 0 || !lr.Columns[j].Bool() {
   271  				return
   272  			}
   273  			return fu.True, nil
   274  		}
   275  	}
   276  }
   277  
   278  func (zf Lazy) Alias(c string, a string) Lazy {
   279  	return func() lazy.Stream {
   280  		z := zf()
   281  		return func(index uint64) (v reflect.Value, err error) {
   282  			v, err = z(index)
   283  			if err != nil || v.Kind() == reflect.Bool {
   284  				return
   285  			}
   286  			lr := v.Interface().(fu.Struct)
   287  			return reflect.ValueOf(lr.Set(a, lr.Value(c))), nil
   288  		}
   289  	}
   290  }
   291  
   292  func (zf Lazy) True(c string) Lazy {
   293  	return func() lazy.Stream {
   294  		z := zf()
   295  		return func(index uint64) (v reflect.Value, err error) {
   296  			v, err = z(index)
   297  			if err != nil || v.Kind() == reflect.Bool {
   298  				return
   299  			}
   300  			lr := v.Interface().(fu.Struct)
   301  			return reflect.ValueOf(lr.Set(c, fu.True)), nil
   302  		}
   303  	}
   304  }
   305  
   306  func (zf Lazy) False(c string) Lazy {
   307  	return func() lazy.Stream {
   308  		z := zf()
   309  		return func(index uint64) (v reflect.Value, err error) {
   310  			v, err = z(index)
   311  			if err != nil || v.Kind() == reflect.Bool {
   312  				return
   313  			}
   314  			lr := v.Interface().(fu.Struct)
   315  			return reflect.ValueOf(lr.Set(c, fu.False)), nil
   316  		}
   317  	}
   318  }
   319  
   320  func (zf Lazy) Only(c ...string) Lazy {
   321  	return func() lazy.Stream {
   322  		z := zf()
   323  		var only func(fu.Struct) fu.Struct
   324  		mu := sync.Mutex{}
   325  		f := fu.AtomicFlag{}
   326  		return func(index uint64) (v reflect.Value, err error) {
   327  			v, err = z(index)
   328  			if err != nil || v.Kind() == reflect.Bool {
   329  				return
   330  			}
   331  			lr := v.Interface().(fu.Struct)
   332  			if !f.State() {
   333  				mu.Lock()
   334  				if !f.State() {
   335  					only = fu.OnlyFilter(lr.Names, c...)
   336  					f.Set()
   337  				}
   338  				mu.Unlock()
   339  			}
   340  			return reflect.ValueOf(only(lr)), nil
   341  		}
   342  	}
   343  }
   344  
   345  func (zf Lazy) Chain(zx Lazy) Lazy {
   346  	return Lazy(lazy.Source(zf).Chain(lazy.Source(zx), func(a, b reflect.Value) (eqt bool) {
   347  		if lr, ok := a.Interface().(fu.Struct); ok {
   348  			if lrx, ok := b.Interface().(fu.Struct); ok {
   349  				if len(lrx.Names) == len(lr.Names) {
   350  					for i, n := range lrx.Names {
   351  						if n != lr.Names[i] || lrx.Columns[i].Type() != lr.Columns[i].Type() {
   352  							return false
   353  						}
   354  					}
   355  					eqt = true
   356  				}
   357  			}
   358  		}
   359  		return
   360  	}))
   361  }
   362  
   363  func (zf Lazy) Kfold(seed int, kfold int, k int, name string) Lazy {
   364  	return func() lazy.Stream {
   365  		z := zf()
   366  		rnd := fu.NaiveRandom{Value: uint32(seed)}
   367  		ac := fu.AtomicCounter{Value: 0}
   368  		wc := fu.WaitCounter{Value: 0}
   369  		nx := make([]int, kfold)
   370  		for i := range nx {
   371  			nx[i] = i
   372  		}
   373  		return func(index uint64) (v reflect.Value, err error) {
   374  			v, err = z(index)
   375  			if index == lazy.STOP {
   376  				wc.Stop()
   377  			}
   378  			if wc.Wait(index) {
   379  				if err == nil && v.Kind() != reflect.Bool {
   380  					a := int(ac.PostInc())
   381  					if a%kfold == 0 {
   382  						for i := range nx {
   383  							j := int(rnd.Float() * float64(kfold))
   384  							nx[i], nx[j] = nx[j], nx[i]
   385  						}
   386  					}
   387  					lr := v.Interface().(fu.Struct)
   388  					if nx[a%kfold] == k {
   389  						v = reflect.ValueOf(lr.Set(name, fu.True))
   390  					} else {
   391  						v = reflect.ValueOf(lr.Set(name, fu.False))
   392  					}
   393  				}
   394  				wc.Inc()
   395  			}
   396  			return
   397  		}
   398  	}
   399  }
   400  
   401  func (zf Lazy) Transform(f func(fu.Struct) (fu.Struct, bool, error)) Lazy {
   402  	return func() lazy.Stream {
   403  		z := zf()
   404  		return func(index uint64) (v reflect.Value, err error) {
   405  			v, err = z(index)
   406  			if err != nil || v.Kind() == reflect.Bool {
   407  				return
   408  			}
   409  			lr := v.Interface().(fu.Struct)
   410  			lr, ok, err := f(lr)
   411  			if err != nil {
   412  				return fu.False, err
   413  			}
   414  			if !ok {
   415  				return fu.True, nil
   416  			}
   417  			return reflect.ValueOf(lr), nil
   418  		}
   419  	}
   420  }
   421  
   422  func (zf Lazy) BatchTransform(batch int, tf func(int) (FeaturesMapper, error)) Lazy {
   423  	return zf.Batch(batch).Transform(tf).Flat()
   424  }
   425  
   426  func (zf Lazy) BatchReduce(batch int, tf func(*Table) (fu.Struct, bool, error)) Lazy {
   427  	return zf.Batch(batch).Reduce(tf)
   428  }
   429  
   430  func (zf Lazy) Foreach(f func(fu.Struct) error) (err error) {
   431  	return zf.Drain(func(v reflect.Value) error {
   432  		if v.Kind() != reflect.Bool {
   433  			lr := v.Interface().(fu.Struct)
   434  			return f(lr)
   435  		}
   436  		return nil
   437  	})
   438  }
   439  
   440  func (zf Lazy) UnpackTensor(c string) Lazy {
   441  	return func() lazy.Stream {
   442  		var ft func(fu.Struct) fu.Struct
   443  		m := sync.Mutex{}
   444  		fc := fu.AtomicFlag{Value: 0}
   445  		z := zf()
   446  		return func(index uint64) (v reflect.Value, err error) {
   447  			v, err = z(index)
   448  			if err != nil || v.Kind() == reflect.Bool {
   449  				return
   450  			}
   451  			lr := v.Interface().(fu.Struct)
   452  			if !fc.State() {
   453  				m.Lock()
   454  				if !fc.State() {
   455  					ft = fu.TensorUnpacker(lr, c)
   456  					fc.Set()
   457  				}
   458  				m.Unlock()
   459  			}
   460  			return reflect.ValueOf(ft(lr)), nil
   461  		}
   462  	}
   463  }
   464  
   465  func LazyConcatf(f ...func() Lazy) Lazy {
   466  	return func() lazy.Stream {
   467  		zf := f[0]()()
   468  		fl := f[1:]
   469  		wc := fu.WaitCounter{Value: 0}
   470  		c := uint64(0)
   471  		return func(index uint64) (v reflect.Value, err error) {
   472  			if index == lazy.STOP {
   473  				v, err = zf(index)
   474  				wc.Stop()
   475  				return
   476  			}
   477  			if wc.Wait(index) {
   478  				v, err = zf(c)
   479  				c++
   480  				if err != nil {
   481  					wc.Stop()
   482  				} else {
   483  					if v.Kind() == reflect.Bool && !v.Bool() {
   484  						if len(fl) > 0 {
   485  							_, _ = zf(lazy.STOP)
   486  							zf = fl[0]()()
   487  							fl = fl[1:]
   488  							v = fu.True
   489  							c = 0
   490  						}
   491  					}
   492  					wc.Inc()
   493  				}
   494  			}
   495  			return
   496  		}
   497  	}
   498  }
   499  
   500  func LazyConcat(a ...AnyData) Lazy {
   501  	f := make([]func() Lazy, len(a))
   502  	for i, x := range a {
   503  		q := x
   504  		f[i] = func() Lazy { return q.Lazy() }
   505  	}
   506  	return LazyConcatf(f...)
   507  }
   508  
   509  func LazyZip(a ...AnyData) Lazy {
   510  	return func() lazy.Stream {
   511  		zf := make([]lazy.Stream, len(a))
   512  		vx := make([]fu.Struct, len(a))
   513  		dx := make([]uint64, len(a))
   514  		nx := []int{}
   515  		names := []string{}
   516  		for i, x := range a {
   517  			zf[i] = x.Lazy()()
   518  		}
   519  		wc := fu.WaitCounter{Value: 0}
   520  		return func(index uint64) (v reflect.Value, err error) {
   521  			if index == lazy.STOP {
   522  				for _, f := range zf {
   523  					_, _ = f(index)
   524  				}
   525  				wc.Stop()
   526  				return fu.False, nil
   527  			}
   528  			for wc.Wait(index) {
   529  				for i, f := range zf {
   530  				retry:
   531  					for {
   532  						v, err = f(dx[i])
   533  						dx[i]++
   534  						if err != nil || v.Kind() == reflect.Bool && !v.Bool() {
   535  							wc.Stop()
   536  							return fu.False, err
   537  						}
   538  						if v.Kind() != reflect.Bool {
   539  							vx[i] = v.Interface().(fu.Struct)
   540  							break retry
   541  						}
   542  					}
   543  				}
   544  				if len(names) == 0 {
   545  					for _, a := range vx {
   546  						for _, n := range a.Names {
   547  							if fu.IndexOf(n, names) < 0 {
   548  								nx = append(nx, len(names))
   549  								names = append(names, n)
   550  							} else {
   551  								nx = append(nx, -1)
   552  							}
   553  						}
   554  					}
   555  				}
   556  				k := 0
   557  				columns := make([]reflect.Value, len(names))
   558  				na := fu.Bits{}
   559  				for _, a := range vx {
   560  					for i, v := range a.Columns {
   561  						j := nx[k]
   562  						k++
   563  						if j >= 0 {
   564  							columns[j] = v
   565  							na.Set(j, a.Na.Bit(i))
   566  						}
   567  					}
   568  				}
   569  				lr := fu.Struct{names, columns, na}
   570  				wc.Inc()
   571  				return reflect.ValueOf(lr), nil
   572  			}
   573  			return fu.False, nil
   574  		}
   575  	}
   576  }
   577  
   578  func (zf Lazy) Reals() ([]float32, error) {
   579  	length := 0
   580  	r := []float32{}
   581  	err := zf.Drain(func(v reflect.Value) error {
   582  		if v.Kind() != reflect.Bool {
   583  			lr := v.Interface().(fu.Struct)
   584  			if len(lr.Names) != 1 {
   585  				return zorros.New("Lazy.Errors can handle only one column")
   586  			}
   587  			defer func() {
   588  				if e := recover(); e != nil {
   589  					fmt.Println(e)
   590  					fmt.Println(lr.String())
   591  					panic(e)
   592  				}
   593  			}()
   594  			if lr.Na.Bit(0) {
   595  				r = append(r, float32(math.NaN()))
   596  			} else {
   597  				r = append(r, fu.Cell{lr.Columns[0]}.Real())
   598  			}
   599  			length++
   600  		}
   601  		return nil
   602  	})
   603  	return r, err
   604  }
   605  
   606  func (zf Lazy) LuckyReals() []float32 {
   607  	r, err := zf.Reals()
   608  	if err != nil {
   609  		panic(zorros.Panic(err))
   610  	}
   611  	return r
   612  }