github.com/wzzhu/tensor@v0.9.24/consopt.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 6 "github.com/wzzhu/tensor/internal/storage" 7 ) 8 9 // ConsOpt is a tensor construction option. 10 type ConsOpt func(Tensor) 11 12 // Of is a construction option for a Tensor. 13 func Of(a Dtype) ConsOpt { 14 Register(a) 15 f := func(t Tensor) { 16 switch tt := t.(type) { 17 case *Dense: 18 tt.t = a 19 case *CS: 20 tt.t = a 21 default: 22 panic("Unsupported Tensor type") 23 } 24 } 25 return f 26 } 27 28 // WithBacking is a construction option for a Tensor 29 // Use it as such: 30 // backing := []float64{1,2,3,4} 31 // t := New(WithBacking(backing)) 32 // It can be used with other construction options like WithShape 33 func WithBacking(x interface{}, argMask ...[]bool) ConsOpt { 34 var mask []bool 35 if len(argMask) > 0 { 36 mask = argMask[0] 37 } 38 f := func(t Tensor) { 39 if x == nil { 40 return 41 } 42 switch tt := t.(type) { 43 case *Dense: 44 tt.fromSlice(x) 45 if len(argMask) > 0 { 46 tt.addMask(mask) 47 } 48 default: 49 panic("Unsupported Tensor type") 50 } 51 } 52 return f 53 } 54 55 // WithMask is a construction option for a Tensor 56 // Use it as such: 57 // mask := []bool{true,true,false,false} 58 // t := New(WithBacking(backing), WithMask(mask)) 59 // It can be used with other construction options like WithShape 60 // The supplied mask can be any type. If non-boolean, then tensor mask is set to true 61 // wherever non-zero value is obtained 62 func WithMask(x interface{}) ConsOpt { 63 f := func(t Tensor) { 64 if x == nil { 65 return 66 } 67 switch tt := t.(type) { 68 case *Dense: 69 tt.MaskFromSlice(x) 70 default: 71 panic("Unsupported Tensor type") 72 } 73 } 74 return f 75 } 76 77 // WithShape is a construction option for a Tensor. It creates the ndarray in the required shape. 78 func WithShape(dims ...int) ConsOpt { 79 f := func(t Tensor) { 80 switch tt := t.(type) { 81 case *Dense: 82 throw := BorrowInts(len(dims)) 83 copy(throw, dims) 84 tt.setShape(throw...) 85 case *CS: 86 if len(dims) != 2 { 87 panic("Only sparse matrices are supported") 88 } 89 throw := BorrowInts(len(dims)) 90 copy(throw, dims) 91 tt.s = throw 92 93 default: 94 panic("Unsupported Tensor type") 95 } 96 } 97 return f 98 } 99 100 // FromScalar is a construction option for representing a scalar value as a Tensor 101 func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { 102 var mask []bool 103 if len(argMask) > 0 { 104 mask = argMask[0] 105 } 106 107 f := func(t Tensor) { 108 switch tt := t.(type) { 109 case *Dense: 110 xT := reflect.TypeOf(x) 111 sxT := reflect.SliceOf(xT) 112 xv := reflect.MakeSlice(sxT, 1, 1) // []T 113 xv0 := xv.Index(0) // xv[0] 114 xv0.Set(reflect.ValueOf(x)) 115 tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) 116 tt.t = Dtype{xT} 117 tt.mask = mask 118 119 default: 120 panic("Unsupported Tensor Type") 121 } 122 } 123 return f 124 } 125 126 // FromMemory is a construction option for creating a *Dense (for now) from memory location. This is a useful 127 // option for super large tensors that don't fit into memory - the user may need to `mmap` a file the tensor. 128 // 129 // Bear in mind that at the current stage of the ConsOpt design, the order of the ConsOpt is important. 130 // FromMemory requires the *Dense's Dtype be set already. 131 // This would fail (and panic): 132 // New(FromMemory(ptr, size), Of(Float64)) 133 // This would not: 134 // New(Of(Float64), FromMemory(ptr, size)) 135 // This behaviour of requiring the ConsOpts to be in order might be changed in the future. 136 // 137 // Memory must be manually managed by the caller. 138 // Tensors called with this construction option will not be returned to any pool - rather, all references to the pointers will be null'd. 139 // Use with caution. 140 //go:nocheckptr 141 func FromMemory(ptr uintptr, memsize uintptr) ConsOpt { 142 f := func(t Tensor) { 143 switch tt := t.(type) { 144 case *Dense: 145 tt.Header.Raw = nil // GC anything if needed 146 tt.Header.Raw = storage.FromMemory(ptr, memsize) 147 tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged) 148 default: 149 panic("Unsupported Tensor type") 150 } 151 } 152 return f 153 } 154 155 // WithEngine is a construction option that would cause a Tensor to be linked with an execution engine. 156 func WithEngine(e Engine) ConsOpt { 157 f := func(t Tensor) { 158 switch tt := t.(type) { 159 case *Dense: 160 tt.e = e 161 if e != nil && !e.AllocAccessible() { 162 tt.flag = MakeMemoryFlag(tt.flag, NativelyInaccessible) 163 } 164 165 tt.oe = nil 166 if oe, ok := e.(standardEngine); ok { 167 tt.oe = oe 168 } 169 case *CS: 170 tt.e = e 171 if e != nil && !e.AllocAccessible() { 172 tt.f = MakeMemoryFlag(tt.f, NativelyInaccessible) 173 } 174 } 175 } 176 return f 177 } 178 179 // AsFortran creates a *Dense with a col-major layout. 180 // If the optional backing argument is passed, the backing is assumed to be C-order (row major), and 181 // it will be transposed before being used. 182 func AsFortran(backing interface{}, argMask ...[]bool) ConsOpt { 183 var mask []bool 184 if len(argMask) > 0 { 185 mask = argMask[0] 186 } 187 f := func(t Tensor) { 188 switch tt := t.(type) { 189 case *Dense: 190 if backing != nil { 191 // put the data into the tensor, then make a clone tensor to transpose 192 tt.fromSliceOrArrayer(backing) 193 // create a temporary tensor, to which the transpose will be done 194 tmp := NewDense(tt.Dtype(), tt.shape.Clone()) 195 copyArray(tmp.arrPtr(), tt.arrPtr()) 196 tmp.SetMask(mask) 197 tmp.T() 198 tmp.Transpose() 199 // copy the data back to the current tensor 200 copyArray(tt.arrPtr(), tmp.arrPtr()) 201 tt.SetMask(tmp.Mask()) 202 // cleanup: return the temporary tensor back to the pool 203 ReturnTensor(tmp) 204 } 205 206 tt.AP.o = MakeDataOrder(tt.AP.o, ColMajor) 207 if tt.AP.shape != nil { 208 ReturnInts(tt.AP.strides) 209 tt.AP.strides = nil 210 tt.AP.strides = tt.AP.calcStrides() 211 } 212 case *CS: 213 panic("AsFortran is not an available option for Compressed Sparse layouts") 214 } 215 } 216 return f 217 } 218 219 func AsDenseDiag(backing interface{}) ConsOpt { 220 f := func(t Tensor) { 221 switch tt := t.(type) { 222 case *Dense: 223 if bt, ok := backing.(Tensor); ok { 224 backing = bt.Data() 225 } 226 xT := reflect.TypeOf(backing) 227 if xT.Kind() != reflect.Slice { 228 panic("Expected a slice") 229 } 230 xV := reflect.ValueOf(backing) 231 l := xV.Len() 232 // elT := xT.Elem() 233 234 sli := reflect.MakeSlice(xT, l*l, l*l) 235 236 shape := Shape{l, l} 237 strides := shape.CalcStrides() 238 for i := 0; i < l; i++ { 239 idx, err := Ltoi(shape, strides, i, i) 240 if err != nil { 241 panic(err) 242 } 243 244 at := sli.Index(idx) 245 xi := xV.Index(i) 246 at.Set(xi) 247 } 248 249 tt.fromSliceOrArrayer(sli.Interface()) 250 tt.setShape(l, l) 251 252 default: 253 panic("AsDenseDiag is not available as an option for CS") 254 } 255 } 256 return f 257 }