gorgonia.org/tensor@v0.9.24/defaultengine_matop_stack.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 // This file contains code for the execution engine to stack tensors 8 9 func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { 10 opdims := t.Dims() 11 if axis >= opdims+1 { 12 err = errors.Errorf(dimMismatch, opdims+1, axis) 13 return 14 } 15 16 newShape := Shape(BorrowInts(opdims + 1)) 17 newShape[axis] = len(others) + 1 18 shape := t.Shape() 19 var cur int 20 for i, s := range shape { 21 if i == axis { 22 cur++ 23 } 24 newShape[cur] = s 25 cur++ 26 } 27 28 info := t.Info() 29 var newStrides []int 30 if info.o.IsColMajor() { 31 newStrides = newShape.CalcStridesColMajor() 32 } else { 33 newStrides = newShape.CalcStrides() 34 35 } 36 ap := MakeAP(newShape, newStrides, info.o, info.Δ) 37 38 allNoMat := !t.RequiresIterator() 39 for _, ot := range others { 40 if allNoMat && ot.RequiresIterator() { 41 allNoMat = false 42 } 43 } 44 45 retVal = recycledDense(t.Dtype(), ap.Shape(), WithEngine(e)) 46 retVal.setAP(&ap) 47 48 // the "viewStack" method is the more generalized method 49 // and will work for all Tensors, regardless of whether it's a view 50 // But the simpleStack is faster, and is an optimization 51 52 if allNoMat { 53 retVal = e.denseSimpleStack(t, retVal, axis, others) 54 } else { 55 retVal, err = e.denseViewStack(t, retVal, axis, others) 56 } 57 return 58 } 59 60 func (e StdEng) denseSimpleStack(t, retVal DenseTensor, axis int, others []DenseTensor) DenseTensor { 61 switch axis { 62 case 0: 63 copyDense(retVal, t) 64 next := t.len() 65 for _, ot := range others { 66 copyDenseSliced(retVal, next, retVal.len(), ot, 0, ot.len()) 67 next += ot.len() 68 } 69 default: 70 axisStride := retVal.Info().Strides()[axis] 71 batches := retVal.len() / axisStride 72 73 destStart := 0 74 start := 0 75 end := start + axisStride 76 77 for i := 0; i < batches; i++ { 78 copyDenseSliced(retVal, destStart, retVal.len(), t, start, end) 79 for _, ot := range others { 80 destStart += axisStride 81 copyDenseSliced(retVal, destStart, retVal.len(), ot, start, end) 82 i++ 83 } 84 destStart += axisStride 85 start += axisStride 86 end += axisStride 87 } 88 } 89 return retVal 90 } 91 92 func (e StdEng) denseViewStack(t, retVal DenseTensor, axis int, others []DenseTensor) (DenseTensor, error) { 93 axisStride := retVal.Info().Strides()[axis] 94 batches := retVal.len() / axisStride 95 96 it := IteratorFromDense(t) 97 its := make([]Iterator, 0, len(others)) 98 for _, ot := range others { 99 oter := IteratorFromDense(ot) 100 its = append(its, oter) 101 } 102 103 err := e.doViewStack(t, retVal, axisStride, batches, it, others, its) 104 return retVal, err 105 } 106 107 func (e StdEng) doViewStack(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) error { 108 switch int(t.Dtype().Size()) { 109 case 1: 110 return e.doViewStack1(t, retVal, axisStride, batches, it, others, its) 111 case 2: 112 return e.doViewStack2(t, retVal, axisStride, batches, it, others, its) 113 case 4: 114 return e.doViewStack4(t, retVal, axisStride, batches, it, others, its) 115 case 8: 116 return e.doViewStack8(t, retVal, axisStride, batches, it, others, its) 117 default: 118 return e.doViewStackArbitrary(t, retVal, axisStride, batches, it, others, its) 119 } 120 } 121 122 func (e StdEng) doViewStack1(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { 123 data := retVal.hdr().Uint8s()[:0] 124 var mask []bool 125 var retIsMasked bool 126 if mt, ok := t.(MaskedTensor); ok { 127 retIsMasked = mt.IsMasked() 128 } 129 for _, ot := range others { 130 if mt, ok := ot.(MaskedTensor); ok { 131 retIsMasked = retIsMasked || mt.IsMasked() 132 } 133 } 134 135 f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) { 136 var tmask []bool 137 if mt, ok := t.(MaskedTensor); ok { 138 tmask = mt.Mask() 139 isMasked = mt.IsMasked() 140 } 141 142 for last = 0; last < axisStride; last++ { 143 id, err := it.Next() 144 if handleNoOp(err) != nil { 145 return -1, isMasked, errors.Wrap(err, "doviewStackfailed") 146 } 147 if err != nil { 148 break 149 } 150 data = append(data, t.hdr().Uint8s()[id]) 151 if isMasked { 152 mask = append(mask, tmask[id]) 153 } 154 } 155 return 156 } 157 158 for i := 0; i < batches; i++ { 159 var last int 160 var isMasked bool 161 if last, isMasked, err = f(t, it); err != nil { 162 return 163 } 164 if retIsMasked && (!isMasked) { 165 mask = append(mask, make([]bool, last)...) 166 } 167 for j, ot := range others { 168 if last, isMasked, err = f(ot, its[j]); err != nil { 169 return 170 } 171 if retIsMasked && (!isMasked) { 172 mask = append(mask, make([]bool, last)...) 173 } 174 } 175 } 176 177 if mt, ok := retVal.(MaskedTensor); ok { 178 mt.SetMask(mask) 179 } 180 return nil 181 } 182 183 func (e StdEng) doViewStack2(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { 184 data := retVal.hdr().Uint16s()[:0] 185 var mask []bool 186 var retIsMasked bool 187 if mt, ok := t.(MaskedTensor); ok { 188 retIsMasked = mt.IsMasked() 189 } 190 for _, ot := range others { 191 if mt, ok := ot.(MaskedTensor); ok { 192 retIsMasked = retIsMasked || mt.IsMasked() 193 } 194 } 195 196 f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) { 197 var tmask []bool 198 if mt, ok := t.(MaskedTensor); ok { 199 tmask = mt.Mask() 200 isMasked = mt.IsMasked() 201 } 202 203 for last = 0; last < axisStride; last++ { 204 id, err := it.Next() 205 if handleNoOp(err) != nil { 206 return -1, isMasked, errors.Wrap(err, "doviewStackfailed") 207 } 208 if err != nil { 209 break 210 } 211 data = append(data, t.hdr().Uint16s()[id]) 212 if isMasked { 213 mask = append(mask, tmask[id]) 214 } 215 } 216 return 217 } 218 219 for i := 0; i < batches; i++ { 220 var last int 221 var isMasked bool 222 if last, isMasked, err = f(t, it); err != nil { 223 return 224 } 225 if retIsMasked && (!isMasked) { 226 mask = append(mask, make([]bool, last)...) 227 } 228 for j, ot := range others { 229 if last, isMasked, err = f(ot, its[j]); err != nil { 230 return 231 } 232 if retIsMasked && (!isMasked) { 233 mask = append(mask, make([]bool, last)...) 234 } 235 } 236 } 237 238 if mt, ok := retVal.(MaskedTensor); ok { 239 mt.SetMask(mask) 240 } 241 return nil 242 } 243 244 func (e StdEng) doViewStack4(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { 245 data := retVal.hdr().Uint32s()[:0] 246 var mask []bool 247 var retIsMasked bool 248 if mt, ok := t.(MaskedTensor); ok { 249 retIsMasked = mt.IsMasked() 250 } 251 for _, ot := range others { 252 if mt, ok := ot.(MaskedTensor); ok { 253 retIsMasked = retIsMasked || mt.IsMasked() 254 } 255 } 256 257 f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) { 258 var tmask []bool 259 if mt, ok := t.(MaskedTensor); ok { 260 tmask = mt.Mask() 261 isMasked = mt.IsMasked() 262 } 263 264 for last = 0; last < axisStride; last++ { 265 id, err := it.Next() 266 if handleNoOp(err) != nil { 267 return -1, isMasked, errors.Wrap(err, "doviewStackfailed") 268 } 269 if err != nil { 270 break 271 } 272 data = append(data, t.hdr().Uint32s()[id]) 273 if isMasked { 274 mask = append(mask, tmask[id]) 275 } 276 } 277 return 278 } 279 280 for i := 0; i < batches; i++ { 281 var last int 282 var isMasked bool 283 if last, isMasked, err = f(t, it); err != nil { 284 return 285 } 286 if retIsMasked && (!isMasked) { 287 mask = append(mask, make([]bool, last)...) 288 } 289 for j, ot := range others { 290 if last, isMasked, err = f(ot, its[j]); err != nil { 291 return 292 } 293 if retIsMasked && (!isMasked) { 294 mask = append(mask, make([]bool, last)...) 295 } 296 } 297 } 298 299 if mt, ok := retVal.(MaskedTensor); ok { 300 mt.SetMask(mask) 301 } 302 return nil 303 } 304 305 func (e StdEng) doViewStack8(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { 306 data := retVal.hdr().Uint64s()[:0] 307 var mask []bool 308 var retIsMasked bool 309 if mt, ok := t.(MaskedTensor); ok { 310 retIsMasked = mt.IsMasked() 311 } 312 for _, ot := range others { 313 if mt, ok := ot.(MaskedTensor); ok { 314 retIsMasked = retIsMasked || mt.IsMasked() 315 } 316 } 317 318 f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) { 319 var tmask []bool 320 if mt, ok := t.(MaskedTensor); ok { 321 tmask = mt.Mask() 322 isMasked = mt.IsMasked() 323 } 324 325 for last = 0; last < axisStride; last++ { 326 id, err := it.Next() 327 if handleNoOp(err) != nil { 328 return -1, isMasked, errors.Wrap(err, "doviewStackfailed") 329 } 330 if err != nil { 331 break 332 } 333 data = append(data, t.hdr().Uint64s()[id]) 334 if isMasked { 335 mask = append(mask, tmask[id]) 336 } 337 } 338 return 339 } 340 341 for i := 0; i < batches; i++ { 342 var last int 343 var isMasked bool 344 if last, isMasked, err = f(t, it); err != nil { 345 return 346 } 347 if retIsMasked && (!isMasked) { 348 mask = append(mask, make([]bool, last)...) 349 } 350 for j, ot := range others { 351 if last, isMasked, err = f(ot, its[j]); err != nil { 352 return 353 } 354 if retIsMasked && (!isMasked) { 355 mask = append(mask, make([]bool, last)...) 356 } 357 } 358 } 359 360 if mt, ok := retVal.(MaskedTensor); ok { 361 mt.SetMask(mask) 362 } 363 return nil 364 } 365 366 func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) { 367 dt := t.Dtype() 368 data := retVal.hdr().Raw[:0] // truncate to 0 369 size := int(dt.Size()) 370 var mask []bool 371 var retIsMasked bool 372 if mt, ok := t.(MaskedTensor); ok { 373 retIsMasked = mt.IsMasked() 374 } 375 for _, ot := range others { 376 if mt, ok := ot.(MaskedTensor); ok { 377 retIsMasked = retIsMasked || mt.IsMasked() 378 } 379 } 380 381 f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) { 382 var tmask []bool 383 if mt, ok := t.(MaskedTensor); ok { 384 tmask = mt.Mask() 385 isMasked = mt.IsMasked() 386 } 387 bs := t.hdr().Raw 388 389 for last = 0; last < axisStride; last++ { 390 id, err := it.Next() 391 if handleNoOp(err) != nil { 392 return -1, isMasked, errors.Wrap(err, "doviewStackfailed") 393 } 394 if err != nil { 395 break 396 } 397 v := bs[id*size : id*size+size] 398 data = append(data, v...) 399 if isMasked { 400 mask = append(mask, tmask[id]) 401 } 402 } 403 return 404 } 405 406 for i := 0; i < batches; i++ { 407 var last int 408 var isMasked bool 409 if last, isMasked, err = f(t, it); err != nil { 410 return 411 } 412 if retIsMasked && (!isMasked) { 413 mask = append(mask, make([]bool, last)...) 414 } 415 for j, ot := range others { 416 if last, isMasked, err = f(ot, its[j]); err != nil { 417 return 418 } 419 if retIsMasked && (!isMasked) { 420 mask = append(mask, make([]bool, last)...) 421 } 422 } 423 } 424 425 if mt, ok := retVal.(MaskedTensor); ok { 426 mt.SetMask(mask) 427 } 428 return nil 429 }