go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/fu/lazy/lazy.go (about)

     1  package lazy
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"go-ml.dev/pkg/zorros"
     6  	"math"
     7  	"reflect"
     8  	"runtime"
     9  	"sync/atomic"
    10  	"unsafe"
    11  )
    12  
    13  const STOP = math.MaxUint64
    14  
    15  type Stream func(index uint64) (reflect.Value, error)
    16  type Source func() Stream
    17  type Sink func(reflect.Value) error
    18  type Parallel int
    19  
    20  var falseValue = reflect.ValueOf(false)
    21  var trueValue = reflect.ValueOf(true)
    22  
    23  func (zf Source) Map(f interface{}) Source {
    24  	return func() Stream {
    25  		z := zf()
    26  		return func(index uint64) (v reflect.Value, err error) {
    27  			if v, err = z(index); err != nil || v.Kind() == reflect.Bool {
    28  				return v, err
    29  			}
    30  			fv := reflect.ValueOf(f)
    31  			return fv.Call([]reflect.Value{v})[0], nil
    32  		}
    33  	}
    34  }
    35  
    36  func (zf Source) Filter(f interface{}) Source {
    37  	return func() Stream {
    38  		z := zf()
    39  		return func(index uint64) (v reflect.Value, err error) {
    40  			if v, err = z(index); err != nil || v.Kind() == reflect.Bool {
    41  				return v, err
    42  			}
    43  			fv := reflect.ValueOf(f)
    44  			if fv.Call([]reflect.Value{v})[0].Bool() {
    45  				return
    46  			}
    47  			return trueValue, nil
    48  		}
    49  	}
    50  }
    51  
    52  func (zf Source) Parallel(concurrency ...int) Source {
    53  	return func() Stream {
    54  		z := zf()
    55  		ccrn := fu.Fnzi(fu.Fnzi(concurrency...), runtime.NumCPU())
    56  		type C struct {
    57  			reflect.Value
    58  			error
    59  		}
    60  		index := fu.AtomicCounter{0}
    61  		wc := fu.WaitCounter{Value: 0}
    62  		c := make(chan C)
    63  		stop := make(chan struct{})
    64  		alive := fu.AtomicCounter{uint64(ccrn)}
    65  		for i := 0; i < ccrn; i++ {
    66  			go func() {
    67  			loop:
    68  				for !wc.Stopped() {
    69  					n := index.PostInc() // returns old value
    70  					v, err := z(n)
    71  					if n < STOP && wc.Wait(n) {
    72  						select {
    73  						case c <- C{v, err}:
    74  						case <-stop:
    75  							wc.Stop()
    76  							break loop
    77  						}
    78  						wc.Inc()
    79  					}
    80  				}
    81  				if alive.Dec() == 0 { // return new value
    82  					close(c)
    83  				}
    84  			}()
    85  		}
    86  		return func(index uint64) (reflect.Value, error) {
    87  			if index == STOP {
    88  				close(stop)
    89  				return z(STOP)
    90  			}
    91  			if x, ok := <-c; ok {
    92  				return x.Value, x.error
    93  			}
    94  			return falseValue, nil
    95  		}
    96  	}
    97  }
    98  
    99  func (zf Source) First(n int) Source {
   100  	return func() Stream {
   101  		z := zf()
   102  		count := 0
   103  		wc := fu.WaitCounter{Value: 0}
   104  		return func(index uint64) (v reflect.Value, err error) {
   105  			v, err = z(index)
   106  			if index != STOP && wc.Wait(index) {
   107  				if count < n && err == nil {
   108  					if v.Kind() != reflect.Bool {
   109  						count++
   110  					}
   111  					wc.Inc()
   112  					return
   113  				}
   114  				wc.Stop()
   115  			}
   116  			return falseValue, nil
   117  		}
   118  	}
   119  }
   120  
   121  func (zf Source) Drain(sink func(reflect.Value) error) (err error) {
   122  	z := zf()
   123  	var v reflect.Value
   124  	var i uint64
   125  	for {
   126  		if v, err = z(i); err != nil {
   127  			break
   128  		}
   129  		i++
   130  		if v.Kind() != reflect.Bool {
   131  			if err = sink(v); err != nil {
   132  				break
   133  			}
   134  		} else if !v.Bool() {
   135  			break
   136  		}
   137  	}
   138  	z(STOP)
   139  	e := sink(reflect.ValueOf(err == nil))
   140  	return fu.Fnze(err, e)
   141  }
   142  
   143  func Chan(c interface{}, stop ...chan struct{}) Source {
   144  	return func() Stream {
   145  		scase := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c)}}
   146  		wc := fu.WaitCounter{Value: 0}
   147  		return func(index uint64) (v reflect.Value, err error) {
   148  			if index == STOP {
   149  				wc.Stop()
   150  				for _, s := range stop {
   151  					close(s)
   152  				}
   153  			}
   154  			if wc.Wait(index) {
   155  				_, r, ok := reflect.Select(scase)
   156  				if wc.Inc() && ok {
   157  					return r, nil
   158  				}
   159  			}
   160  			return falseValue, nil
   161  		}
   162  	}
   163  }
   164  
   165  func List(list interface{}) Source {
   166  	return func() Stream {
   167  		v := reflect.ValueOf(list)
   168  		l := uint64(v.Len())
   169  		flag := fu.AtomicFlag{Value: 1}
   170  		return func(index uint64) (reflect.Value, error) {
   171  			if index < l && flag.State() {
   172  				return v.Index(int(index)), nil
   173  			}
   174  			return falseValue, nil
   175  		}
   176  	}
   177  }
   178  
   179  func Flatn(listarr interface{}) Source {
   180  	return func() Stream {
   181  		t := reflect.TypeOf(listarr)
   182  		if t.Kind() != reflect.Slice || t.Elem().Kind() != reflect.Array {
   183  			panic(zorros.Panic(zorros.Errorf("only [][n]any allowed but %v occurred", t)))
   184  		}
   185  		v := reflect.ValueOf(listarr)
   186  		n := t.Elem().Len()
   187  		l := v.Len()
   188  		return func(index uint64) (reflect.Value, error) {
   189  			if index >= uint64(n*l) {
   190  				return fu.False, nil
   191  			}
   192  			i := int(index)
   193  			return v.Index(i / n).Index(i % n), nil
   194  		}
   195  	}
   196  }
   197  
   198  const iniCollectLength = 13
   199  
   200  func (zf Source) Collect() (r interface{}, err error) {
   201  	length := 0
   202  	values := reflect.ValueOf((interface{})(nil))
   203  	err = zf.Drain(func(v reflect.Value) error {
   204  		if length == 0 {
   205  			values = reflect.MakeSlice(reflect.SliceOf(v.Type()), 0, iniCollectLength)
   206  		}
   207  		if v.Kind() != reflect.Bool {
   208  			values = reflect.Append(values, v)
   209  			length++
   210  		}
   211  		return nil
   212  	})
   213  	if err != nil {
   214  		return
   215  	}
   216  	return values.Interface(), nil
   217  }
   218  
   219  func (zf Source) LuckyCollect() interface{} {
   220  	t, err := zf.Collect()
   221  	if err != nil {
   222  		panic(err)
   223  	}
   224  	return t
   225  }
   226  
   227  func (zf Source) Count() (count int, err error) {
   228  	err = zf.Drain(func(v reflect.Value) error {
   229  		if v.Kind() != reflect.Bool {
   230  			count++
   231  		}
   232  		return nil
   233  	})
   234  	return
   235  }
   236  
   237  func (zf Source) LuckyCount() int {
   238  	c, err := zf.Count()
   239  	if err != nil {
   240  		panic(err)
   241  	}
   242  	return c
   243  }
   244  
   245  func (zf Source) RandFilter(seed int, prob float64, t bool) Source {
   246  	return func() Stream {
   247  		z := zf()
   248  		nr := fu.NaiveRandom{Value: uint32(seed)}
   249  		wc := fu.WaitCounter{Value: 0}
   250  		return func(index uint64) (v reflect.Value, err error) {
   251  			v, err = z(index)
   252  			if index == STOP {
   253  				wc.Stop()
   254  			}
   255  			if wc.Wait(index) {
   256  				if v.Kind() != reflect.Bool {
   257  					p := nr.Float()
   258  					if (t && p <= prob) || (!t && p > prob) {
   259  						v = trueValue // skip
   260  					}
   261  				}
   262  				wc.Inc()
   263  			}
   264  			return
   265  		}
   266  	}
   267  }
   268  
   269  func (zf Source) RandSkip(seed int, prob float64) Source {
   270  	return zf.RandFilter(seed, prob, true)
   271  }
   272  
   273  func (zf Source) Rand(seed int, prob float64) Source {
   274  	return zf.RandFilter(seed, prob, false)
   275  }
   276  
   277  func Error(err error, z ...Stream) Stream {
   278  	return func(index uint64) (reflect.Value, error) {
   279  		if index == STOP && len(z) > 0 {
   280  			z[0](STOP)
   281  		}
   282  		return falseValue, err
   283  	}
   284  }
   285  
   286  func Wrap(e interface{}) Stream {
   287  	if stream, ok := e.(Stream); ok {
   288  		return stream
   289  	} else {
   290  		return Error(e.(error))
   291  	}
   292  }
   293  
   294  func (zf Source) Chain(zx Source, eqt ...func(a, b reflect.Value) bool) Source {
   295  	return func() Stream {
   296  		z0 := zf()
   297  		z1 := zx()
   298  		b := uint64(0)
   299  		ptr := unsafe.Pointer(nil)
   300  		return func(index uint64) (v reflect.Value, err error) {
   301  			q := atomic.LoadUint64(&b)
   302  			if index == STOP {
   303  				_, err = z0(index)
   304  				_, err1 := z1(index)
   305  				if q > 0 || err == nil {
   306  					err = err1
   307  				}
   308  				return falseValue, err
   309  			}
   310  			if q == 0 || index < q {
   311  				v, err = z0(index)
   312  				if err == nil {
   313  					if v.Kind() == reflect.Bool {
   314  						if !v.Bool() { // end first stream
   315  							atomic.CompareAndSwapUint64(&b, q, index)
   316  							return trueValue, nil
   317  						}
   318  					} else {
   319  						if q == 0 && atomic.LoadPointer(&ptr) == nil {
   320  							vx := v
   321  							atomic.CompareAndSwapPointer(&ptr, nil, unsafe.Pointer(&vx))
   322  						}
   323  					}
   324  				}
   325  			} else {
   326  				v, err = z1(index - q)
   327  				if err == nil && v.Kind() != reflect.Bool {
   328  					if p := atomic.LoadPointer(&ptr); p != nil {
   329  						vx := (*reflect.Value)(p)
   330  						if v.Type() != vx.Type() {
   331  							return falseValue, zorros.Errorf("chained stream is not compatible")
   332  						}
   333  						for _, f := range eqt {
   334  							if !f(v, *vx) {
   335  								return falseValue, zorros.Errorf("chained stream has non equal value type")
   336  							}
   337  						}
   338  						atomic.CompareAndSwapPointer(&ptr, p, nil)
   339  					}
   340  				}
   341  			}
   342  			return
   343  		}
   344  	}
   345  }