go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/fu/struct.go (about)

     1  package fu
     2  
     3  import (
     4  	"fmt"
     5  	"go-ml.dev/pkg/zorros"
     6  	"reflect"
     7  	"strings"
     8  	"sync"
     9  )
    10  
    11  type Struct struct {
    12  	Names   []string
    13  	Columns []reflect.Value
    14  	Na      Bits
    15  }
    16  
    17  func (lr Struct) String() string {
    18  	r := make([]string, len(lr.Names))
    19  	for i, n := range lr.Names {
    20  		v := (interface{})("<nil>")
    21  		if lr.Columns[i].IsValid() {
    22  			v = lr.Columns[i].Interface()
    23  		}
    24  		r[i] = fmt.Sprintf("%v:%v", n, Ife(lr.Na.Bit(i), "N/A", v))
    25  	}
    26  	return "fu.Struct{" + strings.Join(r, ", ") + "}"
    27  }
    28  
    29  func (lr Struct) Copy(extra int) Struct {
    30  	width := len(lr.Names)
    31  	columns := make([]reflect.Value, width, width+extra)
    32  	copy(columns, lr.Columns)
    33  	names := make([]string, width, width+extra)
    34  	copy(names, lr.Names)
    35  	na := lr.Na.Copy()
    36  	return Struct{names, columns, na}
    37  }
    38  
    39  func (lrx Struct) With(lr Struct) (r Struct) {
    40  	extra := 0
    41  	ndx := make([]int, len(lr.Names))
    42  	for i, n := range lr.Names {
    43  		j := IndexOf(n, lrx.Names)
    44  		ndx[i] = j
    45  		if j < 0 {
    46  			extra++
    47  		}
    48  	}
    49  	r = lrx.Copy(extra)
    50  	for i, j := range ndx {
    51  		if j >= 0 {
    52  			r.Columns[j] = lr.Columns[i]
    53  			r.Na.Set(j, lr.Na.Bit(i))
    54  		} else {
    55  			r.Na.Set(len(r.Names), lr.Na.Bit(i))
    56  			r.Names = append(r.Names, lr.Names[i])
    57  			r.Columns = append(r.Columns, lr.Columns[i])
    58  		}
    59  	}
    60  	return
    61  }
    62  
    63  func Wrapper(rt reflect.Type) func(reflect.Value) Struct {
    64  	L := rt.NumField()
    65  	names := make([]string, L)
    66  	for i := range names {
    67  		names[i] = rt.Field(i).Name
    68  	}
    69  	return func(v reflect.Value) Struct {
    70  		lr := Struct{Columns: make([]reflect.Value, L), Names: names, Na: Bits{}}
    71  		for i := range names {
    72  			x := v.Field(i)
    73  			lr.Na.Set(i, Isna(x))
    74  			lr.Columns[i] = x
    75  		}
    76  		return lr
    77  	}
    78  }
    79  
    80  var uwrpMu = sync.Mutex{}
    81  
    82  func Unwrapper(v reflect.Type) func(lr Struct) reflect.Value {
    83  	var indecies [][]int
    84  	inif := AtomicFlag{0}
    85  	return func(lr Struct) reflect.Value {
    86  		if !inif.State() {
    87  			uwrpMu.Lock()
    88  			if !inif.State() {
    89  				var nd [][]int
    90  				L := v.NumField()
    91  				for i := 0; i < L; i++ {
    92  					vt := v.Field(i)
    93  					pat := string(vt.Tag)
    94  					if pat == "" {
    95  						pat = vt.Name
    96  					}
    97  					like := Pattern(pat)
    98  					q := []int{}
    99  					for i, n := range lr.Names {
   100  						if like(n) {
   101  							q = append(q, i)
   102  						}
   103  					}
   104  					if len(q) == 0 {
   105  						uwrpMu.Unlock()
   106  						panic(zorros.Panic(zorros.Errorf("Struct does not have filed(s) matched to " + pat)))
   107  					}
   108  					if vt.Type.Kind() == reflect.Slice {
   109  						nd = append(nd, q)
   110  					} else {
   111  						nd = append(nd, q[:1])
   112  					}
   113  				}
   114  				indecies = nd
   115  				inif.Set()
   116  			}
   117  			uwrpMu.Unlock()
   118  		}
   119  
   120  		x := reflect.New(v).Elem()
   121  		for i, nd := range indecies {
   122  			vt := v.Field(i)
   123  			if vt.Type.Kind() == reflect.Slice {
   124  				et := vt.Type.Elem()
   125  				a := reflect.MakeSlice(reflect.SliceOf(et), len(nd), len(nd))
   126  				for j, k := range nd {
   127  					a.Index(j).Set(Convert(lr.Columns[k], lr.Na.Bit(k), et))
   128  				}
   129  				x.Field(i).Set(a)
   130  			} else {
   131  				k := nd[0]
   132  				y := Convert(lr.Columns[k], lr.Na.Bit(k), vt.Type)
   133  				x.Field(i).Set(y)
   134  			}
   135  		}
   136  		return x
   137  	}
   138  }
   139  
   140  var trfMu = sync.Mutex{}
   141  
   142  func Transformer(rt reflect.Type) func(reflect.Value, reflect.Value) reflect.Value {
   143  	var (
   144  		names  []string
   145  		update []int
   146  	)
   147  	inif := AtomicFlag{0}
   148  	return func(v reflect.Value, olr reflect.Value) reflect.Value {
   149  		lrx := olr.Interface().(Struct)
   150  		if !inif.State() {
   151  			trfMu.Lock()
   152  			if !inif.State() {
   153  				names = make([]string, len(lrx.Names), len(lrx.Names)*2)
   154  				update = make([]int, len(lrx.Names), len(lrx.Names)*2)
   155  				copy(names, lrx.Names)
   156  				for i := range update {
   157  					update[i] = -1
   158  				}
   159  				L := rt.NumField()
   160  				for i := 0; i < L; i++ {
   161  					n := rt.Field(i).Name
   162  					if j := IndexOf(n, names); j < 0 {
   163  						names = append(names, n)
   164  						update = append(update, i)
   165  					} else {
   166  						update[j] = i
   167  					}
   168  				}
   169  				inif.Set()
   170  			}
   171  			trfMu.Unlock()
   172  		}
   173  		lr := Struct{Columns: make([]reflect.Value, len(names)), Names: names, Na: lrx.Na.Copy()}
   174  		for i := range names {
   175  			if j := update[i]; j >= 0 {
   176  				x := v.Field(j)
   177  				lr.Na.Set(i, Isna(x))
   178  				lr.Columns[i] = x
   179  			} else {
   180  				lr.Columns[i] = lrx.Columns[i]
   181  			}
   182  		}
   183  		return reflect.ValueOf(lr)
   184  	}
   185  }
   186  
   187  func NaStruct(names []string, tp reflect.Type) Struct {
   188  	columns := make([]reflect.Value, len(names))
   189  	for i := range columns {
   190  		columns[i] = reflect.Zero(tp)
   191  	}
   192  	return Struct{names, columns, FillBits(len(names))}
   193  }
   194  
   195  func MakeStruct(names []string, vals ...interface{}) Struct {
   196  	columns := make([]reflect.Value, len(names))
   197  	for i := range columns {
   198  		columns[i] = reflect.ValueOf(vals[i])
   199  	}
   200  	return Struct{names, columns, Bits{}}
   201  }
   202  
   203  func (lr Struct) Set(c string, val reflect.Value) Struct {
   204  	cj := IndexOf(c, lr.Names)
   205  	lr = lr.Copy(cj + 1)
   206  	if cj < 0 {
   207  		lr.Names = append(lr.Names, c)
   208  		lr.Columns = append(lr.Columns, val)
   209  	} else {
   210  		lr.Columns[cj] = val
   211  		lr.Na.Set(cj, false)
   212  	}
   213  	return lr
   214  }
   215  
   216  func (lr Struct) Pos(c string) int {
   217  	return IndexOf(c, lr.Names)
   218  }
   219  
   220  func (lr Struct) ValueAt(i int) reflect.Value {
   221  	return lr.Columns[i]
   222  }
   223  
   224  func (lr Struct) Value(c string) reflect.Value {
   225  	j := IndexOf(c, lr.Names)
   226  	return lr.Columns[j]
   227  }
   228  
   229  func (lr Struct) Index(c string) Cell {
   230  	j := IndexOf(c, lr.Names)
   231  	return Cell{lr.Columns[j]}
   232  }
   233  
   234  func (lr Struct) Int(c string) int       { return lr.Index(c).Int() }
   235  func (lr Struct) Float(c string) float64 { return lr.Index(c).Float() }
   236  func (lr Struct) Real(c string) float32  { return lr.Index(c).Real() }
   237  func (lr Struct) Text(c string) string   { return lr.Index(c).Text() }
   238  
   239  func (lr Struct) Round(p int) Struct {
   240  	c := lr.Copy(0)
   241  	for i, v := range c.Columns {
   242  		switch v.Kind() {
   243  		case reflect.Float32, reflect.Float64:
   244  			c.Columns[i] = reflect.ValueOf(Round64(v.Float(), p))
   245  		}
   246  	}
   247  	return c
   248  }
   249  
   250  func OnlyFilter(names []string, c ...string) func(Struct) Struct {
   251  	ns := make([]string, 0, len(names))
   252  	nx := make([]int, 0, len(names))
   253  	p := make([]func(string) bool, len(c))
   254  	for i, s := range c {
   255  		p[i] = Pattern(s)
   256  	}
   257  	for i, n := range names {
   258  	l:
   259  		for _, f := range p {
   260  			if f(n) {
   261  				ns = append(ns, n)
   262  				nx = append(nx, i)
   263  				break l
   264  			}
   265  		}
   266  	}
   267  	return func(lr Struct) Struct {
   268  		columns := make([]reflect.Value, len(ns))
   269  		for i, x := range nx {
   270  			columns[i] = lr.Columns[x]
   271  		}
   272  		return Struct{Names: ns, Columns: columns}
   273  	}
   274  }
   275  
   276  func tensorUnpacker(names []string, c string, volume int) func(lr Struct) Struct {
   277  	j := IndexOf(c, names)
   278  	ns := make([]string, len(names)-1+volume)
   279  	k := 0
   280  	for i, n := range names {
   281  		if i != j {
   282  			ns[k] = n
   283  			k++
   284  		}
   285  	}
   286  	for i := 1; k < len(ns); k++ {
   287  		ns[k] = fmt.Sprintf("%v%v", c, i)
   288  		i++
   289  	}
   290  	k = len(names) - 1
   291  	return func(lr Struct) Struct {
   292  		t := lr.ValueAt(j).Interface().(Tensor)
   293  		columns := make([]reflect.Value, len(lr.Names)-1+t.Volume())
   294  		na := Bits{}
   295  		t.Extract(columns[k:])
   296  		n := 0
   297  		for i, v := range lr.Columns {
   298  			if i != j {
   299  				if lr.Na.Bit(i) {
   300  					na.Set(j, true)
   301  				}
   302  				columns[n] = v
   303  				n++
   304  			}
   305  		}
   306  		return Struct{ns, columns, na}
   307  	}
   308  }
   309  
   310  func TensorUnpacker(lr Struct, c string) func(lr Struct) Struct {
   311  	return tensorUnpacker(lr.Names, c, lr.Value(c).Interface().(Tensor).Volume())
   312  }
   313  
   314  func (lr Struct) UnpackTensor(c string) Struct {
   315  	return TensorUnpacker(lr, c)(lr)
   316  }