gorgonia.org/tensor@v0.9.24/ap.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  // An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides
    10  // Through the AP, there are several definitions of things, most notably there are two very specific "special cases":
    11  //		Scalar has Dims() of 0.
    12  //			- (1)
    13  //		Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0.
    14  //			- (1, 1)
    15  //			- (1, 1, 1)
    16  //			- (1, 1, 1, 1), etc
    17  //		Vector has Dims() of 1, but its shape can take several forms:
    18  //			- (x, 1)
    19  //			- (1, x)
    20  //			- (x)
    21  //		Matrix has Dims() of 2. This is the most basic form. The len(shape) has to be equal to 2 as well
    22  //		ndarray has Dims() of n.
    23  type AP struct {
    24  	shape   Shape // len(shape) is the operational definition of the dimensions
    25  	strides []int // strides is usually calculated from shape
    26  	fin     bool  // is this struct change-proof?
    27  
    28  	o DataOrder
    29  	Δ Triangle
    30  }
    31  
    32  func makeAP(size int) AP {
    33  	return AP{
    34  		shape:   Shape(BorrowInts(size)),
    35  		strides: BorrowInts(size),
    36  	}
    37  }
    38  
    39  // MakeAP creates an AP, given the shape and strides.
    40  func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP {
    41  	return AP{
    42  		shape:   shape,
    43  		strides: strides,
    44  		o:       o,
    45  		Δ:       Δ,
    46  		fin:     true,
    47  	}
    48  }
    49  
    50  // Init initializes an already created AP with a shape and stries.
    51  // It will panic if AP is nil.
    52  func (ap *AP) Init(shape Shape, strides []int) {
    53  	ap.shape = shape
    54  	ap.strides = strides
    55  	ap.fin = true
    56  }
    57  
    58  // SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff
    59  //
    60  // Caveats:
    61  //
    62  // - SetShape will recalculate the strides.
    63  //
    64  // - If the AP is locked, nothing will happen
    65  func (ap *AP) SetShape(s ...int) {
    66  	if !ap.fin {
    67  		// scalars are a special case, we don't want to remove it completely
    68  		if len(s) == 0 {
    69  			if ap.shape == nil || ap.strides == nil {
    70  				ap.shape = Shape{}
    71  			}
    72  			ap.shape = ap.shape[:0]
    73  			ap.strides = ap.strides[:0]
    74  			return
    75  		}
    76  
    77  		if ap.shape != nil {
    78  			ReturnInts(ap.shape)
    79  			ap.shape = nil
    80  		}
    81  		if ap.strides != nil {
    82  			ReturnInts(ap.strides)
    83  			ap.strides = nil
    84  		}
    85  		ap.shape = Shape(s).Clone()
    86  		ap.strides = ap.calcStrides()
    87  	}
    88  }
    89  
    90  // Shape returns the shape of the AP
    91  func (ap *AP) Shape() Shape { return ap.shape }
    92  
    93  // Strides returns the strides of the AP
    94  func (ap *AP) Strides() []int { return ap.strides }
    95  
    96  // Dims returns the dimensions of the shape in the AP
    97  func (ap *AP) Dims() int { return ap.shape.Dims() }
    98  
    99  // Size returns the expected array size of the shape
   100  func (ap *AP) Size() int { return ap.shape.TotalSize() }
   101  
   102  // String implements fmt.Stringer and runtime.Stringer
   103  func (ap *AP) String() string { return fmt.Sprintf("%v", ap) }
   104  
   105  // Format implements fmt.Formatter
   106  func (ap *AP) Format(state fmt.State, c rune) {
   107  	fmt.Fprintf(state, "Shape: %v, Stride: %v, Lock: %t", ap.shape, ap.strides, ap.fin)
   108  }
   109  
   110  // IsVector returns whether the access pattern falls into one of three possible definitions of vectors:
   111  //		vanilla vector (not a row or a col)
   112  //		column vector
   113  //		row vector
   114  func (ap *AP) IsVector() bool { return ap.shape.IsVector() }
   115  
   116  // IsVectorLike returns true if the shape is vector-like (i.e. the shape only has one dim that is a non-1).
   117  func (ap *AP) IsVectorLike() bool {
   118  	return ap.shape.IsVectorLike() && allones(ap.strides)
   119  }
   120  
   121  // IsColVec returns true when the access pattern has the shape (x, 1)
   122  func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() }
   123  
   124  // IsRowVec returns true when the access pattern has the shape (1, x)
   125  func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() }
   126  
   127  // IsScalar returns true if the access pattern indicates it's a scalar value.
   128  func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() }
   129  
   130  // IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape.
   131  func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() }
   132  
   133  // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices
   134  func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 }
   135  
   136  // IsZero tell us if the ap has zero size
   137  func (ap *AP) IsZero() bool {
   138  	return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0
   139  }
   140  
   141  // Zero zeros out an AP.
   142  func (ap *AP) zero() {
   143  	// log.Printf("ZEROING. Called by %v", string(debug.Stack()))
   144  
   145  	// Jorge's original implementation for zeroing a AP is as below
   146  	// but to cater for the (*Dense).fix() method of the *Dense
   147  	// a nil shape is used to signal unsetness
   148  	// so we cannot just truncate the shape even though it would be a lot more efficient
   149  
   150  	// ap.shape = ap.shape[:0]
   151  	// ap.strides = ap.strides[:0]
   152  	ReturnInts([]int(ap.shape))
   153  	ReturnInts(ap.strides)
   154  	ap.zeroOnly()
   155  }
   156  
   157  // side effect free zeroing
   158  func (ap *AP) zeroOnly() {
   159  	ap.shape = nil
   160  	ap.strides = nil
   161  
   162  	ap.fin = false
   163  	ap.o = 0
   164  	ap.Δ = 0
   165  }
   166  
   167  func (ap *AP) zeroWithDims(dims int) {
   168  	//ap.shape = BorrowInts(dims)
   169  	//ap.strides = BorrowInts(dims)
   170  	if cap(ap.shape) >= dims {
   171  		ap.shape = ap.shape[:dims]
   172  	}
   173  	ap.shape = BorrowInts(dims)
   174  	if cap(ap.strides) >= dims {
   175  		ap.strides = ap.strides[:dims]
   176  	}
   177  	ap.strides = BorrowInts(dims)
   178  }
   179  
   180  // Clone clones the *AP. Clearly. It returns AP
   181  func (ap *AP) Clone() (retVal AP) {
   182  	retVal = makeAP(cap(ap.shape))
   183  
   184  	copy(retVal.shape, ap.shape)
   185  	copy(retVal.strides, ap.strides)
   186  
   187  	// handle vectors
   188  	retVal.shape = retVal.shape[:len(ap.shape)]
   189  	retVal.strides = retVal.strides[:len(ap.strides)]
   190  
   191  	retVal.fin = ap.fin
   192  	retVal.o = ap.o
   193  	retVal.Δ = ap.Δ
   194  	return
   195  }
   196  
   197  func (ap *AP) CloneTo(dest *AP) {
   198  	dest.shape = append(dest.shape[:0], ap.shape...)
   199  	dest.strides = append(dest.strides[:0], ap.strides...)
   200  	dest.fin = ap.fin
   201  	dest.o = ap.o
   202  	dest.Δ = ap.Δ
   203  }
   204  
   205  // DataOrder returns the data order of the AP.
   206  func (ap *AP) DataOrder() DataOrder { return ap.o }
   207  
   208  // C returns true if the access pattern is C-contiguous array
   209  func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() }
   210  
   211  // F returns true if the access pattern is Fortran contiguous array
   212  func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() }
   213  
   214  // S returns the metadata of the sliced tensor.
   215  func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) {
   216  	if len(slices) > len(ap.shape) {
   217  		// error
   218  		err = errors.Errorf(dimMismatch, len(ap.shape), len(slices))
   219  		return
   220  	}
   221  
   222  	ndEnd = size
   223  	newShape := ap.shape.Clone()   // the new shape
   224  	dims := ap.Dims()              // reported dimensions
   225  	newStrides := BorrowInts(dims) // the new strides
   226  
   227  	var outerDim int
   228  	order := ap.o
   229  	if ap.o.IsRowMajor() || ap.IsVector() {
   230  		outerDim = 0
   231  	} else {
   232  		outerDim = len(ap.shape) - 1
   233  	}
   234  
   235  	for i := 0; i < dims; i++ {
   236  		var sl Slice
   237  		if i <= len(slices)-1 {
   238  			sl = slices[i]
   239  		}
   240  
   241  		size := ap.shape[i]
   242  		var stride int
   243  		stride = ap.strides[i]
   244  		// if ap.IsVector() {
   245  		// 	// handles non-vanilla vectors
   246  		// 	stride = ap.strides[0]
   247  		// } else {
   248  		// 	stride = ap.strides[i]
   249  		// }
   250  
   251  		var start, end, step int
   252  		if start, end, step, err = SliceDetails(sl, size); err != nil {
   253  			err = errors.Wrapf(err, "Unable to get slice details on slice %d with size %d: %v", i, sl, size)
   254  			return
   255  		}
   256  
   257  		// a slice where start == end is []
   258  		ndStart = ndStart + start*stride
   259  		ndEnd = ndEnd - (size-end)*stride
   260  
   261  		if step > 0 {
   262  			if newShape[i] = (end - start) / step; (end-start)%step > 0 && i > 0 {
   263  				newShape[i]++
   264  			}
   265  			newStrides[i] = stride * step
   266  
   267  			//fix
   268  			if newShape[i] <= 0 {
   269  				newShape[i] = 1
   270  			}
   271  		} else {
   272  			newShape[i] = (end - start)
   273  			newStrides[i] = stride
   274  		}
   275  
   276  		if (sl != nil && (!ap.IsVector() && i != outerDim)) || step > 1 {
   277  			order = MakeDataOrder(order, NonContiguous)
   278  		}
   279  	}
   280  
   281  	if ndEnd-ndStart == 1 {
   282  		// scalars are a special case
   283  		newAP = AP{}
   284  		newAP.SetShape() // make it a Scalar
   285  		newAP.lock()
   286  	} else {
   287  
   288  		// drop any dimension with size 1, except the last dimension
   289  		offset := 0
   290  		for d := 0; d < dims; d++ {
   291  			if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1  && dims > 2*/ {
   292  				newShape = append(newShape[:d], newShape[d+1:]...)
   293  				newStrides = append(newStrides[:d], newStrides[d+1:]...)
   294  				d--
   295  				dims--
   296  				offset++
   297  			}
   298  		}
   299  
   300  		newAP = MakeAP(newShape, newStrides, order, ap.Δ)
   301  	}
   302  	return
   303  }
   304  
   305  // T returns the transposed metadata based on the given input
   306  func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) {
   307  
   308  	// prep axes
   309  	if len(axes) > 0 && len(axes) != ap.Dims() {
   310  		err = errors.Errorf(dimMismatch, ap.Dims(), len(axes))
   311  		return
   312  	}
   313  
   314  	dims := len(ap.shape)
   315  	if len(axes) == 0 || axes == nil {
   316  		axes = make([]int, dims)
   317  		for i := 0; i < dims; i++ {
   318  			axes[i] = dims - 1 - i
   319  		}
   320  	}
   321  	a = axes
   322  
   323  	if ap.shape.IsScalarEquiv() {
   324  		return ap.Clone(), a, noopError{}
   325  	}
   326  
   327  	// if axes is 0, 1, 2, 3... then no op
   328  	if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 {
   329  		return ap.Clone(), a, noopError{}
   330  	}
   331  
   332  	currentShape := ap.shape
   333  	currentStride := ap.strides
   334  	shape := make(Shape, len(currentShape))
   335  	strides := make([]int, len(currentStride))
   336  
   337  	switch {
   338  	case ap.IsScalar():
   339  		return
   340  	case ap.IsVector():
   341  		if axes[0] == 0 {
   342  			return
   343  		}
   344  		strides[0], strides[1] = 1, 1
   345  		shape[0], shape[1] = currentShape[1], currentShape[0]
   346  	default:
   347  		copy(shape, currentShape)
   348  		copy(strides, currentStride)
   349  		err = UnsafePermute(axes, shape, strides)
   350  		if err != nil {
   351  			err = handleNoOp(err)
   352  		}
   353  	}
   354  
   355  	o := MakeDataOrder(ap.o, Transposed)
   356  	retVal = MakeAP(shape, strides, o, ap.Δ)
   357  	retVal.fin = true
   358  	return
   359  }
   360  
   361  // locking and unlocking is used to ensure that the shape and stride doesn't change (it's not really safe though, as a direct mutation of the strides/shape would still mutate it, but at least the dimensions cannot change)
   362  func (ap *AP) lock()   { ap.fin = true }
   363  func (ap *AP) unlock() { ap.fin = false }
   364  
   365  func (ap *AP) calcStrides() []int {
   366  	switch {
   367  	case ap.o.IsRowMajor():
   368  		return ap.shape.CalcStrides()
   369  	case ap.o.IsColMajor():
   370  		return ap.shape.CalcStridesColMajor()
   371  	}
   372  	panic("unreachable")
   373  }
   374  
   375  // setDataOrder is a method such that any tensor that embeds *AP will have the same method
   376  func (ap *AP) setDataOrder(o DataOrder) {
   377  	if !o.HasSameOrder(ap.o) {
   378  		ap.o = ap.o.toggleColMajor()
   379  	}
   380  }
   381  
   382  // TransposeIndex returns the new index given the old index
   383  func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int {
   384  	oldCoord, err := Itol(i, oldShape, oldStrides)
   385  	if err != nil {
   386  		panic(err) // or return error?
   387  	}
   388  	/*
   389  		coordss, _ := Permute(pattern, oldCoord)
   390  		coords := coordss[0]
   391  		index, _ := Ltoi(newShape, strides, coords...)
   392  	*/
   393  
   394  	// The above is the "conceptual" algorithm.
   395  	// Too many checks above slows things down, so the below is the "optimized" edition
   396  	var index int
   397  	for i, axis := range pattern {
   398  		index += oldCoord[axis] * newStrides[i]
   399  	}
   400  	return index
   401  }
   402  
   403  // UntransposeIndex returns the old index given the new index
   404  func UntransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int {
   405  	newPattern := make([]int, len(pattern))
   406  	for i, p := range pattern {
   407  		newPattern[p] = i
   408  	}
   409  	return TransposeIndex(i, oldShape, newPattern, oldStrides, newStrides)
   410  }
   411  
   412  // BroadcastStrides handles broadcasting from different shapes.
   413  //
   414  // Deprecated: this function will be unexported
   415  func BroadcastStrides(destShape, srcShape Shape, destStrides, srcStrides []int) (retVal []int, err error) {
   416  	dims := len(destShape)
   417  	start := dims - len(srcShape)
   418  
   419  	if destShape.IsVector() && srcShape.IsVector() {
   420  		return []int{srcStrides[0]}, nil
   421  	}
   422  
   423  	if start < 0 {
   424  		//error
   425  		err = errors.Errorf(dimMismatch, dims, len(srcShape))
   426  		return
   427  	}
   428  
   429  	retVal = BorrowInts(len(destStrides))
   430  	for i := dims - 1; i >= start; i-- {
   431  		s := srcShape[i-start]
   432  		switch {
   433  		case s == 1:
   434  			retVal[i] = 0
   435  		case s != destShape[i]:
   436  			// error
   437  			err = errors.Errorf("Cannot broadcast from %v to %v", srcShape, destShape)
   438  			return
   439  		default:
   440  			retVal[i] = srcStrides[i-start]
   441  		}
   442  	}
   443  	for i := 0; i < start; i++ {
   444  		retVal[i] = 0
   445  	}
   446  	return
   447  }