gorgonia.org/tensor@v0.9.24/defaultengine_matop_misc.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 "gorgonia.org/tensor/internal/storage" 6 ) 7 8 var ( 9 _ Diager = StdEng{} 10 ) 11 12 type fastcopier interface { 13 fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error 14 } 15 16 // Repeat ... 17 func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { 18 switch tt := t.(type) { 19 case DenseTensor: 20 newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) 21 if err != nil { 22 return nil, err 23 } 24 rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) 25 return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) 26 default: 27 return nil, errors.Errorf("NYI") 28 } 29 } 30 31 // RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. 32 func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { 33 switch tt := t.(type) { 34 case DenseTensor: 35 newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) 36 if err != nil { 37 return nil, err 38 } 39 40 rr, ok := reuse.(DenseTensor) 41 if !ok { 42 return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse) 43 } 44 if !reuse.Shape().Eq(newShape) { 45 return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape) 46 } 47 return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) 48 default: 49 return nil, errors.Errorf("NYI") 50 } 51 } 52 53 func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { 54 if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { 55 return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") 56 } 57 newAxis = axis 58 if axis == AllAxes { 59 newAxis = 0 60 } 61 62 return 63 } 64 65 func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) { 66 d, err := assertDense(reuse) 67 if err != nil { 68 return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense") 69 } 70 var outers int 71 if t.IsScalar() { 72 outers = 1 73 } else { 74 outers = ProdInts(t.Shape()[0:axis]) 75 } 76 77 var stride, newStride int 78 if newShape.IsVector() || t.IsVector() { 79 stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector 80 } else { 81 stride = t.ostrides()[axis] 82 } 83 84 if newShape.IsVector() { 85 newStride = 1 86 } else { 87 newStride = d.ostrides()[axis] 88 } 89 90 var destStart, srcStart int 91 // fastCopy is not bypassing the copyDenseSliced method to populate the output tensor 92 var fastCopy bool 93 var fce fastcopier 94 // we need an engine for fastCopying... 95 e := t.Engine() 96 // e can never be nil. Error would have occurred elsewhere 97 var ok bool 98 if fce, ok = e.(fastcopier); ok { 99 fastCopy = true 100 } 101 102 // In this case, let's not implement the fast copy to keep the code readable 103 if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() { 104 fastCopy = false 105 } 106 107 // if d is not a fastcopier, then we also cannot use fast copy 108 if _, ok := d.Engine().(fastcopier); !ok { 109 fastCopy = false 110 } 111 112 if fastCopy { 113 if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil { 114 return nil, err 115 } 116 return d, nil 117 } 118 119 for i := 0; i < outers; i++ { 120 for j := 0; j < size; j++ { 121 var tmp int 122 tmp = repeats[j] 123 124 for k := 0; k < tmp; k++ { 125 if srcStart >= t.len() || destStart+stride > d.len() { 126 break 127 } 128 copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len()) 129 destStart += newStride 130 } 131 srcStart += stride 132 } 133 } 134 return d, nil 135 } 136 137 func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, stride, newStride int, repeats []int) error { 138 sarr := src.arr() 139 darr := dest.arr() 140 141 var destStart, srcStart int 142 for i := 0; i < outers; i++ { 143 // faster shortcut for common case. 144 // 145 // Consider a case where: 146 // a := ⎡ 1 ⎤ 147 // ⎢ 2 ⎥ 148 // ⎢ 3 ⎥ 149 // ⎣ 4 ⎦ 150 // a has a shape of (4, 1). it is a *Dense. 151 // 152 // Now assume we want to repeat it on axis 1, 3 times. We want to repeat it into `b`, 153 // which is already allocated and zeroed, as shown below 154 // 155 // b := ⎡ 0 0 0 ⎤ 156 // ⎢ 0 0 0 ⎥ 157 // ⎢ 0 0 0 ⎥ 158 // ⎣ 0 0 0 ⎦ 159 // 160 // Now, both `a` and `b` have a stride of 1. 161 // 162 // The desired result is: 163 // b := ⎡ 1 1 1 ⎤ 164 // ⎢ 2 2 2 ⎥ 165 // ⎢ 3 3 3 ⎥ 166 // ⎣ 4 4 4 ⎦ 167 /// 168 // Observe that this is simply broadcasting (copying) a[0] (a scalar value) to the row b[0], and so on and so forth. 169 // This can be done without knowing the full type - we simply copy the bytes over. 170 if stride == 1 && newStride == 1 { 171 for sz := 0; sz < size; sz++ { 172 tmp := repeats[sz] 173 174 // first we get the bounds of the src and the dest 175 // the srcStart and destStart are the indices assuming a flat array of []T 176 // we need to get the byte slice equivalent. 177 bSrcStart := srcStart * int(sarr.t.Size()) 178 bSrcEnd := (srcStart + stride) * int(sarr.t.Size()) 179 bDestStart := destStart * int(darr.t.Size()) 180 bDestEnd := (destStart + tmp) * int(darr.t.Size()) 181 182 // then we get the data as a slice of raw bytes 183 sBS := sarr.Header.Raw 184 dBS := darr.Header.Raw 185 186 // recall that len(src) < len(dest) 187 // it's easier to understand if we define the ranges. 188 // Less prone to errors. 189 sRange := sBS[bSrcStart:bSrcEnd] 190 dRange := dBS[bDestStart:bDestEnd] 191 192 // finally we copy things. 193 for i := 0; i < len(dRange); i += len(sRange) { 194 copy(dRange[i:], sRange) 195 } 196 srcStart += stride 197 destStart += tmp 198 } 199 200 // we can straightaway broadcast 201 202 continue 203 } 204 205 for j := 0; j < size; j++ { 206 var tmp int 207 tmp = repeats[j] 208 var tSlice array 209 210 tSlice = sarr.slice(srcStart, src.len()) 211 212 for k := 0; k < tmp; k++ { 213 if srcStart >= src.len() || destStart+stride > dest.len() { 214 break 215 } 216 217 dSlice := darr.slice(destStart, destStart+newStride) 218 219 // THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED. 220 storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header) 221 222 destStart += newStride 223 } 224 srcStart += stride 225 } 226 } 227 return nil 228 } 229 230 // Concat tensors 231 func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { 232 switch tt := t.(type) { 233 case DenseTensor: 234 var denses []DenseTensor 235 if denses, err = tensorsToDenseTensors(others); err != nil { 236 return nil, errors.Wrap(err, "Concat failed") 237 } 238 return e.denseConcat(tt, axis, denses) 239 default: 240 return nil, errors.Errorf("NYI") 241 } 242 } 243 244 func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) { 245 ss := make([]Shape, len(Ts)) 246 var err error 247 var isMasked bool 248 for i, T := range Ts { 249 ss[i] = T.Shape() 250 if mt, ok := T.(MaskedTensor); ok { 251 isMasked = isMasked || mt.IsMasked() 252 } 253 } 254 255 var newShape Shape 256 if newShape, err = a.Shape().Concat(axis, ss...); err != nil { 257 return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") 258 } 259 260 retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) 261 if isMasked { 262 retVal.makeMask() 263 } 264 265 all := make([]DenseTensor, len(Ts)+1) 266 all[0] = a 267 copy(all[1:], Ts) 268 269 // TODO: OPIMIZATION 270 // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) 271 // just flat copy 272 // 273 274 // isOuter is true when the axis is the outermost axis 275 // isInner is true when the axis is the inner most axis 276 isOuter := axis == 0 277 isInner := axis == (a.Shape().Dims() - 1) 278 279 // special case 280 var start, end int 281 for _, T := range all { 282 end += T.Shape()[axis] 283 slices := make([]Slice, axis+1) 284 slices[axis] = makeRS(start, end) 285 286 var v *Dense 287 if v, err = sliceDense(retVal, slices...); err != nil { 288 return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") 289 } 290 291 // keep dims after slicing 292 switch { 293 case v.IsVector() && T.IsMatrix() && axis == 0: 294 v.reshape(v.shape[0], 1) 295 case T.IsRowVec() && axis == 0: 296 T.reshape(T.Shape()[1]) 297 case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): 298 copyArray(v.arrPtr(), T.arrPtr()) 299 if mt, ok := T.(MaskedTensor); ok { 300 copy(v.mask, mt.Mask()) 301 } 302 start = end 303 continue 304 default: 305 diff := retVal.Shape().Dims() - v.Shape().Dims() 306 if diff > 0 && isOuter { 307 newShape := make(Shape, v.Shape().Dims()+diff) 308 for i := 0; i < diff; i++ { 309 newShape[i] = 1 310 } 311 copy(newShape[diff:], v.Shape()) 312 v.reshape(newShape...) 313 } else if diff > 0 && isInner { 314 newShape := v.Shape().Clone() 315 newStrides := v.strides 316 for i := 0; i < diff; i++ { 317 newShape = append(newShape, 1) 318 newStrides = append(newStrides, 1) 319 } 320 v.shape = newShape 321 v.strides = newStrides 322 } else if T.Shape()[axis] == 1 { 323 if err := v.unsqueeze(axis); err != nil { 324 return nil, errors.Wrapf(err, "Unable to keep dims after slicing a shape %v on axis %d where the size is 1", T.Shape(), axis) 325 } 326 } 327 } 328 329 var vmask, Tmask []bool 330 vmask = v.mask 331 v.mask = nil 332 if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { 333 Tmask = mt.Mask() 334 mt.SetMask(nil) 335 336 } 337 338 if err = assignArray(v, T); err != nil { 339 return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") 340 } 341 // if it's a masked tensor, we copy the mask as well 342 if Tmask != nil { 343 if vmask != nil { 344 if cap(vmask) < len(Tmask) { 345 vmask2 := make([]bool, len(Tmask)) 346 copy(vmask2, vmask) 347 vmask = vmask2 348 } 349 copy(vmask, Tmask) 350 v.SetMask(vmask) 351 } 352 // mt.SetMask(Tmask) 353 } 354 355 start = end 356 } 357 358 return retVal, nil 359 } 360 361 // Diag ... 362 func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { 363 a, ok := t.(DenseTensor) 364 if !ok { 365 return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") 366 } 367 368 if a.Dims() != 2 { 369 err = errors.Errorf(dimMismatch, 2, a.Dims()) 370 return 371 } 372 373 if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { 374 return nil, errors.Wrap(err, "Diagonal") 375 } 376 377 rstride := a.Strides()[0] 378 cstride := a.Strides()[1] 379 380 r := a.Shape()[0] 381 c := a.Shape()[1] 382 383 m := MinInt(r, c) 384 stride := rstride + cstride 385 386 b := a.Clone().(DenseTensor) 387 b.Zero() 388 389 switch a.rtype().Size() { 390 case 1: 391 bdata := b.hdr().Uint8s() 392 adata := a.hdr().Uint8s() 393 for i := 0; i < m; i++ { 394 bdata[i] = adata[i*stride] 395 } 396 case 2: 397 bdata := b.hdr().Uint16s() 398 adata := a.hdr().Uint16s() 399 for i := 0; i < m; i++ { 400 bdata[i] = adata[i*stride] 401 } 402 case 4: 403 bdata := b.hdr().Uint32s() 404 adata := a.hdr().Uint32s() 405 for i := 0; i < m; i++ { 406 bdata[i] = adata[i*stride] 407 } 408 case 8: 409 bdata := b.hdr().Uint64s() 410 adata := a.hdr().Uint64s() 411 for i := 0; i < m; i++ { 412 bdata[i] = adata[i*stride] 413 } 414 default: 415 return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) 416 } 417 return b, nil 418 }