github.com/wzzhu/tensor@v0.9.24/consopt.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/wzzhu/tensor/internal/storage"
     7  )
     8  
     9  // ConsOpt is a tensor construction option.
    10  type ConsOpt func(Tensor)
    11  
    12  // Of is a construction option for a Tensor.
    13  func Of(a Dtype) ConsOpt {
    14  	Register(a)
    15  	f := func(t Tensor) {
    16  		switch tt := t.(type) {
    17  		case *Dense:
    18  			tt.t = a
    19  		case *CS:
    20  			tt.t = a
    21  		default:
    22  			panic("Unsupported Tensor type")
    23  		}
    24  	}
    25  	return f
    26  }
    27  
    28  // WithBacking is a construction option for a Tensor
    29  // Use it as such:
    30  //		backing := []float64{1,2,3,4}
    31  // 		t := New(WithBacking(backing))
    32  // It can be used with other construction options like WithShape
    33  func WithBacking(x interface{}, argMask ...[]bool) ConsOpt {
    34  	var mask []bool
    35  	if len(argMask) > 0 {
    36  		mask = argMask[0]
    37  	}
    38  	f := func(t Tensor) {
    39  		if x == nil {
    40  			return
    41  		}
    42  		switch tt := t.(type) {
    43  		case *Dense:
    44  			tt.fromSlice(x)
    45  			if len(argMask) > 0 {
    46  				tt.addMask(mask)
    47  			}
    48  		default:
    49  			panic("Unsupported Tensor type")
    50  		}
    51  	}
    52  	return f
    53  }
    54  
    55  // WithMask is a construction option for a Tensor
    56  // Use it as such:
    57  //		mask := []bool{true,true,false,false}
    58  // 		t := New(WithBacking(backing), WithMask(mask))
    59  // It can be used with other construction options like WithShape
    60  // The supplied mask can be any type. If non-boolean, then tensor mask is set to true
    61  // wherever non-zero value is obtained
    62  func WithMask(x interface{}) ConsOpt {
    63  	f := func(t Tensor) {
    64  		if x == nil {
    65  			return
    66  		}
    67  		switch tt := t.(type) {
    68  		case *Dense:
    69  			tt.MaskFromSlice(x)
    70  		default:
    71  			panic("Unsupported Tensor type")
    72  		}
    73  	}
    74  	return f
    75  }
    76  
    77  // WithShape is a construction option for a Tensor. It creates the ndarray in the required shape.
    78  func WithShape(dims ...int) ConsOpt {
    79  	f := func(t Tensor) {
    80  		switch tt := t.(type) {
    81  		case *Dense:
    82  			throw := BorrowInts(len(dims))
    83  			copy(throw, dims)
    84  			tt.setShape(throw...)
    85  		case *CS:
    86  			if len(dims) != 2 {
    87  				panic("Only sparse matrices are supported")
    88  			}
    89  			throw := BorrowInts(len(dims))
    90  			copy(throw, dims)
    91  			tt.s = throw
    92  
    93  		default:
    94  			panic("Unsupported Tensor type")
    95  		}
    96  	}
    97  	return f
    98  }
    99  
   100  // FromScalar is a construction option for representing a scalar value as a Tensor
   101  func FromScalar(x interface{}, argMask ...[]bool) ConsOpt {
   102  	var mask []bool
   103  	if len(argMask) > 0 {
   104  		mask = argMask[0]
   105  	}
   106  
   107  	f := func(t Tensor) {
   108  		switch tt := t.(type) {
   109  		case *Dense:
   110  			xT := reflect.TypeOf(x)
   111  			sxT := reflect.SliceOf(xT)
   112  			xv := reflect.MakeSlice(sxT, 1, 1) // []T
   113  			xv0 := xv.Index(0)                 // xv[0]
   114  			xv0.Set(reflect.ValueOf(x))
   115  			tt.array.Header.Raw = storage.AsByteSlice(xv.Interface())
   116  			tt.t = Dtype{xT}
   117  			tt.mask = mask
   118  
   119  		default:
   120  			panic("Unsupported Tensor Type")
   121  		}
   122  	}
   123  	return f
   124  }
   125  
   126  // FromMemory is a construction option for creating a *Dense (for now) from memory location. This is a useful
   127  // option for super large tensors that don't fit into memory - the user may need to `mmap` a file the tensor.
   128  //
   129  // Bear in mind that at the current stage of the ConsOpt design, the order of the ConsOpt is important.
   130  // FromMemory  requires the *Dense's Dtype be set already.
   131  // This would fail (and panic):
   132  //		New(FromMemory(ptr, size), Of(Float64))
   133  // This would not:
   134  //		New(Of(Float64), FromMemory(ptr, size))
   135  // This behaviour  of  requiring the ConsOpts to be in order might be changed in the future.
   136  //
   137  // Memory must be manually managed by the caller.
   138  // Tensors called with this construction option will not be returned to any pool - rather, all references to the pointers will be null'd.
   139  // Use with caution.
   140  //go:nocheckptr
   141  func FromMemory(ptr uintptr, memsize uintptr) ConsOpt {
   142  	f := func(t Tensor) {
   143  		switch tt := t.(type) {
   144  		case *Dense:
   145  			tt.Header.Raw = nil // GC anything if needed
   146  			tt.Header.Raw = storage.FromMemory(ptr, memsize)
   147  			tt.flag = MakeMemoryFlag(tt.flag, ManuallyManaged)
   148  		default:
   149  			panic("Unsupported Tensor type")
   150  		}
   151  	}
   152  	return f
   153  }
   154  
   155  // WithEngine is a construction option that would cause a Tensor to be linked with an execution engine.
   156  func WithEngine(e Engine) ConsOpt {
   157  	f := func(t Tensor) {
   158  		switch tt := t.(type) {
   159  		case *Dense:
   160  			tt.e = e
   161  			if e != nil && !e.AllocAccessible() {
   162  				tt.flag = MakeMemoryFlag(tt.flag, NativelyInaccessible)
   163  			}
   164  
   165  			tt.oe = nil
   166  			if oe, ok := e.(standardEngine); ok {
   167  				tt.oe = oe
   168  			}
   169  		case *CS:
   170  			tt.e = e
   171  			if e != nil && !e.AllocAccessible() {
   172  				tt.f = MakeMemoryFlag(tt.f, NativelyInaccessible)
   173  			}
   174  		}
   175  	}
   176  	return f
   177  }
   178  
   179  // AsFortran creates a *Dense with a col-major layout.
   180  // If the optional backing argument is passed, the backing is assumed to be C-order (row major), and
   181  // it will be transposed before being used.
   182  func AsFortran(backing interface{}, argMask ...[]bool) ConsOpt {
   183  	var mask []bool
   184  	if len(argMask) > 0 {
   185  		mask = argMask[0]
   186  	}
   187  	f := func(t Tensor) {
   188  		switch tt := t.(type) {
   189  		case *Dense:
   190  			if backing != nil {
   191  				// put the data into the tensor, then make a clone tensor to transpose
   192  				tt.fromSliceOrArrayer(backing)
   193  				// create a temporary tensor, to which the transpose will be done
   194  				tmp := NewDense(tt.Dtype(), tt.shape.Clone())
   195  				copyArray(tmp.arrPtr(), tt.arrPtr())
   196  				tmp.SetMask(mask)
   197  				tmp.T()
   198  				tmp.Transpose()
   199  				// copy the data back to the current tensor
   200  				copyArray(tt.arrPtr(), tmp.arrPtr())
   201  				tt.SetMask(tmp.Mask())
   202  				// cleanup: return the temporary tensor back to the pool
   203  				ReturnTensor(tmp)
   204  			}
   205  
   206  			tt.AP.o = MakeDataOrder(tt.AP.o, ColMajor)
   207  			if tt.AP.shape != nil {
   208  				ReturnInts(tt.AP.strides)
   209  				tt.AP.strides = nil
   210  				tt.AP.strides = tt.AP.calcStrides()
   211  			}
   212  		case *CS:
   213  			panic("AsFortran is not an available option for Compressed Sparse layouts")
   214  		}
   215  	}
   216  	return f
   217  }
   218  
   219  func AsDenseDiag(backing interface{}) ConsOpt {
   220  	f := func(t Tensor) {
   221  		switch tt := t.(type) {
   222  		case *Dense:
   223  			if bt, ok := backing.(Tensor); ok {
   224  				backing = bt.Data()
   225  			}
   226  			xT := reflect.TypeOf(backing)
   227  			if xT.Kind() != reflect.Slice {
   228  				panic("Expected a slice")
   229  			}
   230  			xV := reflect.ValueOf(backing)
   231  			l := xV.Len()
   232  			// elT := xT.Elem()
   233  
   234  			sli := reflect.MakeSlice(xT, l*l, l*l)
   235  
   236  			shape := Shape{l, l}
   237  			strides := shape.CalcStrides()
   238  			for i := 0; i < l; i++ {
   239  				idx, err := Ltoi(shape, strides, i, i)
   240  				if err != nil {
   241  					panic(err)
   242  				}
   243  
   244  				at := sli.Index(idx)
   245  				xi := xV.Index(i)
   246  				at.Set(xi)
   247  			}
   248  
   249  			tt.fromSliceOrArrayer(sli.Interface())
   250  			tt.setShape(l, l)
   251  
   252  		default:
   253  			panic("AsDenseDiag is not available as an option for CS")
   254  		}
   255  	}
   256  	return f
   257  }