github.com/wzzhu/tensor@v0.9.24/defaultengine_matop_transpose_inplace.go (about) 1 // +build inplacetranspose 2 3 package tensor 4 5 import ( 6 "github.com/pkg/errors" 7 ) 8 9 func (e StdEng) Transpose(a Tensor, expStrides []int) error { 10 if !a.IsNativelyAccessible() { 11 return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") 12 } 13 if dt, ok := a.(DenseTensor); ok { 14 e.denseTranspose(dt, expStrides) 15 return nil 16 } 17 return errors.Errorf("Tranpose for tensor of %T not supported", a) 18 } 19 20 func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) { 21 if a.rtype() == String.Type { 22 e.denseTransposeString(a, expStrides) 23 return 24 } 25 26 e.transposeMask(a) 27 28 switch a.rtype().Size() { 29 case 1: 30 e.denseTranspose1(a, expStrides) 31 case 2: 32 e.denseTranspose2(a, expStrides) 33 case 4: 34 e.denseTranspose4(a, expStrides) 35 case 8: 36 e.denseTranspose8(a, expStrides) 37 default: 38 e.denseTransposeArbitrary(a, expStrides) 39 } 40 } 41 42 func (e StdEng) transposeMask(a DenseTensor) { 43 if !a.(*Dense).IsMasked() { 44 return 45 } 46 47 shape := a.Shape() 48 if len(shape) != 2 { 49 // TODO(poopoothegorilla): currently only two dimensions are implemented 50 return 51 } 52 n, m := shape[0], shape[1] 53 mask := a.(*Dense).Mask() 54 size := len(mask) 55 56 track := NewBitMap(size) 57 track.Set(0) 58 track.Set(size - 1) 59 60 for i := 0; i < size; i++ { 61 srci := i 62 if track.IsSet(srci) { 63 continue 64 } 65 srcv := mask[srci] 66 for { 67 oc := srci % n 68 or := (srci - oc) / n 69 desti := oc*m + or 70 71 if track.IsSet(desti) { 72 break 73 } 74 track.Set(desti) 75 destv := mask[desti] 76 mask[desti] = srcv 77 srci = desti 78 srcv = destv 79 } 80 } 81 } 82 83 func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { 84 axes := a.transposeAxes() 85 size := a.len() 86 87 // first we'll create a bit-map to track which elements have been moved to their correct places 88 track := NewBitMap(size) 89 track.Set(0) 90 track.Set(size - 1) // first and last element of a transposedon't change 91 92 var saved, tmp byte 93 var i int 94 95 data := a.hdr().Uint8s() 96 if len(data) < 4 { 97 return 98 } 99 for i = 1; ; { 100 dest := a.transposeIndex(i, axes, expStrides) 101 102 if track.IsSet(i) && track.IsSet(dest) { 103 data[i] = saved 104 saved = 0 105 for i < size && track.IsSet(i) { 106 i++ 107 } 108 if i >= size { 109 break 110 } 111 continue 112 } 113 track.Set(i) 114 tmp = data[i] 115 data[i] = saved 116 saved = tmp 117 118 i = dest 119 } 120 } 121 122 func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { 123 axes := a.transposeAxes() 124 size := a.len() 125 126 // first we'll create a bit-map to track which elements have been moved to their correct places 127 track := NewBitMap(size) 128 track.Set(0) 129 track.Set(size - 1) // first and last element of a transposedon't change 130 131 var saved, tmp uint16 132 var i int 133 134 data := a.hdr().Uint16s() 135 if len(data) < 4 { 136 return 137 } 138 for i = 1; ; { 139 dest := a.transposeIndex(i, axes, expStrides) 140 141 if track.IsSet(i) && track.IsSet(dest) { 142 data[i] = saved 143 saved = 0 144 for i < size && track.IsSet(i) { 145 i++ 146 } 147 if i >= size { 148 break 149 } 150 continue 151 } 152 track.Set(i) 153 tmp = data[i] 154 data[i] = saved 155 saved = tmp 156 157 i = dest 158 } 159 } 160 161 func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { 162 axes := a.transposeAxes() 163 size := a.len() 164 165 // first we'll create a bit-map to track which elements have been moved to their correct places 166 track := NewBitMap(size) 167 track.Set(0) 168 track.Set(size - 1) // first and last element of a transposedon't change 169 170 var saved, tmp uint32 171 var i int 172 173 data := a.hdr().Uint32s() 174 if len(data) < 4 { 175 return 176 } 177 for i = 1; ; { 178 dest := a.transposeIndex(i, axes, expStrides) 179 180 if track.IsSet(i) && track.IsSet(dest) { 181 data[i] = saved 182 saved = 0 183 for i < size && track.IsSet(i) { 184 i++ 185 } 186 if i >= size { 187 break 188 } 189 continue 190 } 191 track.Set(i) 192 tmp = data[i] 193 data[i] = saved 194 saved = tmp 195 196 i = dest 197 } 198 } 199 200 func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { 201 axes := a.transposeAxes() 202 size := a.len() 203 204 // first we'll create a bit-map to track which elements have been moved to their correct places 205 track := NewBitMap(size) 206 track.Set(0) 207 track.Set(size - 1) // first and last element of a transposedon't change 208 209 var saved, tmp uint64 210 var i int 211 212 data := a.hdr().Uint64s() 213 if len(data) < 4 { 214 return 215 } 216 for i = 1; ; { 217 dest := a.transposeIndex(i, axes, expStrides) 218 if track.IsSet(i) && track.IsSet(dest) { 219 data[i] = saved 220 saved = 0 221 for i < size && track.IsSet(i) { 222 i++ 223 } 224 if i >= size { 225 break 226 } 227 continue 228 } 229 track.Set(i) 230 // log.Printf("i: %d start %d, end %d | tmp %v saved %v", i, start, end, tmp, saved) 231 tmp = data[i] 232 data[i] = saved 233 saved = tmp 234 235 i = dest 236 } 237 } 238 239 func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { 240 axes := a.transposeAxes() 241 size := a.len() 242 243 // first we'll create a bit-map to track which elements have been moved to their correct places 244 track := NewBitMap(size) 245 track.Set(0) 246 track.Set(size - 1) // first and last element of a transposedon't change 247 248 var saved, tmp string 249 var i int 250 251 data := a.hdr().Strings() 252 if len(data) < 4 { 253 return 254 } 255 for i = 1; ; { 256 dest := a.transposeIndex(i, axes, expStrides) 257 258 if track.IsSet(i) && track.IsSet(dest) { 259 data[i] = saved 260 saved = "" 261 for i < size && track.IsSet(i) { 262 i++ 263 } 264 if i >= size { 265 break 266 } 267 continue 268 } 269 track.Set(i) 270 tmp = data[i] 271 data[i] = saved 272 saved = tmp 273 274 i = dest 275 } 276 } 277 278 func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { 279 axes := a.transposeAxes() 280 size := a.len() 281 rtype := a.rtype() 282 typeSize := int(rtype.Size()) 283 284 // first we'll create a bit-map to track which elements have been moved to their correct places 285 track := NewBitMap(size) 286 track.Set(0) 287 track.Set(size - 1) // first and last element of a transposedon't change 288 289 saved := make([]byte, typeSize, typeSize) 290 tmp := make([]byte, typeSize, typeSize) 291 var i int 292 data := a.arr().Raw 293 if len(data) < 4*typeSize { 294 return 295 } 296 for i = 1; ; { 297 dest := a.transposeIndex(i, axes, expStrides) 298 start := typeSize * i 299 end := start + typeSize 300 301 if track.IsSet(i) && track.IsSet(dest) { 302 copy(data[start:end], saved) 303 for i := range saved { 304 saved[i] = 0 305 } 306 for i < size && track.IsSet(i) { 307 i++ 308 } 309 if i >= size { 310 break 311 } 312 continue 313 } 314 track.Set(i) 315 copy(tmp, data[start:end]) 316 copy(data[start:end], saved) 317 copy(saved, tmp) 318 i = dest 319 } 320 }