github.com/wzzhu/tensor@v0.9.24/genlib2/array_getset.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} {return (*(*[]{{asType .}})(unsafe.Pointer(&h.Raw)))[:h.TypedLen({{short . | unexport}}Type):h.TypedLen({{short . | unexport}}Type)]}
    10  `
    11  
    12  const setBasicRaw = `func (h *Header) Set{{short . }}(i int, x {{asType . }}) { h.{{sliceOf .}}[i] = x }
    13  `
    14  
    15  const getBasicRaw = `func (h *Header) Get{{short .}}(i int) {{asType .}} { return h.{{lower .String | clean | strip | title }}s()[i]}
    16  `
    17  
    18  const getRaw = `// Get returns the ith element of the underlying array of the *Dense tensor.
    19  func (a *array) Get(i int) interface{} {
    20  	switch a.t.Kind() {
    21  	{{range .Kinds -}}
    22  		{{if isParameterized . -}}
    23  		{{else -}}
    24  	case reflect.{{reflectKind .}}:
    25  		return a.{{getOne .}}(i)
    26  		{{end -}};
    27  	{{end -}}
    28  	default:
    29  		val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size()))
    30  		val = reflect.Indirect(val)
    31  		return val.Interface()
    32  	}
    33  }
    34  
    35  `
    36  const setRaw = `// Set sets the value of the underlying array at the index i.
    37  func (a *array) Set(i int, x interface{}) {
    38  	switch a.t.Kind() {
    39  	{{range .Kinds -}}
    40  		{{if isParameterized . -}}
    41  		{{else -}}
    42  	case reflect.{{reflectKind .}}:
    43  		xv := x.({{asType .}})
    44  		a.{{setOne .}}(i, xv)
    45  		{{end -}}
    46  	{{end -}}
    47  	default:
    48  		xv := reflect.ValueOf(x)
    49  		val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size()))
    50  		val = reflect.Indirect(val)
    51  		val.Set(xv)
    52  	}
    53  }
    54  
    55  `
    56  
    57  const memsetRaw = `// Memset sets all values in the array.
    58  func (a *array) Memset(x interface{}) error {
    59  	switch a.t {
    60  	{{range .Kinds -}}
    61  		{{if isParameterized . -}}
    62  		{{else -}}
    63  	case {{reflectKind .}}:
    64  		if xv, ok := x.({{asType .}}); ok {
    65  			data := a.{{sliceOf .}}
    66  			for i := range data{
    67  				data[i] = xv
    68  			}
    69  			return nil
    70  		}
    71  
    72  		{{end -}}
    73  	{{end -}}
    74  	}
    75  
    76  	xv := reflect.ValueOf(x)
    77  	l := a.Len()
    78  	for i := 0; i < l; i++ {
    79  		val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size()))
    80  		val = reflect.Indirect(val)
    81  		val.Set(xv)
    82  	}
    83  	return nil
    84  }
    85  `
    86  
    87  const arrayEqRaw = ` // Eq checks that any two arrays are equal
    88  func (a array) Eq(other interface{}) bool {
    89  	if oa, ok := other.(*array); ok {
    90  		if oa.t != a.t {
    91  			return false
    92  		}
    93  
    94  		if oa.Len() != a.Len() {
    95  			return false
    96  		}
    97  		/*
    98  		if oa.C != a.C {
    99  			return false
   100  		}
   101  		*/
   102  
   103  		// same exact thing
   104  		if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])){
   105  			return true
   106  		}
   107  
   108  		switch a.t.Kind() {
   109  		{{range .Kinds -}}
   110  			{{if isParameterized . -}}
   111  			{{else -}}
   112  		case reflect.{{reflectKind .}}:
   113  			for i, v := range a.{{sliceOf .}} {
   114  				if oa.{{getOne .}}(i) != v {
   115  					return false
   116  				}
   117  			}
   118  			{{end -}}
   119  		{{end -}}
   120  		default:
   121  			for i := 0; i < a.Len(); i++{
   122  				if !reflect.DeepEqual(a.Get(i), oa.Get(i)){
   123  					return false
   124  				}
   125  			}
   126  		}
   127  		return true
   128  	}
   129  	return false
   130  }`
   131  
   132  const copyArrayIterRaw = `func copyArrayIter(dst, src array, diter, siter Iterator) (count int, err error){
   133  	if dst.t != src.t {
   134  		panic("Cannot copy arrays of different types")
   135  	}
   136  
   137  	if diter == nil && siter == nil {
   138  		return copyArray(dst, src), nil
   139  	}
   140  
   141  	if (diter != nil && siter == nil) || (diter == nil && siter != nil) {
   142  		return 0, errors.Errorf("Cannot copy array when only one iterator was passed in")
   143  	}
   144  
   145  	k := dest.t.Kind()
   146  	var i, j int
   147  	var validi, validj bool
   148  	for {
   149  		if i, validi, err = diter.NextValidity(); err != nil {
   150  			if err = handleNoOp(err); err != nil {
   151  				return count, err
   152  			}
   153  			break
   154  		}
   155  		if j, validj, err = siter.NextValidity(); err != nil {
   156  			if err = handleNoOp(err); err != nil {
   157  				return count, err
   158  			}
   159  			break
   160  		}
   161  		switch k {
   162  		{{range .Kinds -}}
   163  			{{if isParameterized . -}}
   164  			{{else -}}
   165  		case reflect.{{reflectKind .}}:
   166  			dest.{{setOne .}}(i, src.{{getOne .}}(j))
   167  			{{end -}}
   168  		{{end -}}
   169  		default:
   170  			dest.Set(i, src.Get(j))
   171  		}
   172  		count++
   173  	}
   174  
   175  }
   176  `
   177  
   178  const memsetIterRaw = `
   179  func (a *array) memsetIter(x interface{}, it Iterator) (err error) {
   180  	var i int
   181  	switch a.t{
   182  	{{range .Kinds -}}
   183  		{{if isParameterized . -}}
   184  		{{else -}}
   185  	case {{reflectKind .}}:
   186  		xv, ok := x.({{asType .}})
   187  		if !ok {
   188  			return errors.Errorf(dtypeMismatch, a.t, x)
   189  		}
   190  		data := a.{{sliceOf .}}
   191  		for i, err = it.Next(); err == nil; i, err = it.Next(){
   192  			data[i] = xv
   193  		}
   194  		err = handleNoOp(err)
   195  		{{end -}}
   196  	{{end -}}
   197  	default:
   198  		xv := reflect.ValueOf(x)
   199  		for i, err = it.Next(); err == nil; i, err = it.Next(){
   200  			val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size()))
   201  			val = reflect.Indirect(val)
   202  			val.Set(xv)
   203  		}
   204  		err = handleNoOp(err)
   205  	}
   206  	return
   207  }
   208  
   209  `
   210  
   211  const zeroIterRaw = `func (a *array) zeroIter(it Iterator) (err error){
   212  	var i int
   213  	switch a.t {
   214  	{{range .Kinds -}}
   215  		{{if isParameterized . -}}
   216  		{{else -}}
   217  	case {{reflectKind .}}:
   218  		data := a.{{sliceOf .}}
   219  		for i, err = it.Next(); err == nil; i, err = it.Next(){
   220  			data[i] = {{if eq .String "bool" -}}
   221  				false
   222  			{{else if eq .String "string" -}}""
   223  			{{else if eq .String "unsafe.Pointer" -}}nil
   224  			{{else -}}0{{end -}}
   225  		}
   226  		err = handleNoOp(err)
   227  		{{end -}}
   228  	{{end -}}
   229  	default:
   230  		for i, err = it.Next(); err == nil; i, err = it.Next(){
   231  			val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size()))
   232  			val = reflect.Indirect(val)
   233  			val.Set(reflect.Zero(a.t))
   234  		}
   235  		err = handleNoOp(err)
   236  	}
   237  	return
   238  }
   239  `
   240  
   241  const reflectConstTemplateRaw = `var (
   242  	{{range .Kinds -}}
   243  		{{if isParameterized . -}}
   244  		{{else -}}
   245  			{{short . | unexport}}Type = reflect.TypeOf({{asType .}}({{if eq .String "bool" -}} false {{else if eq .String "string" -}}"" {{else if eq .String "unsafe.Pointer" -}}nil {{else -}}0{{end -}}))
   246  		{{end -}}
   247  	{{end -}}
   248  )`
   249  
   250  var (
   251  	AsSlice     *template.Template
   252  	SimpleSet   *template.Template
   253  	SimpleGet   *template.Template
   254  	Get         *template.Template
   255  	Set         *template.Template
   256  	Memset      *template.Template
   257  	MemsetIter  *template.Template
   258  	Eq          *template.Template
   259  	ZeroIter    *template.Template
   260  	ReflectType *template.Template
   261  )
   262  
   263  func init() {
   264  	AsSlice = template.Must(template.New("AsSlice").Funcs(funcs).Parse(asSliceRaw))
   265  	SimpleSet = template.Must(template.New("SimpleSet").Funcs(funcs).Parse(setBasicRaw))
   266  	SimpleGet = template.Must(template.New("SimpleGet").Funcs(funcs).Parse(getBasicRaw))
   267  	Get = template.Must(template.New("Get").Funcs(funcs).Parse(getRaw))
   268  	Set = template.Must(template.New("Set").Funcs(funcs).Parse(setRaw))
   269  	Memset = template.Must(template.New("Memset").Funcs(funcs).Parse(memsetRaw))
   270  	MemsetIter = template.Must(template.New("MemsetIter").Funcs(funcs).Parse(memsetIterRaw))
   271  	Eq = template.Must(template.New("ArrayEq").Funcs(funcs).Parse(arrayEqRaw))
   272  	ZeroIter = template.Must(template.New("Zero").Funcs(funcs).Parse(zeroIterRaw))
   273  	ReflectType = template.Must(template.New("ReflectType").Funcs(funcs).Parse(reflectConstTemplateRaw))
   274  }
   275  
   276  func generateArrayMethods(f io.Writer, ak Kinds) {
   277  	Set.Execute(f, ak)
   278  	fmt.Fprintf(f, "\n\n\n")
   279  	Get.Execute(f, ak)
   280  	fmt.Fprintf(f, "\n\n\n")
   281  	Memset.Execute(f, ak)
   282  	fmt.Fprintf(f, "\n\n\n")
   283  	MemsetIter.Execute(f, ak)
   284  	fmt.Fprintf(f, "\n\n\n")
   285  	Eq.Execute(f, ak)
   286  	fmt.Fprintf(f, "\n\n\n")
   287  	ZeroIter.Execute(f, ak)
   288  	fmt.Fprintf(f, "\n\n\n")
   289  }
   290  
   291  func generateHeaderGetSet(f io.Writer, ak Kinds) {
   292  	for _, k := range ak.Kinds {
   293  		if !isParameterized(k) {
   294  			fmt.Fprintf(f, "/* %v */\n\n", k)
   295  			AsSlice.Execute(f, k)
   296  			SimpleSet.Execute(f, k)
   297  			SimpleGet.Execute(f, k)
   298  			fmt.Fprint(f, "\n")
   299  		}
   300  	}
   301  }
   302  
   303  func generateReflectTypes(f io.Writer, ak Kinds) {
   304  	ReflectType.Execute(f, ak)
   305  	fmt.Fprintf(f, "\n\n\n")
   306  }