github.com/wzzhu/tensor@v0.9.24/genlib2/array_getset.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const asSliceRaw = `func (h *Header) {{asType . | strip | title}}s() []{{asType .}} {return (*(*[]{{asType .}})(unsafe.Pointer(&h.Raw)))[:h.TypedLen({{short . | unexport}}Type):h.TypedLen({{short . | unexport}}Type)]} 10 ` 11 12 const setBasicRaw = `func (h *Header) Set{{short . }}(i int, x {{asType . }}) { h.{{sliceOf .}}[i] = x } 13 ` 14 15 const getBasicRaw = `func (h *Header) Get{{short .}}(i int) {{asType .}} { return h.{{lower .String | clean | strip | title }}s()[i]} 16 ` 17 18 const getRaw = `// Get returns the ith element of the underlying array of the *Dense tensor. 19 func (a *array) Get(i int) interface{} { 20 switch a.t.Kind() { 21 {{range .Kinds -}} 22 {{if isParameterized . -}} 23 {{else -}} 24 case reflect.{{reflectKind .}}: 25 return a.{{getOne .}}(i) 26 {{end -}}; 27 {{end -}} 28 default: 29 val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) 30 val = reflect.Indirect(val) 31 return val.Interface() 32 } 33 } 34 35 ` 36 const setRaw = `// Set sets the value of the underlying array at the index i. 37 func (a *array) Set(i int, x interface{}) { 38 switch a.t.Kind() { 39 {{range .Kinds -}} 40 {{if isParameterized . -}} 41 {{else -}} 42 case reflect.{{reflectKind .}}: 43 xv := x.({{asType .}}) 44 a.{{setOne .}}(i, xv) 45 {{end -}} 46 {{end -}} 47 default: 48 xv := reflect.ValueOf(x) 49 val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) 50 val = reflect.Indirect(val) 51 val.Set(xv) 52 } 53 } 54 55 ` 56 57 const memsetRaw = `// Memset sets all values in the array. 58 func (a *array) Memset(x interface{}) error { 59 switch a.t { 60 {{range .Kinds -}} 61 {{if isParameterized . -}} 62 {{else -}} 63 case {{reflectKind .}}: 64 if xv, ok := x.({{asType .}}); ok { 65 data := a.{{sliceOf .}} 66 for i := range data{ 67 data[i] = xv 68 } 69 return nil 70 } 71 72 {{end -}} 73 {{end -}} 74 } 75 76 xv := reflect.ValueOf(x) 77 l := a.Len() 78 for i := 0; i < l; i++ { 79 val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) 80 val = reflect.Indirect(val) 81 val.Set(xv) 82 } 83 return nil 84 } 85 ` 86 87 const arrayEqRaw = ` // Eq checks that any two arrays are equal 88 func (a array) Eq(other interface{}) bool { 89 if oa, ok := other.(*array); ok { 90 if oa.t != a.t { 91 return false 92 } 93 94 if oa.Len() != a.Len() { 95 return false 96 } 97 /* 98 if oa.C != a.C { 99 return false 100 } 101 */ 102 103 // same exact thing 104 if uintptr(unsafe.Pointer(&oa.Header.Raw[0])) == uintptr(unsafe.Pointer(&a.Header.Raw[0])){ 105 return true 106 } 107 108 switch a.t.Kind() { 109 {{range .Kinds -}} 110 {{if isParameterized . -}} 111 {{else -}} 112 case reflect.{{reflectKind .}}: 113 for i, v := range a.{{sliceOf .}} { 114 if oa.{{getOne .}}(i) != v { 115 return false 116 } 117 } 118 {{end -}} 119 {{end -}} 120 default: 121 for i := 0; i < a.Len(); i++{ 122 if !reflect.DeepEqual(a.Get(i), oa.Get(i)){ 123 return false 124 } 125 } 126 } 127 return true 128 } 129 return false 130 }` 131 132 const copyArrayIterRaw = `func copyArrayIter(dst, src array, diter, siter Iterator) (count int, err error){ 133 if dst.t != src.t { 134 panic("Cannot copy arrays of different types") 135 } 136 137 if diter == nil && siter == nil { 138 return copyArray(dst, src), nil 139 } 140 141 if (diter != nil && siter == nil) || (diter == nil && siter != nil) { 142 return 0, errors.Errorf("Cannot copy array when only one iterator was passed in") 143 } 144 145 k := dest.t.Kind() 146 var i, j int 147 var validi, validj bool 148 for { 149 if i, validi, err = diter.NextValidity(); err != nil { 150 if err = handleNoOp(err); err != nil { 151 return count, err 152 } 153 break 154 } 155 if j, validj, err = siter.NextValidity(); err != nil { 156 if err = handleNoOp(err); err != nil { 157 return count, err 158 } 159 break 160 } 161 switch k { 162 {{range .Kinds -}} 163 {{if isParameterized . -}} 164 {{else -}} 165 case reflect.{{reflectKind .}}: 166 dest.{{setOne .}}(i, src.{{getOne .}}(j)) 167 {{end -}} 168 {{end -}} 169 default: 170 dest.Set(i, src.Get(j)) 171 } 172 count++ 173 } 174 175 } 176 ` 177 178 const memsetIterRaw = ` 179 func (a *array) memsetIter(x interface{}, it Iterator) (err error) { 180 var i int 181 switch a.t{ 182 {{range .Kinds -}} 183 {{if isParameterized . -}} 184 {{else -}} 185 case {{reflectKind .}}: 186 xv, ok := x.({{asType .}}) 187 if !ok { 188 return errors.Errorf(dtypeMismatch, a.t, x) 189 } 190 data := a.{{sliceOf .}} 191 for i, err = it.Next(); err == nil; i, err = it.Next(){ 192 data[i] = xv 193 } 194 err = handleNoOp(err) 195 {{end -}} 196 {{end -}} 197 default: 198 xv := reflect.ValueOf(x) 199 for i, err = it.Next(); err == nil; i, err = it.Next(){ 200 val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) 201 val = reflect.Indirect(val) 202 val.Set(xv) 203 } 204 err = handleNoOp(err) 205 } 206 return 207 } 208 209 ` 210 211 const zeroIterRaw = `func (a *array) zeroIter(it Iterator) (err error){ 212 var i int 213 switch a.t { 214 {{range .Kinds -}} 215 {{if isParameterized . -}} 216 {{else -}} 217 case {{reflectKind .}}: 218 data := a.{{sliceOf .}} 219 for i, err = it.Next(); err == nil; i, err = it.Next(){ 220 data[i] = {{if eq .String "bool" -}} 221 false 222 {{else if eq .String "string" -}}"" 223 {{else if eq .String "unsafe.Pointer" -}}nil 224 {{else -}}0{{end -}} 225 } 226 err = handleNoOp(err) 227 {{end -}} 228 {{end -}} 229 default: 230 for i, err = it.Next(); err == nil; i, err = it.Next(){ 231 val := reflect.NewAt(a.t.Type, storage.ElementAt(i, unsafe.Pointer(&a.Header.Raw[0]), a.t.Size())) 232 val = reflect.Indirect(val) 233 val.Set(reflect.Zero(a.t)) 234 } 235 err = handleNoOp(err) 236 } 237 return 238 } 239 ` 240 241 const reflectConstTemplateRaw = `var ( 242 {{range .Kinds -}} 243 {{if isParameterized . -}} 244 {{else -}} 245 {{short . | unexport}}Type = reflect.TypeOf({{asType .}}({{if eq .String "bool" -}} false {{else if eq .String "string" -}}"" {{else if eq .String "unsafe.Pointer" -}}nil {{else -}}0{{end -}})) 246 {{end -}} 247 {{end -}} 248 )` 249 250 var ( 251 AsSlice *template.Template 252 SimpleSet *template.Template 253 SimpleGet *template.Template 254 Get *template.Template 255 Set *template.Template 256 Memset *template.Template 257 MemsetIter *template.Template 258 Eq *template.Template 259 ZeroIter *template.Template 260 ReflectType *template.Template 261 ) 262 263 func init() { 264 AsSlice = template.Must(template.New("AsSlice").Funcs(funcs).Parse(asSliceRaw)) 265 SimpleSet = template.Must(template.New("SimpleSet").Funcs(funcs).Parse(setBasicRaw)) 266 SimpleGet = template.Must(template.New("SimpleGet").Funcs(funcs).Parse(getBasicRaw)) 267 Get = template.Must(template.New("Get").Funcs(funcs).Parse(getRaw)) 268 Set = template.Must(template.New("Set").Funcs(funcs).Parse(setRaw)) 269 Memset = template.Must(template.New("Memset").Funcs(funcs).Parse(memsetRaw)) 270 MemsetIter = template.Must(template.New("MemsetIter").Funcs(funcs).Parse(memsetIterRaw)) 271 Eq = template.Must(template.New("ArrayEq").Funcs(funcs).Parse(arrayEqRaw)) 272 ZeroIter = template.Must(template.New("Zero").Funcs(funcs).Parse(zeroIterRaw)) 273 ReflectType = template.Must(template.New("ReflectType").Funcs(funcs).Parse(reflectConstTemplateRaw)) 274 } 275 276 func generateArrayMethods(f io.Writer, ak Kinds) { 277 Set.Execute(f, ak) 278 fmt.Fprintf(f, "\n\n\n") 279 Get.Execute(f, ak) 280 fmt.Fprintf(f, "\n\n\n") 281 Memset.Execute(f, ak) 282 fmt.Fprintf(f, "\n\n\n") 283 MemsetIter.Execute(f, ak) 284 fmt.Fprintf(f, "\n\n\n") 285 Eq.Execute(f, ak) 286 fmt.Fprintf(f, "\n\n\n") 287 ZeroIter.Execute(f, ak) 288 fmt.Fprintf(f, "\n\n\n") 289 } 290 291 func generateHeaderGetSet(f io.Writer, ak Kinds) { 292 for _, k := range ak.Kinds { 293 if !isParameterized(k) { 294 fmt.Fprintf(f, "/* %v */\n\n", k) 295 AsSlice.Execute(f, k) 296 SimpleSet.Execute(f, k) 297 SimpleGet.Execute(f, k) 298 fmt.Fprint(f, "\n") 299 } 300 } 301 } 302 303 func generateReflectTypes(f io.Writer, ak Kinds) { 304 ReflectType.Execute(f, ak) 305 fmt.Fprintf(f, "\n\n\n") 306 }