github.com/wzzhu/tensor@v0.9.24/defaultengine_matop_transpose_inplace.go (about)

     1  // +build inplacetranspose
     2  
     3  package tensor
     4  
     5  import (
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  func (e StdEng) Transpose(a Tensor, expStrides []int) error {
    10  	if !a.IsNativelyAccessible() {
    11  		return errors.Errorf("Cannot Transpose() on non-natively accessible tensor")
    12  	}
    13  	if dt, ok := a.(DenseTensor); ok {
    14  		e.denseTranspose(dt, expStrides)
    15  		return nil
    16  	}
    17  	return errors.Errorf("Tranpose for tensor of %T not supported", a)
    18  }
    19  
    20  func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) {
    21  	if a.rtype() == String.Type {
    22  		e.denseTransposeString(a, expStrides)
    23  		return
    24  	}
    25  
    26  	e.transposeMask(a)
    27  
    28  	switch a.rtype().Size() {
    29  	case 1:
    30  		e.denseTranspose1(a, expStrides)
    31  	case 2:
    32  		e.denseTranspose2(a, expStrides)
    33  	case 4:
    34  		e.denseTranspose4(a, expStrides)
    35  	case 8:
    36  		e.denseTranspose8(a, expStrides)
    37  	default:
    38  		e.denseTransposeArbitrary(a, expStrides)
    39  	}
    40  }
    41  
    42  func (e StdEng) transposeMask(a DenseTensor) {
    43  	if !a.(*Dense).IsMasked() {
    44  		return
    45  	}
    46  
    47  	shape := a.Shape()
    48  	if len(shape) != 2 {
    49  		// TODO(poopoothegorilla): currently only two dimensions are implemented
    50  		return
    51  	}
    52  	n, m := shape[0], shape[1]
    53  	mask := a.(*Dense).Mask()
    54  	size := len(mask)
    55  
    56  	track := NewBitMap(size)
    57  	track.Set(0)
    58  	track.Set(size - 1)
    59  
    60  	for i := 0; i < size; i++ {
    61  		srci := i
    62  		if track.IsSet(srci) {
    63  			continue
    64  		}
    65  		srcv := mask[srci]
    66  		for {
    67  			oc := srci % n
    68  			or := (srci - oc) / n
    69  			desti := oc*m + or
    70  
    71  			if track.IsSet(desti) {
    72  				break
    73  			}
    74  			track.Set(desti)
    75  			destv := mask[desti]
    76  			mask[desti] = srcv
    77  			srci = desti
    78  			srcv = destv
    79  		}
    80  	}
    81  }
    82  
    83  func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) {
    84  	axes := a.transposeAxes()
    85  	size := a.len()
    86  
    87  	// first we'll create a bit-map to track which elements have been moved to their correct places
    88  	track := NewBitMap(size)
    89  	track.Set(0)
    90  	track.Set(size - 1) // first and last element of a transposedon't change
    91  
    92  	var saved, tmp byte
    93  	var i int
    94  
    95  	data := a.hdr().Uint8s()
    96  	if len(data) < 4 {
    97  		return
    98  	}
    99  	for i = 1; ; {
   100  		dest := a.transposeIndex(i, axes, expStrides)
   101  
   102  		if track.IsSet(i) && track.IsSet(dest) {
   103  			data[i] = saved
   104  			saved = 0
   105  			for i < size && track.IsSet(i) {
   106  				i++
   107  			}
   108  			if i >= size {
   109  				break
   110  			}
   111  			continue
   112  		}
   113  		track.Set(i)
   114  		tmp = data[i]
   115  		data[i] = saved
   116  		saved = tmp
   117  
   118  		i = dest
   119  	}
   120  }
   121  
   122  func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) {
   123  	axes := a.transposeAxes()
   124  	size := a.len()
   125  
   126  	// first we'll create a bit-map to track which elements have been moved to their correct places
   127  	track := NewBitMap(size)
   128  	track.Set(0)
   129  	track.Set(size - 1) // first and last element of a transposedon't change
   130  
   131  	var saved, tmp uint16
   132  	var i int
   133  
   134  	data := a.hdr().Uint16s()
   135  	if len(data) < 4 {
   136  		return
   137  	}
   138  	for i = 1; ; {
   139  		dest := a.transposeIndex(i, axes, expStrides)
   140  
   141  		if track.IsSet(i) && track.IsSet(dest) {
   142  			data[i] = saved
   143  			saved = 0
   144  			for i < size && track.IsSet(i) {
   145  				i++
   146  			}
   147  			if i >= size {
   148  				break
   149  			}
   150  			continue
   151  		}
   152  		track.Set(i)
   153  		tmp = data[i]
   154  		data[i] = saved
   155  		saved = tmp
   156  
   157  		i = dest
   158  	}
   159  }
   160  
   161  func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) {
   162  	axes := a.transposeAxes()
   163  	size := a.len()
   164  
   165  	// first we'll create a bit-map to track which elements have been moved to their correct places
   166  	track := NewBitMap(size)
   167  	track.Set(0)
   168  	track.Set(size - 1) // first and last element of a transposedon't change
   169  
   170  	var saved, tmp uint32
   171  	var i int
   172  
   173  	data := a.hdr().Uint32s()
   174  	if len(data) < 4 {
   175  		return
   176  	}
   177  	for i = 1; ; {
   178  		dest := a.transposeIndex(i, axes, expStrides)
   179  
   180  		if track.IsSet(i) && track.IsSet(dest) {
   181  			data[i] = saved
   182  			saved = 0
   183  			for i < size && track.IsSet(i) {
   184  				i++
   185  			}
   186  			if i >= size {
   187  				break
   188  			}
   189  			continue
   190  		}
   191  		track.Set(i)
   192  		tmp = data[i]
   193  		data[i] = saved
   194  		saved = tmp
   195  
   196  		i = dest
   197  	}
   198  }
   199  
   200  func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) {
   201  	axes := a.transposeAxes()
   202  	size := a.len()
   203  
   204  	// first we'll create a bit-map to track which elements have been moved to their correct places
   205  	track := NewBitMap(size)
   206  	track.Set(0)
   207  	track.Set(size - 1) // first and last element of a transposedon't change
   208  
   209  	var saved, tmp uint64
   210  	var i int
   211  
   212  	data := a.hdr().Uint64s()
   213  	if len(data) < 4 {
   214  		return
   215  	}
   216  	for i = 1; ; {
   217  		dest := a.transposeIndex(i, axes, expStrides)
   218  		if track.IsSet(i) && track.IsSet(dest) {
   219  			data[i] = saved
   220  			saved = 0
   221  			for i < size && track.IsSet(i) {
   222  				i++
   223  			}
   224  			if i >= size {
   225  				break
   226  			}
   227  			continue
   228  		}
   229  		track.Set(i)
   230  		// log.Printf("i: %d start %d, end %d | tmp %v saved %v", i, start, end, tmp, saved)
   231  		tmp = data[i]
   232  		data[i] = saved
   233  		saved = tmp
   234  
   235  		i = dest
   236  	}
   237  }
   238  
   239  func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) {
   240  	axes := a.transposeAxes()
   241  	size := a.len()
   242  
   243  	// first we'll create a bit-map to track which elements have been moved to their correct places
   244  	track := NewBitMap(size)
   245  	track.Set(0)
   246  	track.Set(size - 1) // first and last element of a transposedon't change
   247  
   248  	var saved, tmp string
   249  	var i int
   250  
   251  	data := a.hdr().Strings()
   252  	if len(data) < 4 {
   253  		return
   254  	}
   255  	for i = 1; ; {
   256  		dest := a.transposeIndex(i, axes, expStrides)
   257  
   258  		if track.IsSet(i) && track.IsSet(dest) {
   259  			data[i] = saved
   260  			saved = ""
   261  			for i < size && track.IsSet(i) {
   262  				i++
   263  			}
   264  			if i >= size {
   265  				break
   266  			}
   267  			continue
   268  		}
   269  		track.Set(i)
   270  		tmp = data[i]
   271  		data[i] = saved
   272  		saved = tmp
   273  
   274  		i = dest
   275  	}
   276  }
   277  
   278  func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) {
   279  	axes := a.transposeAxes()
   280  	size := a.len()
   281  	rtype := a.rtype()
   282  	typeSize := int(rtype.Size())
   283  
   284  	// first we'll create a bit-map to track which elements have been moved to their correct places
   285  	track := NewBitMap(size)
   286  	track.Set(0)
   287  	track.Set(size - 1) // first and last element of a transposedon't change
   288  
   289  	saved := make([]byte, typeSize, typeSize)
   290  	tmp := make([]byte, typeSize, typeSize)
   291  	var i int
   292  	data := a.arr().Raw
   293  	if len(data) < 4*typeSize {
   294  		return
   295  	}
   296  	for i = 1; ; {
   297  		dest := a.transposeIndex(i, axes, expStrides)
   298  		start := typeSize * i
   299  		end := start + typeSize
   300  
   301  		if track.IsSet(i) && track.IsSet(dest) {
   302  			copy(data[start:end], saved)
   303  			for i := range saved {
   304  				saved[i] = 0
   305  			}
   306  			for i < size && track.IsSet(i) {
   307  				i++
   308  			}
   309  			if i >= size {
   310  				break
   311  			}
   312  			continue
   313  		}
   314  		track.Set(i)
   315  		copy(tmp, data[start:end])
   316  		copy(data[start:end], saved)
   317  		copy(saved, tmp)
   318  		i = dest
   319  	}
   320  }