github.com/wzzhu/tensor@v0.9.24/sparse.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 6 "sort" 7 8 "github.com/pkg/errors" 9 ) 10 11 var ( 12 _ Sparse = &CS{} 13 ) 14 15 // Sparse is a sparse tensor. 16 type Sparse interface { 17 Tensor 18 Densor 19 NonZeroes() int // NonZeroes returns the number of nonzero values 20 } 21 22 // coo is an internal representation of the Coordinate type sparse matrix. 23 // It's not exported because you probably shouldn't be using it. 24 // Instead, constructors for the *CS type supports using a coordinate as an input. 25 type coo struct { 26 o DataOrder 27 xs, ys []int 28 data array 29 } 30 31 func (c *coo) Len() int { return c.data.Len() } 32 func (c *coo) Less(i, j int) bool { 33 if c.o.IsColMajor() { 34 return c.colMajorLess(i, j) 35 } 36 return c.rowMajorLess(i, j) 37 } 38 func (c *coo) Swap(i, j int) { 39 c.xs[i], c.xs[j] = c.xs[j], c.xs[i] 40 c.ys[i], c.ys[j] = c.ys[j], c.ys[i] 41 c.data.swap(i, j) 42 } 43 44 func (c *coo) colMajorLess(i, j int) bool { 45 if c.ys[i] < c.ys[j] { 46 return true 47 } 48 if c.ys[i] == c.ys[j] { 49 // check xs 50 if c.xs[i] <= c.xs[j] { 51 return true 52 } 53 } 54 return false 55 } 56 57 func (c *coo) rowMajorLess(i, j int) bool { 58 if c.xs[i] < c.xs[j] { 59 return true 60 } 61 62 if c.xs[i] == c.xs[j] { 63 // check ys 64 if c.ys[i] <= c.ys[j] { 65 return true 66 } 67 } 68 return false 69 } 70 71 // CS is a compressed sparse data structure. It can be used to represent both CSC and CSR sparse matrices. 72 // Refer to the individual creation functions for more information. 73 type CS struct { 74 s Shape 75 o DataOrder 76 e Engine 77 f MemoryFlag 78 z interface{} // z is the "zero" value. Typically it's not used. 79 80 indices []int 81 indptr []int 82 83 array 84 } 85 86 // NewCSR creates a new Compressed Sparse Row matrix. The data has to be a slice or it panics. 87 func NewCSR(indices, indptr []int, data interface{}, opts ...ConsOpt) *CS { 88 t := new(CS) 89 t.indices = indices 90 t.indptr = indptr 91 t.array = arrayFromSlice(data) 92 t.o = NonContiguous 93 t.e = StdEng{} 94 95 for _, opt := range opts { 96 opt(t) 97 } 98 return t 99 } 100 101 // NewCSC creates a new Compressed Sparse Column matrix. The data has to be a slice, or it panics. 102 func NewCSC(indices, indptr []int, data interface{}, opts ...ConsOpt) *CS { 103 t := new(CS) 104 t.indices = indices 105 t.indptr = indptr 106 t.array = arrayFromSlice(data) 107 t.o = MakeDataOrder(ColMajor, NonContiguous) 108 t.e = StdEng{} 109 110 for _, opt := range opts { 111 opt(t) 112 } 113 return t 114 } 115 116 // CSRFromCoord creates a new Compressed Sparse Row matrix given the coordinates. The data has to be a slice or it panics. 117 func CSRFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { 118 t := new(CS) 119 t.s = shape 120 t.o = NonContiguous 121 t.array = arrayFromSlice(data) 122 t.e = StdEng{} 123 124 // coord matrix 125 cm := &coo{t.o, xs, ys, t.array} 126 sort.Sort(cm) 127 128 r := shape[0] 129 c := shape[1] 130 if r <= cm.xs[len(cm.xs)-1] || c <= MaxInts(cm.ys...) { 131 panic("Cannot create sparse matrix where provided shape is smaller than the implied shape of the data") 132 } 133 134 indptr := make([]int, r+1) 135 136 var i, j, tmp int 137 for i = 1; i < r+1; i++ { 138 for j = tmp; j < len(xs) && xs[j] < i; j++ { 139 140 } 141 tmp = j 142 indptr[i] = j 143 } 144 t.indices = ys 145 t.indptr = indptr 146 return t 147 } 148 149 // CSRFromCoord creates a new Compressed Sparse Column matrix given the coordinates. The data has to be a slice or it panics. 150 func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { 151 t := new(CS) 152 t.s = shape 153 t.o = MakeDataOrder(NonContiguous, ColMajor) 154 t.array = arrayFromSlice(data) 155 t.e = StdEng{} 156 157 // coord matrix 158 cm := &coo{t.o, xs, ys, t.array} 159 sort.Sort(cm) 160 161 r := shape[0] 162 c := shape[1] 163 164 // check shape 165 if r <= MaxInts(cm.xs...) || c <= cm.ys[len(cm.ys)-1] { 166 panic("Cannot create sparse matrix where provided shape is smaller than the implied shape of the data") 167 } 168 169 indptr := make([]int, c+1) 170 171 var i, j, tmp int 172 for i = 1; i < c+1; i++ { 173 for j = tmp; j < len(ys) && ys[j] < i; j++ { 174 175 } 176 tmp = j 177 indptr[i] = j 178 } 179 t.indices = xs 180 t.indptr = indptr 181 return t 182 } 183 184 func (t *CS) Shape() Shape { return t.s } 185 func (t *CS) Strides() []int { return nil } 186 func (t *CS) Dtype() Dtype { return t.t } 187 func (t *CS) Dims() int { return 2 } 188 func (t *CS) Size() int { return t.s.TotalSize() } 189 func (t *CS) DataSize() int { return t.Len() } 190 func (t *CS) Engine() Engine { return t.e } 191 func (t *CS) DataOrder() DataOrder { return t.o } 192 193 func (t *CS) Slice(...Slice) (View, error) { 194 return nil, errors.Errorf("Slice for sparse tensors not implemented yet") 195 } 196 197 func (t *CS) At(coord ...int) (interface{}, error) { 198 if len(coord) != t.Dims() { 199 return nil, errors.Errorf("Expected coordinates to be of %d-dimensions. Got %v instead", t.Dims(), coord) 200 } 201 if i, ok := t.at(coord...); ok { 202 return t.Get(i), nil 203 } 204 if t.z == nil { 205 return reflect.Zero(t.t.Type).Interface(), nil 206 } 207 return t.z, nil 208 } 209 210 func (t *CS) SetAt(v interface{}, coord ...int) error { 211 if i, ok := t.at(coord...); ok { 212 t.Set(i, v) 213 return nil 214 } 215 return errors.Errorf("Cannot set value in a compressed sparse matrix: Coordinate %v not found", coord) 216 } 217 218 func (t *CS) Reshape(...int) error { return errors.New("compressed sparse matrix cannot be reshaped") } 219 220 // T transposes the matrix. Concretely, it just changes a bit - the state goes from CSC to CSR, and vice versa. 221 func (t *CS) T(axes ...int) error { 222 dims := t.Dims() 223 if len(axes) != dims && len(axes) != 0 { 224 return errors.Errorf("Cannot transpose along axes %v", axes) 225 } 226 if len(axes) == 0 || axes == nil { 227 228 axes = make([]int, dims) 229 for i := 0; i < dims; i++ { 230 axes[i] = dims - 1 - i 231 } 232 } 233 UnsafePermute(axes, []int(t.s)) 234 t.o = t.o.toggleColMajor() 235 t.o = MakeDataOrder(t.o, Transposed) 236 return errors.Errorf(methodNYI, "T", t) 237 } 238 239 // UT untransposes the CS 240 func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } 241 242 // Transpose is a no-op. The data does not move 243 func (t *CS) Transpose() error { return nil } 244 245 func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { 246 return nil, errors.Errorf(methodNYI, "Apply", t) 247 } 248 249 func (t *CS) Eq(other interface{}) bool { 250 if ot, ok := other.(*CS); ok { 251 if t == ot { 252 return true 253 } 254 255 if len(ot.indices) != len(t.indices) { 256 return false 257 } 258 if len(ot.indptr) != len(t.indptr) { 259 return false 260 } 261 if !t.s.Eq(ot.s) { 262 return false 263 } 264 if ot.o != t.o { 265 return false 266 } 267 for i, ind := range t.indices { 268 if ot.indices[i] != ind { 269 return false 270 } 271 } 272 for i, ind := range t.indptr { 273 if ot.indptr[i] != ind { 274 return false 275 } 276 } 277 return t.array.Eq(&ot.array) 278 } 279 return false 280 } 281 282 func (t *CS) Clone() interface{} { 283 retVal := new(CS) 284 retVal.s = t.s.Clone() 285 retVal.o = t.o 286 retVal.e = t.e 287 retVal.indices = make([]int, len(t.indices)) 288 retVal.indptr = make([]int, len(t.indptr)) 289 copy(retVal.indices, t.indices) 290 copy(retVal.indptr, t.indptr) 291 retVal.array = makeArray(t.t, t.array.Len()) 292 copyArray(&retVal.array, &t.array) 293 retVal.e = t.e 294 return retVal 295 } 296 297 func (t *CS) IsScalar() bool { return false } 298 func (t *CS) ScalarValue() interface{} { panic("Sparse Matrices cannot represent Scalar Values") } 299 300 func (t *CS) MemSize() uintptr { return uintptr(calcMemSize(t.t, t.array.Len())) } 301 func (t *CS) Uintptr() uintptr { return t.array.Uintptr() } 302 303 // NonZeroes returns the nonzeroes. In academic literature this is often written as NNZ. 304 func (t *CS) NonZeroes() int { return t.Len() } 305 func (t *CS) RequiresIterator() bool { return true } 306 func (t *CS) Iterator() Iterator { return NewFlatSparseIterator(t) } 307 308 func (t *CS) at(coord ...int) (int, bool) { 309 var r, c int 310 if t.o.IsColMajor() { 311 r = coord[1] 312 c = coord[0] 313 } else { 314 r = coord[0] 315 c = coord[1] 316 } 317 318 for i := t.indptr[r]; i < t.indptr[r+1]; i++ { 319 if t.indices[i] == c { 320 return i, true 321 } 322 } 323 return -1, false 324 } 325 326 // Dense creates a Dense tensor from the compressed one. 327 func (t *CS) Dense() *Dense { 328 if t.e != nil && t.e != (StdEng{}) { 329 // use 330 } 331 332 d := recycledDense(t.t, t.Shape().Clone(), WithEngine(t.e)) 333 if t.o.IsColMajor() { 334 for i := 0; i < len(t.indptr)-1; i++ { 335 for j := t.indptr[i]; j < t.indptr[i+1]; j++ { 336 d.SetAt(t.Get(j), t.indices[j], i) 337 } 338 } 339 } else { 340 for i := 0; i < len(t.indptr)-1; i++ { 341 for j := t.indptr[i]; j < t.indptr[i+1]; j++ { 342 d.SetAt(t.Get(j), i, t.indices[j]) 343 } 344 } 345 } 346 return d 347 } 348 349 // Other Accessors 350 351 func (t *CS) Indptr() []int { 352 retVal := BorrowInts(len(t.indptr)) 353 copy(retVal, t.indptr) 354 return retVal 355 } 356 357 func (t *CS) Indices() []int { 358 retVal := BorrowInts(len(t.indices)) 359 copy(retVal, t.indices) 360 return retVal 361 } 362 363 func (t *CS) AsCSR() { 364 if t.o.IsRowMajor() { 365 return 366 } 367 t.o.toggleColMajor() 368 } 369 370 func (t *CS) AsCSC() { 371 if t.o.IsColMajor() { 372 return 373 } 374 t.o.toggleColMajor() 375 } 376 377 func (t *CS) IsNativelyAccessible() bool { return t.f.nativelyAccessible() } 378 func (t *CS) IsManuallyManaged() bool { return t.f.manuallyManaged() } 379 380 func (t *CS) arr() array { return t.array } 381 func (t *CS) arrPtr() *array { return &t.array } 382 func (t *CS) standardEngine() standardEngine { return nil }