github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/niterator/shape/shape.go (about)

     1  package shape
     2  
     3  import "fmt"
     4  
     5  type AP struct {
     6  	Shape  []int
     7  	Stride []int
     8  }
     9  
    10  func (ap *AP) TotalSize() int {
    11  	total := 1
    12  	for _, size := range ap.Shape {
    13  		total *= size
    14  	}
    15  	return total
    16  }
    17  
    18  func (ap *AP) String() string {
    19  	return fmt.Sprintf("%v", ap.Shape)
    20  }
    21  
    22  func New(shape ...int) *AP {
    23  	ap := &AP{}
    24  	ap.Shape = shape
    25  	ap.Stride = make([]int, len(shape))
    26  
    27  	acc := 1
    28  	for i := len(shape) - 1; i >= 0; i-- {
    29  		ap.Stride[i] = acc
    30  		d := shape[i]
    31  		if d < 0 {
    32  			panic("negative dimension size does not make sense")
    33  		}
    34  		acc *= d
    35  	}
    36  	return ap
    37  }