github.com/wzzhu/tensor@v0.9.24/sparse.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"sort"
     7  
     8  	"github.com/pkg/errors"
     9  )
    10  
    11  var (
    12  	_ Sparse = &CS{}
    13  )
    14  
    15  // Sparse is a sparse tensor.
    16  type Sparse interface {
    17  	Tensor
    18  	Densor
    19  	NonZeroes() int // NonZeroes returns the number of nonzero values
    20  }
    21  
    22  // coo is an internal representation of the Coordinate type sparse matrix.
    23  // It's not exported because you probably shouldn't be using it.
    24  // Instead, constructors for the *CS type supports using a coordinate as an input.
    25  type coo struct {
    26  	o      DataOrder
    27  	xs, ys []int
    28  	data   array
    29  }
    30  
    31  func (c *coo) Len() int { return c.data.Len() }
    32  func (c *coo) Less(i, j int) bool {
    33  	if c.o.IsColMajor() {
    34  		return c.colMajorLess(i, j)
    35  	}
    36  	return c.rowMajorLess(i, j)
    37  }
    38  func (c *coo) Swap(i, j int) {
    39  	c.xs[i], c.xs[j] = c.xs[j], c.xs[i]
    40  	c.ys[i], c.ys[j] = c.ys[j], c.ys[i]
    41  	c.data.swap(i, j)
    42  }
    43  
    44  func (c *coo) colMajorLess(i, j int) bool {
    45  	if c.ys[i] < c.ys[j] {
    46  		return true
    47  	}
    48  	if c.ys[i] == c.ys[j] {
    49  		// check xs
    50  		if c.xs[i] <= c.xs[j] {
    51  			return true
    52  		}
    53  	}
    54  	return false
    55  }
    56  
    57  func (c *coo) rowMajorLess(i, j int) bool {
    58  	if c.xs[i] < c.xs[j] {
    59  		return true
    60  	}
    61  
    62  	if c.xs[i] == c.xs[j] {
    63  		// check ys
    64  		if c.ys[i] <= c.ys[j] {
    65  			return true
    66  		}
    67  	}
    68  	return false
    69  }
    70  
    71  // CS is a compressed sparse data structure. It can be used to represent both CSC and CSR sparse matrices.
    72  // Refer to the individual creation functions for more information.
    73  type CS struct {
    74  	s Shape
    75  	o DataOrder
    76  	e Engine
    77  	f MemoryFlag
    78  	z interface{} // z is the "zero" value. Typically it's not used.
    79  
    80  	indices []int
    81  	indptr  []int
    82  
    83  	array
    84  }
    85  
    86  // NewCSR creates a new Compressed Sparse Row matrix. The data has to be a slice or it panics.
    87  func NewCSR(indices, indptr []int, data interface{}, opts ...ConsOpt) *CS {
    88  	t := new(CS)
    89  	t.indices = indices
    90  	t.indptr = indptr
    91  	t.array = arrayFromSlice(data)
    92  	t.o = NonContiguous
    93  	t.e = StdEng{}
    94  
    95  	for _, opt := range opts {
    96  		opt(t)
    97  	}
    98  	return t
    99  }
   100  
   101  // NewCSC creates a new Compressed Sparse Column matrix. The data has to be a slice, or it panics.
   102  func NewCSC(indices, indptr []int, data interface{}, opts ...ConsOpt) *CS {
   103  	t := new(CS)
   104  	t.indices = indices
   105  	t.indptr = indptr
   106  	t.array = arrayFromSlice(data)
   107  	t.o = MakeDataOrder(ColMajor, NonContiguous)
   108  	t.e = StdEng{}
   109  
   110  	for _, opt := range opts {
   111  		opt(t)
   112  	}
   113  	return t
   114  }
   115  
   116  // CSRFromCoord creates a new Compressed Sparse Row matrix given the coordinates. The data has to be a slice or it panics.
   117  func CSRFromCoord(shape Shape, xs, ys []int, data interface{}) *CS {
   118  	t := new(CS)
   119  	t.s = shape
   120  	t.o = NonContiguous
   121  	t.array = arrayFromSlice(data)
   122  	t.e = StdEng{}
   123  
   124  	// coord matrix
   125  	cm := &coo{t.o, xs, ys, t.array}
   126  	sort.Sort(cm)
   127  
   128  	r := shape[0]
   129  	c := shape[1]
   130  	if r <= cm.xs[len(cm.xs)-1] || c <= MaxInts(cm.ys...) {
   131  		panic("Cannot create sparse matrix where provided shape is smaller than the implied shape of the data")
   132  	}
   133  
   134  	indptr := make([]int, r+1)
   135  
   136  	var i, j, tmp int
   137  	for i = 1; i < r+1; i++ {
   138  		for j = tmp; j < len(xs) && xs[j] < i; j++ {
   139  
   140  		}
   141  		tmp = j
   142  		indptr[i] = j
   143  	}
   144  	t.indices = ys
   145  	t.indptr = indptr
   146  	return t
   147  }
   148  
   149  // CSRFromCoord creates a new Compressed Sparse Column matrix given the coordinates. The data has to be a slice or it panics.
   150  func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS {
   151  	t := new(CS)
   152  	t.s = shape
   153  	t.o = MakeDataOrder(NonContiguous, ColMajor)
   154  	t.array = arrayFromSlice(data)
   155  	t.e = StdEng{}
   156  
   157  	// coord matrix
   158  	cm := &coo{t.o, xs, ys, t.array}
   159  	sort.Sort(cm)
   160  
   161  	r := shape[0]
   162  	c := shape[1]
   163  
   164  	// check shape
   165  	if r <= MaxInts(cm.xs...) || c <= cm.ys[len(cm.ys)-1] {
   166  		panic("Cannot create sparse matrix where provided shape is smaller than the implied shape of the data")
   167  	}
   168  
   169  	indptr := make([]int, c+1)
   170  
   171  	var i, j, tmp int
   172  	for i = 1; i < c+1; i++ {
   173  		for j = tmp; j < len(ys) && ys[j] < i; j++ {
   174  
   175  		}
   176  		tmp = j
   177  		indptr[i] = j
   178  	}
   179  	t.indices = xs
   180  	t.indptr = indptr
   181  	return t
   182  }
   183  
   184  func (t *CS) Shape() Shape         { return t.s }
   185  func (t *CS) Strides() []int       { return nil }
   186  func (t *CS) Dtype() Dtype         { return t.t }
   187  func (t *CS) Dims() int            { return 2 }
   188  func (t *CS) Size() int            { return t.s.TotalSize() }
   189  func (t *CS) DataSize() int        { return t.Len() }
   190  func (t *CS) Engine() Engine       { return t.e }
   191  func (t *CS) DataOrder() DataOrder { return t.o }
   192  
   193  func (t *CS) Slice(...Slice) (View, error) {
   194  	return nil, errors.Errorf("Slice for sparse tensors not implemented yet")
   195  }
   196  
   197  func (t *CS) At(coord ...int) (interface{}, error) {
   198  	if len(coord) != t.Dims() {
   199  		return nil, errors.Errorf("Expected coordinates to be of %d-dimensions. Got %v instead", t.Dims(), coord)
   200  	}
   201  	if i, ok := t.at(coord...); ok {
   202  		return t.Get(i), nil
   203  	}
   204  	if t.z == nil {
   205  		return reflect.Zero(t.t.Type).Interface(), nil
   206  	}
   207  	return t.z, nil
   208  }
   209  
   210  func (t *CS) SetAt(v interface{}, coord ...int) error {
   211  	if i, ok := t.at(coord...); ok {
   212  		t.Set(i, v)
   213  		return nil
   214  	}
   215  	return errors.Errorf("Cannot set value in a compressed sparse matrix: Coordinate %v not found", coord)
   216  }
   217  
   218  func (t *CS) Reshape(...int) error { return errors.New("compressed sparse matrix cannot be reshaped") }
   219  
   220  // T transposes the matrix. Concretely, it just changes a bit - the state goes from CSC to CSR, and vice versa.
   221  func (t *CS) T(axes ...int) error {
   222  	dims := t.Dims()
   223  	if len(axes) != dims && len(axes) != 0 {
   224  		return errors.Errorf("Cannot transpose along axes %v", axes)
   225  	}
   226  	if len(axes) == 0 || axes == nil {
   227  
   228  		axes = make([]int, dims)
   229  		for i := 0; i < dims; i++ {
   230  			axes[i] = dims - 1 - i
   231  		}
   232  	}
   233  	UnsafePermute(axes, []int(t.s))
   234  	t.o = t.o.toggleColMajor()
   235  	t.o = MakeDataOrder(t.o, Transposed)
   236  	return errors.Errorf(methodNYI, "T", t)
   237  }
   238  
   239  // UT untransposes the CS
   240  func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() }
   241  
   242  // Transpose is a no-op. The data does not move
   243  func (t *CS) Transpose() error { return nil }
   244  
   245  func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) {
   246  	return nil, errors.Errorf(methodNYI, "Apply", t)
   247  }
   248  
   249  func (t *CS) Eq(other interface{}) bool {
   250  	if ot, ok := other.(*CS); ok {
   251  		if t == ot {
   252  			return true
   253  		}
   254  
   255  		if len(ot.indices) != len(t.indices) {
   256  			return false
   257  		}
   258  		if len(ot.indptr) != len(t.indptr) {
   259  			return false
   260  		}
   261  		if !t.s.Eq(ot.s) {
   262  			return false
   263  		}
   264  		if ot.o != t.o {
   265  			return false
   266  		}
   267  		for i, ind := range t.indices {
   268  			if ot.indices[i] != ind {
   269  				return false
   270  			}
   271  		}
   272  		for i, ind := range t.indptr {
   273  			if ot.indptr[i] != ind {
   274  				return false
   275  			}
   276  		}
   277  		return t.array.Eq(&ot.array)
   278  	}
   279  	return false
   280  }
   281  
   282  func (t *CS) Clone() interface{} {
   283  	retVal := new(CS)
   284  	retVal.s = t.s.Clone()
   285  	retVal.o = t.o
   286  	retVal.e = t.e
   287  	retVal.indices = make([]int, len(t.indices))
   288  	retVal.indptr = make([]int, len(t.indptr))
   289  	copy(retVal.indices, t.indices)
   290  	copy(retVal.indptr, t.indptr)
   291  	retVal.array = makeArray(t.t, t.array.Len())
   292  	copyArray(&retVal.array, &t.array)
   293  	retVal.e = t.e
   294  	return retVal
   295  }
   296  
   297  func (t *CS) IsScalar() bool           { return false }
   298  func (t *CS) ScalarValue() interface{} { panic("Sparse Matrices cannot represent Scalar Values") }
   299  
   300  func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.Len())) }
   301  func (t *CS) Uintptr() uintptr { return t.array.Uintptr() }
   302  
   303  // NonZeroes returns the nonzeroes. In academic literature this is often written as NNZ.
   304  func (t *CS) NonZeroes() int         { return t.Len() }
   305  func (t *CS) RequiresIterator() bool { return true }
   306  func (t *CS) Iterator() Iterator     { return NewFlatSparseIterator(t) }
   307  
   308  func (t *CS) at(coord ...int) (int, bool) {
   309  	var r, c int
   310  	if t.o.IsColMajor() {
   311  		r = coord[1]
   312  		c = coord[0]
   313  	} else {
   314  		r = coord[0]
   315  		c = coord[1]
   316  	}
   317  
   318  	for i := t.indptr[r]; i < t.indptr[r+1]; i++ {
   319  		if t.indices[i] == c {
   320  			return i, true
   321  		}
   322  	}
   323  	return -1, false
   324  }
   325  
   326  // Dense creates a Dense tensor from the compressed one.
   327  func (t *CS) Dense() *Dense {
   328  	if t.e != nil && t.e != (StdEng{}) {
   329  		// use
   330  	}
   331  
   332  	d := recycledDense(t.t, t.Shape().Clone(), WithEngine(t.e))
   333  	if t.o.IsColMajor() {
   334  		for i := 0; i < len(t.indptr)-1; i++ {
   335  			for j := t.indptr[i]; j < t.indptr[i+1]; j++ {
   336  				d.SetAt(t.Get(j), t.indices[j], i)
   337  			}
   338  		}
   339  	} else {
   340  		for i := 0; i < len(t.indptr)-1; i++ {
   341  			for j := t.indptr[i]; j < t.indptr[i+1]; j++ {
   342  				d.SetAt(t.Get(j), i, t.indices[j])
   343  			}
   344  		}
   345  	}
   346  	return d
   347  }
   348  
   349  // Other Accessors
   350  
   351  func (t *CS) Indptr() []int {
   352  	retVal := BorrowInts(len(t.indptr))
   353  	copy(retVal, t.indptr)
   354  	return retVal
   355  }
   356  
   357  func (t *CS) Indices() []int {
   358  	retVal := BorrowInts(len(t.indices))
   359  	copy(retVal, t.indices)
   360  	return retVal
   361  }
   362  
   363  func (t *CS) AsCSR() {
   364  	if t.o.IsRowMajor() {
   365  		return
   366  	}
   367  	t.o.toggleColMajor()
   368  }
   369  
   370  func (t *CS) AsCSC() {
   371  	if t.o.IsColMajor() {
   372  		return
   373  	}
   374  	t.o.toggleColMajor()
   375  }
   376  
   377  func (t *CS) IsNativelyAccessible() bool { return t.f.nativelyAccessible() }
   378  func (t *CS) IsManuallyManaged() bool    { return t.f.manuallyManaged() }
   379  
   380  func (t *CS) arr() array                     { return t.array }
   381  func (t *CS) arrPtr() *array                 { return &t.array }
   382  func (t *CS) standardEngine() standardEngine { return nil }