github.com/wzzhu/tensor@v0.9.24/genlib2/dense_getset.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"text/template"
     6  )
     7  
     8  const copySlicedRaw = `func copySliced(dest *Dense, dstart, dend int, src *Dense, sstart, send int) int{
     9  	if dest.t != src.t {
    10  		panic("Cannot copy arrays of different types")
    11  	}
    12  
    13  	if src.IsMasked(){
    14  		mask:=dest.mask
    15  		if cap(dest.mask) < dend{
    16  			mask = make([]bool, dend)
    17  		}
    18  		copy(mask, dest.mask)
    19  		dest.mask=mask
    20  		copy(dest.mask[dstart:dend], src.mask[sstart:send])
    21  	}
    22  	switch dest.t {
    23  	{{range .Kinds -}}
    24  		{{if isParameterized .}}
    25  		{{else -}}
    26  	case {{reflectKind .}}:
    27  		return copy(dest.{{sliceOf .}}[dstart:dend], src.{{sliceOf .}}[sstart:send])
    28  		{{end -}}
    29  	{{end -}}
    30  	default:
    31  		dv := reflect.ValueOf(dest.v)
    32  		dv = dv.Slice(dstart, dend)
    33  		sv := reflect.ValueOf(src.v)
    34  		sv = sv.Slice(sstart, send)
    35  		return reflect.Copy(dv, sv)
    36  	}	
    37  }
    38  `
    39  
    40  const copyIterRaw = `func copyDenseIter(dest, src *Dense, diter, siter *FlatIterator) (int, error) {
    41  	if dest.t != src.t {
    42  		panic("Cannot copy arrays of different types")
    43  	}
    44  
    45  	if diter == nil && siter == nil && !dest.IsMaterializable() && !src.IsMaterializable() {
    46  		return copyDense(dest, src), nil
    47  	}
    48  
    49  	if diter == nil {
    50  		diter = newFlatIterator(&dest.AP)	
    51  	}
    52  	if siter == nil {
    53  		siter = newFlatIterator(&src.AP)
    54  	}
    55  	
    56  	isMasked:= src.IsMasked()
    57  	if isMasked{
    58  		if cap(dest.mask)<src.DataSize(){
    59  			dest.mask=make([]bool, src.DataSize())
    60  		}
    61  		dest.mask=dest.mask[:dest.DataSize()]
    62  	}
    63  
    64  	dt := dest.t
    65  	var i, j, count int
    66  	var err error
    67  	for {
    68  		if i, err = diter.Next() ; err != nil {
    69  			if err = handleNoOp(err); err != nil{
    70  				return count, err
    71  			}
    72  			break
    73  		}
    74  		if j, err = siter.Next() ; err != nil {
    75  			if err = handleNoOp(err); err != nil{
    76  				return count, err
    77  			}
    78  			break
    79  		}
    80  		if isMasked{
    81  			dest.mask[i]=src.mask[j]
    82  		}
    83  		
    84  		switch dt {
    85  		{{range .Kinds -}}
    86  			{{if isParameterized . -}}
    87  			{{else -}}
    88  		case {{reflectKind .}}:
    89  			dest.{{setOne .}}(i, src.{{getOne .}}(j))
    90  			{{end -}}
    91  		{{end -}}
    92  		default:
    93  			dest.Set(i, src.Get(j))
    94  		}
    95  		count++
    96  	}
    97  	return count, err
    98  }
    99  `
   100  
   101  const sliceRaw = `// the method assumes the AP and metadata has already been set and this is simply slicing the values
   102  func (t *Dense) slice(start, end int) {
   103  	switch t.t {
   104  	{{range .Kinds -}}
   105  		{{if isParameterized .}}
   106  		{{else -}}
   107  	case {{reflectKind .}}:
   108  		data := t.{{sliceOf .}}[start:end]
   109  		t.fromSlice(data)
   110  		{{end -}}
   111  	{{end -}}
   112  	default:
   113  		v := reflect.ValueOf(t.v)
   114  		v = v.Slice(start, end)
   115  		t.fromSlice(v.Interface())
   116  	}	
   117  }
   118  `
   119  
   120  var (
   121  	CopySliced *template.Template
   122  	CopyIter   *template.Template
   123  	Slice      *template.Template
   124  )
   125  
   126  func init() {
   127  
   128  	CopySliced = template.Must(template.New("copySliced").Funcs(funcs).Parse(copySlicedRaw))
   129  	CopyIter = template.Must(template.New("copyIter").Funcs(funcs).Parse(copyIterRaw))
   130  	Slice = template.Must(template.New("slice").Funcs(funcs).Parse(sliceRaw))
   131  }
   132  
   133  func generateDenseGetSet(f io.Writer, generic Kinds) {
   134  
   135  	// CopySliced.Execute(f, generic)
   136  	// fmt.Fprintf(f, "\n\n\n")
   137  	// CopyIter.Execute(f, generic)
   138  	// fmt.Fprintf(f, "\n\n\n")
   139  	// Slice.Execute(f, generic)
   140  	// fmt.Fprintf(f, "\n\n\n")
   141  }