gorgonia.org/gorgonia@v0.9.17/engine.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gorgonia.org/tensor"
     6  )
     7  
     8  // StandardEngine is the default CPU engine for gorgonia
     9  type StandardEngine struct {
    10  	tensor.StdEng
    11  }
    12  
    13  // Transpose tensor a according to expStrides
    14  func (e StandardEngine) Transpose(a tensor.Tensor, expStrides []int) error {
    15  	if !a.IsNativelyAccessible() {
    16  		return errors.Errorf("Cannot Transpose() on non-natively accessible tensor")
    17  	}
    18  	size := a.DataSize()
    19  	it := a.Iterator()
    20  	var i int
    21  	switch a.Dtype() {
    22  	case tensor.Float64:
    23  		tmp := make([]float64, size)
    24  		data := a.Data().([]float64)
    25  		for next, err := it.Next(); err == nil; next, err = it.Next() {
    26  			tmp[i] = data[next]
    27  			i++
    28  		}
    29  		copy(data, tmp)
    30  	case tensor.Float32:
    31  		tmp := make([]float32, size)
    32  		data := a.Data().([]float32)
    33  		for next, err := it.Next(); err == nil; next, err = it.Next() {
    34  			tmp[i] = data[next]
    35  			i++
    36  		}
    37  		copy(data, tmp)
    38  	default:
    39  		return e.StdEng.Transpose(a, expStrides)
    40  	}
    41  	return nil
    42  }