gorgonia.org/tensor@v0.9.24/defaultengine_matop_transpose.go (about)

     1  // +build !inplacetranspose
     2  
     3  package tensor
     4  
     5  import (
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  func (e StdEng) Transpose(a Tensor, expStrides []int) error {
    10  	if !a.IsNativelyAccessible() {
    11  		return errors.Errorf("Cannot Transpose() on non-natively accessible tensor")
    12  	}
    13  	if dt, ok := a.(DenseTensor); ok {
    14  		e.denseTranspose(dt, expStrides)
    15  		return nil
    16  	}
    17  	return errors.Errorf("Tranpose for tensor of %T not supported", a)
    18  }
    19  
    20  func (e StdEng) denseTranspose(a DenseTensor, expStrides []int) {
    21  	if a.rtype() == String.Type {
    22  		e.denseTransposeString(a, expStrides)
    23  		return
    24  	}
    25  
    26  	e.transposeMask(a)
    27  
    28  	switch a.rtype().Size() {
    29  	case 1:
    30  		e.denseTranspose1(a, expStrides)
    31  	case 2:
    32  		e.denseTranspose2(a, expStrides)
    33  	case 4:
    34  		e.denseTranspose4(a, expStrides)
    35  	case 8:
    36  		e.denseTranspose8(a, expStrides)
    37  	default:
    38  		e.denseTransposeArbitrary(a, expStrides)
    39  	}
    40  }
    41  
    42  func (e StdEng) transposeMask(a DenseTensor) {
    43  	if !a.(*Dense).IsMasked() {
    44  		return
    45  	}
    46  
    47  	orig := a.(*Dense).Mask()
    48  	tmp := make([]bool, len(orig))
    49  
    50  	it := newFlatIterator(a.Info())
    51  	var j int
    52  	for i, err := it.Next(); err == nil; i, err = it.Next() {
    53  		tmp[j] = orig[i]
    54  		j++
    55  	}
    56  	copy(orig, tmp)
    57  }
    58  
    59  func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) {
    60  	var tmpArr array
    61  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
    62  	u8s := tmpArr.Uint8s()
    63  
    64  	orig := a.hdr().Uint8s()
    65  	it := newFlatIterator(a.Info())
    66  	var j int
    67  	for i, err := it.Next(); err == nil; i, err = it.Next() {
    68  		u8s[j] = orig[i]
    69  		j++
    70  	}
    71  	copy(orig, u8s)
    72  }
    73  
    74  func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) {
    75  	var tmpArr array
    76  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
    77  	u16s := tmpArr.Uint16s()
    78  
    79  	orig := a.hdr().Uint16s()
    80  	it := newFlatIterator(a.Info())
    81  	var j int
    82  	for i, err := it.Next(); err == nil; i, err = it.Next() {
    83  		u16s[j] = orig[i]
    84  		j++
    85  	}
    86  	copy(orig, u16s)
    87  }
    88  
    89  func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) {
    90  	var tmpArr array
    91  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
    92  	u32s := tmpArr.Uint32s()
    93  
    94  	orig := a.hdr().Uint32s()
    95  	it := newFlatIterator(a.Info())
    96  	var j int
    97  	for i, err := it.Next(); err == nil; i, err = it.Next() {
    98  		u32s[j] = orig[i]
    99  		j++
   100  	}
   101  	copy(orig, u32s)
   102  }
   103  
   104  func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) {
   105  	var tmpArr array
   106  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
   107  	u64s := tmpArr.Uint64s()
   108  
   109  	orig := a.hdr().Uint64s()
   110  	it := newFlatIterator(a.Info())
   111  	var j int
   112  	for i, err := it.Next(); err == nil; i, err = it.Next() {
   113  		u64s[j] = orig[i]
   114  		j++
   115  	}
   116  	copy(orig, u64s)
   117  }
   118  
   119  func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) {
   120  	var tmpArr array
   121  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
   122  	strs := tmpArr.Strings()
   123  
   124  	orig := a.hdr().Strings()
   125  	it := newFlatIterator(a.Info())
   126  	var j int
   127  	for i, err := it.Next(); err == nil; i, err = it.Next() {
   128  		strs[j] = orig[i]
   129  		j++
   130  	}
   131  	copy(orig, strs)
   132  }
   133  
   134  func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) {
   135  	rtype := a.rtype()
   136  	typeSize := int(rtype.Size())
   137  	var tmpArr array
   138  	e.makeArray(&tmpArr, a.Dtype(), a.Size())
   139  	// arbs := storage.AsByteSlice(tmpArr.hdr(), rtype)
   140  	arbs := tmpArr.byteSlice()
   141  
   142  	orig := a.hdr().Raw
   143  	it := newFlatIterator(a.Info())
   144  	var j int
   145  	for i, err := it.Next(); err == nil; i, err = it.Next() {
   146  		srcStart := i * typeSize
   147  		srcEnd := srcStart + typeSize
   148  		dstStart := j * typeSize
   149  		dstEnd := dstStart + typeSize
   150  
   151  		copy(arbs[dstStart:dstEnd], orig[srcStart:srcEnd])
   152  		j++
   153  	}
   154  	copy(orig, arbs)
   155  }