github.com/wzzhu/tensor@v0.9.24/dense_format.go (about)

     1  package tensor
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  	"strconv"
     8  )
     9  
    10  var fmtFlags = [...]rune{'+', '-', '#', ' ', '0'}
    11  var fmtByte = []byte("%")
    12  var precByte = []byte(".")
    13  var newline = []byte("\n")
    14  
    15  var (
    16  	matFirstStart = []byte("⎡")
    17  	matFirstEnd   = []byte("⎤\n")
    18  	matLastStart  = []byte("⎣")
    19  	matLastEnd    = []byte("⎦\n")
    20  	rowStart      = []byte("⎢")
    21  	rowEnd        = []byte("⎥\n")
    22  	vecStart      = []byte("[")
    23  	vecEnd        = []byte("]")
    24  	colVecStart   = []byte("C[")
    25  	rowVecStart   = []byte("R[")
    26  
    27  	hElisionCompact = []byte("⋯ ")
    28  	hElision        = []byte("... ")
    29  	vElisionCompact = []byte("  ⋮  \n")
    30  	vElision        = []byte(".\n.\n.\n")
    31  
    32  	ufVec    = []byte("Vector")
    33  	ufMat    = []byte("Matrix")
    34  	ufTensor = []byte("Tensor-")
    35  
    36  	hInvalid = []byte("--")
    37  )
    38  
    39  type fmtState struct {
    40  	fmt.State
    41  
    42  	buf                *bytes.Buffer
    43  	pad                []byte
    44  	hElision, vElision []byte
    45  
    46  	meta bool
    47  	flat bool
    48  	ext  bool // extended (i.e no elision)
    49  	comp bool // compact
    50  	c    rune // c is here mainly for struct packing reasons
    51  
    52  	w, p int // width and precision
    53  	base int // used only for int/byte arrays
    54  
    55  	rows, cols int
    56  	pr, pc     int // printed row, printed col
    57  }
    58  
    59  func newFmtState(f fmt.State, c rune) *fmtState {
    60  	retVal := &fmtState{
    61  		State: f,
    62  		buf:   bytes.NewBuffer(make([]byte, 10)),
    63  		c:     c,
    64  
    65  		meta:     f.Flag('+'),
    66  		flat:     f.Flag('-'),
    67  		ext:      f.Flag('#'),
    68  		comp:     c == 's',
    69  		hElision: hElision,
    70  		vElision: vElision,
    71  	}
    72  
    73  	w, _ := f.Width()
    74  	p, _ := f.Precision()
    75  	retVal.w = w
    76  	retVal.p = p
    77  	return retVal
    78  }
    79  
    80  func (f *fmtState) originalFmt() string {
    81  	buf := bytes.NewBuffer(fmtByte)
    82  	for _, flag := range fmtFlags {
    83  		if f.Flag(int(flag)) {
    84  			buf.WriteRune(flag)
    85  		}
    86  	}
    87  
    88  	// width
    89  	if w, ok := f.Width(); ok {
    90  		buf.WriteString(strconv.Itoa(w))
    91  	}
    92  
    93  	// precision
    94  	if p, ok := f.Precision(); ok {
    95  		buf.Write(precByte)
    96  		buf.WriteString(strconv.Itoa(p))
    97  	}
    98  
    99  	buf.WriteRune(f.c)
   100  	return buf.String()
   101  
   102  }
   103  
   104  func (f *fmtState) cleanFmt() string {
   105  	buf := bytes.NewBuffer(fmtByte)
   106  
   107  	// width
   108  	if w, ok := f.Width(); ok {
   109  		buf.WriteString(strconv.Itoa(w))
   110  	}
   111  
   112  	// precision
   113  	if p, ok := f.Precision(); ok {
   114  		buf.Write(precByte)
   115  		buf.WriteString(strconv.Itoa(p))
   116  	}
   117  
   118  	buf.WriteRune(f.c)
   119  	return buf.String()
   120  }
   121  
   122  // does the calculation for metadata
   123  func (f *fmtState) populate(t *Dense) {
   124  	switch {
   125  	case t.IsVector():
   126  		f.rows = 1
   127  		f.cols = t.Size()
   128  	case t.IsScalarEquiv():
   129  		f.rows = 1
   130  		f.cols = 1
   131  	default:
   132  		f.rows = t.Shape()[t.Dims()-2]
   133  		f.cols = t.Shape()[t.Dims()-1]
   134  	}
   135  
   136  	switch {
   137  	case f.flat && f.ext:
   138  		f.pc = t.len()
   139  	case f.flat && f.comp:
   140  		f.pc = 5
   141  		f.hElision = hElisionCompact
   142  	case f.flat:
   143  		f.pc = 10
   144  	case f.ext:
   145  		f.pc = f.cols
   146  		f.pr = f.rows
   147  	case f.comp:
   148  		f.pc = MinInt(f.cols, 4)
   149  		f.pr = MinInt(f.rows, 4)
   150  		f.hElision = hElisionCompact
   151  		f.vElision = vElisionCompact
   152  	default:
   153  		f.pc = MinInt(f.cols, 8)
   154  		f.pr = MinInt(f.rows, 8)
   155  	}
   156  
   157  }
   158  
   159  func (f *fmtState) acceptableRune(d *Dense) {
   160  	if f.c == 'H' {
   161  		f.meta = true
   162  		return // accept H as header only
   163  	}
   164  	switch d.t.Kind() {
   165  	case reflect.Float64:
   166  		switch f.c {
   167  		case 'f', 'e', 'E', 'G', 'b':
   168  		default:
   169  			f.c = 'g'
   170  		}
   171  	case reflect.Float32:
   172  		switch f.c {
   173  		case 'f', 'e', 'E', 'G', 'b':
   174  		default:
   175  			f.c = 'g'
   176  		}
   177  	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
   178  		switch f.c {
   179  		case 'b':
   180  			f.base = 2
   181  		case 'd':
   182  			f.base = 10
   183  		case 'o':
   184  			f.base = 8
   185  		case 'x', 'X':
   186  			f.base = 16
   187  		default:
   188  			f.base = 10
   189  			f.c = 'd'
   190  		}
   191  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   192  		switch f.c {
   193  		case 'b':
   194  			f.base = 2
   195  		case 'd':
   196  			f.base = 10
   197  		case 'o':
   198  			f.base = 8
   199  		case 'x', 'X':
   200  			f.base = 16
   201  		default:
   202  			f.base = 10
   203  			f.c = 'd'
   204  		}
   205  	case reflect.Bool:
   206  		f.c = 't'
   207  	default:
   208  		f.c = 'v'
   209  	}
   210  }
   211  
   212  func (f *fmtState) calcWidth(d *Dense) {
   213  	format := f.cleanFmt()
   214  	f.w = 0
   215  	masked := false
   216  	if d.IsMasked() {
   217  		if d.MaskedAny().(bool) {
   218  			masked = true
   219  		}
   220  	}
   221  	for i := 0; i < d.len(); i++ {
   222  		w, _ := fmt.Fprintf(f.buf, format, d.Get(i))
   223  		if masked {
   224  			if d.mask[i] {
   225  				w, _ = fmt.Fprintf(f.buf, "%s", hInvalid)
   226  			}
   227  		}
   228  		if w > f.w {
   229  			f.w = w
   230  		}
   231  		f.buf.Reset()
   232  	}
   233  }
   234  
   235  func (f *fmtState) makePad() {
   236  	f.pad = make([]byte, MaxInt(f.w, 2))
   237  	for i := range f.pad {
   238  		f.pad[i] = ' '
   239  	}
   240  }
   241  
   242  func (f *fmtState) writeHElision() {
   243  	f.Write(f.hElision)
   244  }
   245  
   246  func (f *fmtState) writeVElision() {
   247  	f.Write(f.vElision)
   248  }
   249  
   250  // Format implements fmt.Formatter. Formatting can be controlled with verbs and flags. All default Go verbs are supported and work as expected.
   251  // By default, only 8 columns and rows are printed (the first and the last 4 columns and rows, while the middle columns and rows are ellided)
   252  // Special flags are:
   253  // 		'-' for printing a flat array of values
   254  //		'+' for printing extra metadata before printing the tensor (it prints shape, stride and type, which are useful for debugging)
   255  //		'#' for printing the full tensor - there are no elisions. Overrides the 's' verb
   256  //
   257  // Special care also needs be taken for the verb 's' - it prints a super compressed version of the tensor, only printing 4 cols and 4 rows.
   258  func (t *Dense) Format(s fmt.State, c rune) {
   259  	if c == 'i' {
   260  		fmt.Fprintf(s, "INFO:\n\tAP:  %v\n\tOLD: %v\n\tTRANS %v\n\tENGINE: %T\n", t.AP, t.old, t.transposeWith, t.e)
   261  		return
   262  	}
   263  
   264  	f := newFmtState(s, c)
   265  	if t.IsScalar() {
   266  		o := f.originalFmt()
   267  		fmt.Fprintf(f, o, t.Get(0))
   268  		return
   269  	}
   270  
   271  	f.acceptableRune(t)
   272  	f.calcWidth(t)
   273  	f.makePad()
   274  	f.populate(t)
   275  
   276  	if f.meta {
   277  		switch {
   278  		case t.IsVector():
   279  			f.Write(ufVec)
   280  		case t.Dims() == 2:
   281  			f.Write(ufMat)
   282  		default:
   283  			f.Write(ufTensor)
   284  			fmt.Fprintf(f, "%d", t.Dims())
   285  		}
   286  		fmt.Fprintf(f, " %v %v\n", t.Shape(), t.Strides())
   287  	}
   288  
   289  	if f.c == 'H' {
   290  		return
   291  	}
   292  
   293  	if !t.IsNativelyAccessible() {
   294  		fmt.Fprintf(f, "Inaccesible data")
   295  		return
   296  	}
   297  
   298  	format := f.cleanFmt()
   299  
   300  	if f.flat {
   301  		f.Write(vecStart)
   302  		switch {
   303  		case f.ext:
   304  			for i := 0; i < t.len(); i++ {
   305  				if !t.IsMasked() {
   306  					fmt.Fprintf(f, format, t.Get(i))
   307  				} else {
   308  					if t.mask[i] {
   309  						fmt.Fprintf(f, "%s", hInvalid)
   310  					} else {
   311  						fmt.Fprintf(f, format, t.Get(i))
   312  					}
   313  				}
   314  				if i < t.len()-1 {
   315  					f.Write(f.pad[:1])
   316  				}
   317  			}
   318  		case t.viewOf != 0:
   319  			it := IteratorFromDense(t)
   320  			var c, i int
   321  			var err error
   322  			for i, err = it.Next(); err == nil; i, err = it.Next() {
   323  				if !t.IsMasked() {
   324  					fmt.Fprintf(f, format, t.Get(i))
   325  				} else {
   326  					if t.mask[i] {
   327  						fmt.Fprintf(f, "%s", hInvalid)
   328  					} else {
   329  						fmt.Fprintf(f, format, t.Get(i))
   330  					}
   331  				}
   332  				f.Write(f.pad[:1])
   333  
   334  				c++
   335  				if c >= f.pc {
   336  					f.writeHElision()
   337  					break
   338  				}
   339  			}
   340  			if err != nil {
   341  				if _, noop := err.(NoOpError); !noop {
   342  					fmt.Fprintf(f, "ERROR ITERATING: %v", err)
   343  
   344  				}
   345  			}
   346  		default:
   347  			for i := 0; i < f.pc; i++ {
   348  				if !t.IsMasked() {
   349  					fmt.Fprintf(f, format, t.Get(i))
   350  				} else {
   351  					if t.mask[i] {
   352  						fmt.Fprintf(f, "%s", hInvalid)
   353  					} else {
   354  						fmt.Fprintf(f, format, t.Get(i))
   355  					}
   356  				}
   357  				f.Write(f.pad[:1])
   358  			}
   359  			if f.pc < t.len() {
   360  				f.writeHElision()
   361  			}
   362  		}
   363  		f.Write(vecEnd)
   364  		return
   365  	}
   366  
   367  	// standard stuff
   368  	it := NewIterator(&t.AP)
   369  	coord := it.Coord()
   370  
   371  	firstRow := true
   372  	firstVal := true
   373  	var lastRow, lastCol int
   374  	var expected int
   375  	for next, err := it.Next(); err == nil; next, err = it.Next() {
   376  		if next < expected {
   377  			continue
   378  		}
   379  
   380  		var col, row int
   381  		row = lastRow
   382  		col = lastCol
   383  		if f.rows > f.pr && row > f.pr/2 && row < f.rows-f.pr/2 {
   384  			continue
   385  		}
   386  
   387  		if firstVal {
   388  			if firstRow {
   389  				switch {
   390  				case t.IsColVec():
   391  					f.Write(colVecStart)
   392  				case t.IsRowVec():
   393  					f.Write(rowVecStart)
   394  				case t.IsVector():
   395  					f.Write(vecStart)
   396  				case t.IsScalarEquiv():
   397  					for i := 0; i < t.Dims(); i++ {
   398  						f.Write(vecStart)
   399  					}
   400  				default:
   401  					f.Write(matFirstStart)
   402  				}
   403  
   404  			} else {
   405  				var matLastRow bool
   406  				if !t.IsVector() {
   407  					matLastRow = coord[len(coord)-2] == f.rows-1
   408  				}
   409  				if matLastRow {
   410  					f.Write(matLastStart)
   411  				} else {
   412  					f.Write(rowStart)
   413  				}
   414  			}
   415  			firstVal = false
   416  		}
   417  
   418  		// actual printing of the value
   419  		if f.cols <= f.pc || (col < f.pc/2 || (col >= f.cols-f.pc/2)) {
   420  			var w int
   421  
   422  			if t.IsMasked() {
   423  				if t.mask[next] {
   424  					w, _ = fmt.Fprintf(f.buf, "%s", hInvalid)
   425  				} else {
   426  					w, _ = fmt.Fprintf(f.buf, format, t.Get(next))
   427  				}
   428  			} else {
   429  				w, _ = fmt.Fprintf(f.buf, format, t.Get(next))
   430  			}
   431  			f.Write(f.pad[:f.w-w]) // prepad
   432  			f.Write(f.buf.Bytes()) // write
   433  
   434  			if col < f.cols-1 { // pad with a space
   435  				f.Write(f.pad[:2])
   436  			}
   437  			f.buf.Reset()
   438  		} else if col == f.pc/2 {
   439  			f.writeHElision()
   440  		}
   441  
   442  		// done printing
   443  		// check for end of rows
   444  		if col == f.cols-1 {
   445  			eom := row == f.rows-1
   446  			switch {
   447  			case t.IsVector():
   448  				f.Write(vecEnd)
   449  				return
   450  			case t.IsScalarEquiv():
   451  				for i := 0; i < t.Dims(); i++ {
   452  					f.Write(vecEnd)
   453  				}
   454  				return
   455  			case firstRow:
   456  				f.Write(matFirstEnd)
   457  			case eom:
   458  				f.Write(matLastEnd)
   459  				if t.IsMatrix() {
   460  					return
   461  				}
   462  
   463  				// one newline for every dimension above 2
   464  				for i := t.Dims(); i > 2; i-- {
   465  					f.Write(newline)
   466  				}
   467  
   468  			default:
   469  				f.Write(rowEnd)
   470  			}
   471  
   472  			if firstRow {
   473  				firstRow = false
   474  			}
   475  
   476  			if eom {
   477  				firstRow = true
   478  			}
   479  
   480  			firstVal = true
   481  
   482  			// figure out elision
   483  			if f.rows > f.pr && row+1 == f.pr/2 {
   484  				expectedCoord := BorrowInts(len(coord))
   485  				copy(expectedCoord, coord)
   486  				expectedCoord[len(expectedCoord)-2] = f.rows - (f.pr / 2)
   487  				expected, _ = Ltoi(t.Shape(), t.Strides(), expectedCoord...)
   488  				ReturnInts(expectedCoord)
   489  
   490  				f.writeVElision()
   491  			}
   492  		}
   493  
   494  		// cleanup
   495  		switch {
   496  		case t.IsRowVec():
   497  			lastRow = coord[len(coord)-2]
   498  			lastCol = coord[len(coord)-1]
   499  		case t.IsColVec():
   500  			lastRow = coord[len(coord)-1]
   501  			lastCol = coord[len(coord)-2]
   502  		case t.IsVector():
   503  			lastCol = coord[len(coord)-1]
   504  		default:
   505  			lastRow = coord[len(coord)-2]
   506  			lastCol = coord[len(coord)-1]
   507  		}
   508  	}
   509  }