gorgonia.org/tensor@v0.9.24/defaultengine_matop_transpose.go (about) 1 // +build !inplacetranspose 2 3 package tensor 4 5 import ( 6 "github.com/pkg/errors" 7 ) 8 9 func (e StdEng) Transpose(a Tensor, expStrides []int) error { 10 if !a.IsNativelyAccessible() { 11 return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") 12 } 13 if dt, ok := a.(DenseTensor); ok { 14 e.denseTranspose(dt, expStrides) 15 return nil 16 } 17 return errors.Errorf("Tranpose for tensor of %T not supported", a) 18 } 19 20 func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { 21 if a.rtype() == String.Type { 22 e.denseTransposeString(a, expStrides) 23 return 24 } 25 26 e.transposeMask(a) 27 28 switch a.rtype().Size() { 29 case 1: 30 e.denseTranspose1(a, expStrides) 31 case 2: 32 e.denseTranspose2(a, expStrides) 33 case 4: 34 e.denseTranspose4(a, expStrides) 35 case 8: 36 e.denseTranspose8(a, expStrides) 37 default: 38 e.denseTransposeArbitrary(a, expStrides) 39 } 40 } 41 42 func (e StdEng) transposeMask(a DenseTensor) { 43 if !a.(*Dense).IsMasked() { 44 return 45 } 46 47 orig := a.(*Dense).Mask() 48 tmp := make([]bool, len(orig)) 49 50 it := newFlatIterator(a.Info()) 51 var j int 52 for i, err := it.Next(); err == nil; i, err = it.Next() { 53 tmp[j] = orig[i] 54 j++ 55 } 56 copy(orig, tmp) 57 } 58 59 func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { 60 var tmpArr array 61 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 62 u8s := tmpArr.Uint8s() 63 64 orig := a.hdr().Uint8s() 65 it := newFlatIterator(a.Info()) 66 var j int 67 for i, err := it.Next(); err == nil; i, err = it.Next() { 68 u8s[j] = orig[i] 69 j++ 70 } 71 copy(orig, u8s) 72 } 73 74 func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { 75 var tmpArr array 76 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 77 u16s := tmpArr.Uint16s() 78 79 orig := a.hdr().Uint16s() 80 it := newFlatIterator(a.Info()) 81 var j int 82 for i, err := it.Next(); err == nil; i, err = it.Next() { 83 u16s[j] = orig[i] 84 j++ 85 } 86 copy(orig, u16s) 87 } 88 89 func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { 90 var tmpArr array 91 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 92 u32s := tmpArr.Uint32s() 93 94 orig := a.hdr().Uint32s() 95 it := newFlatIterator(a.Info()) 96 var j int 97 for i, err := it.Next(); err == nil; i, err = it.Next() { 98 u32s[j] = orig[i] 99 j++ 100 } 101 copy(orig, u32s) 102 } 103 104 func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { 105 var tmpArr array 106 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 107 u64s := tmpArr.Uint64s() 108 109 orig := a.hdr().Uint64s() 110 it := newFlatIterator(a.Info()) 111 var j int 112 for i, err := it.Next(); err == nil; i, err = it.Next() { 113 u64s[j] = orig[i] 114 j++ 115 } 116 copy(orig, u64s) 117 } 118 119 func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { 120 var tmpArr array 121 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 122 strs := tmpArr.Strings() 123 124 orig := a.hdr().Strings() 125 it := newFlatIterator(a.Info()) 126 var j int 127 for i, err := it.Next(); err == nil; i, err = it.Next() { 128 strs[j] = orig[i] 129 j++ 130 } 131 copy(orig, strs) 132 } 133 134 func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { 135 rtype := a.rtype() 136 typeSize := int(rtype.Size()) 137 var tmpArr array 138 e.makeArray(&tmpArr, a.Dtype(), a.Size()) 139 // arbs := storage.AsByteSlice(tmpArr.hdr(), rtype) 140 arbs := tmpArr.byteSlice() 141 142 orig := a.hdr().Raw 143 it := newFlatIterator(a.Info()) 144 var j int 145 for i, err := it.Next(); err == nil; i, err = it.Next() { 146 srcStart := i * typeSize 147 srcEnd := srcStart + typeSize 148 dstStart := j * typeSize 149 dstEnd := dstStart + typeSize 150 151 copy(arbs[dstStart:dstEnd], orig[srcStart:srcEnd]) 152 j++ 153 } 154 copy(orig, arbs) 155 }