gorgonia.org/gorgonia@v0.9.17/values_primitives.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  	"unsafe"
     8  
     9  	"github.com/chewxy/hm"
    10  	"github.com/pkg/errors"
    11  	"gorgonia.org/tensor"
    12  )
    13  
    14  // Scalar represents a scalar(non-array-based) value. Do note that it's the pointers of the scalar types (F64, F32, etc) that implement
    15  // the Scalar interface. The main reason is primarily due to optimizations with regards to memory allocation and copying for device interoperability.
    16  type Scalar interface {
    17  	Value
    18  	isScalar() bool
    19  }
    20  
    21  // F64 represents a float64 value.
    22  type F64 float64
    23  
    24  // F32 represents a float32 value.
    25  type F32 float32
    26  
    27  // I represents a int value.
    28  type I int
    29  
    30  // I64 represents a int64 value.
    31  type I64 int64
    32  
    33  // I32 represents a int32 value.
    34  type I32 int32
    35  
    36  // U8 represents a byte value.
    37  type U8 byte
    38  
    39  // B represents a bool value.
    40  type B bool
    41  
    42  func NewF64(v float64) *F64 { r := F64(v); return &r }
    43  func NewF32(v float32) *F32 { r := F32(v); return &r }
    44  func NewI(v int) *I         { r := I(v); return &r }
    45  func NewI64(v int64) *I64   { r := I64(v); return &r }
    46  func NewI32(v int32) *I32   { r := I32(v); return &r }
    47  func NewU8(v byte) *U8      { r := U8(v); return &r }
    48  func NewB(v bool) *B        { r := B(v); return &r }
    49  
    50  /* Shape() */
    51  
    52  // Shape returns a scalar shape for all scalar values
    53  func (v *F64) Shape() tensor.Shape { return scalarShape }
    54  
    55  // Shape returns a scalar shape for all scalar values
    56  func (v *F32) Shape() tensor.Shape { return scalarShape }
    57  
    58  // Shape returns a scalar shape for all scalar values
    59  func (v *I) Shape() tensor.Shape { return scalarShape }
    60  
    61  // Shape returns a scalar shape for all scalar values
    62  func (v *I64) Shape() tensor.Shape { return scalarShape }
    63  
    64  // Shape returns a scalar shape for all scalar values
    65  func (v *I32) Shape() tensor.Shape { return scalarShape }
    66  
    67  // Shape returns a scalar shape for all scalar values
    68  func (v *U8) Shape() tensor.Shape { return scalarShape }
    69  
    70  // Shape returns a scalar shape for all scalar values
    71  func (v *B) Shape() tensor.Shape { return scalarShape }
    72  
    73  // Size returns 0 for all scalar Values
    74  func (v *F64) Size() int { return 0 }
    75  
    76  // Size returns 0 for all scalar Values
    77  func (v *F32) Size() int { return 0 }
    78  
    79  // Size returns 0 for all scalar Values
    80  func (v *I) Size() int { return 0 }
    81  
    82  // Size returns 0 for all scalar Values
    83  func (v *I64) Size() int { return 0 }
    84  
    85  // Size returns 0 for all scalar Values
    86  func (v *I32) Size() int { return 0 }
    87  
    88  // Size returns 0 for all scalar Values
    89  func (v *U8) Size() int { return 0 }
    90  
    91  // Size returns 0 for all scalar Values
    92  func (v *B) Size() int { return 0 }
    93  
    94  /* Data() */
    95  
    96  // Data returns the original representation of the Value
    97  func (v *F64) Data() interface{} { return v.any() }
    98  
    99  // Data returns the original representation of the Value
   100  func (v *F32) Data() interface{} { return v.any() }
   101  
   102  // Data returns the original representation of the Value
   103  func (v *I) Data() interface{} { return v.any() }
   104  
   105  // Data returns the original representation of the Value
   106  func (v *I64) Data() interface{} { return v.any() }
   107  
   108  // Data returns the original representation of the Value
   109  func (v *I32) Data() interface{} { return v.any() }
   110  
   111  // Data returns the original representation of the Value
   112  func (v *U8) Data() interface{} { return v.any() }
   113  
   114  // Data returns the original representation of the Value
   115  func (v *B) Data() interface{} { return v.any() }
   116  
   117  func (v *F64) any() float64 { return float64(*v) }
   118  func (v *F32) any() float32 { return float32(*v) }
   119  func (v *I) any() int       { return int(*v) }
   120  func (v *I64) any() int64   { return int64(*v) }
   121  func (v *I32) any() int32   { return int32(*v) }
   122  func (v *U8) any() byte     { return byte(*v) }
   123  func (v *B) any() bool      { return bool(*v) }
   124  
   125  /* implements fmt.Formatter */
   126  
   127  // Format implements fmt.Formatter
   128  func (v *F64) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   129  
   130  // Format implements fmt.Formatter
   131  func (v *F32) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   132  
   133  // Format implements fmt.Formatter
   134  func (v *I) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   135  
   136  // Format implements fmt.Formatter
   137  func (v *I64) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   138  
   139  // Format implements fmt.Formatter
   140  func (v *I32) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   141  
   142  // Format implements fmt.Formatter
   143  func (v *U8) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   144  
   145  // Format implements fmt.Formatter
   146  func (v *B) Format(s fmt.State, c rune) { formatScalar(v, s, c) }
   147  
   148  /* Dtype() */
   149  
   150  // Dtype  returns the Dtype of the value
   151  func (v *F64) Dtype() tensor.Dtype { return tensor.Float64 }
   152  
   153  // Dtype  returns the Dtype of the value
   154  func (v *F32) Dtype() tensor.Dtype { return tensor.Float32 }
   155  
   156  // Dtype  returns the Dtype of the value
   157  func (v *I) Dtype() tensor.Dtype { return tensor.Int }
   158  
   159  // Dtype  returns the Dtype of the value
   160  func (v *I64) Dtype() tensor.Dtype { return tensor.Int64 }
   161  
   162  // Dtype  returns the Dtype of the value
   163  func (v *I32) Dtype() tensor.Dtype { return tensor.Int32 }
   164  
   165  // Dtype  returns the Dtype of the value
   166  func (v *U8) Dtype() tensor.Dtype { return tensor.Byte }
   167  
   168  // Dtype  returns the Dtype of the value
   169  func (v *B) Dtype() tensor.Dtype { return tensor.Bool }
   170  
   171  /* isScalar */
   172  
   173  func (v *F64) isScalar() bool { return true }
   174  func (v *F32) isScalar() bool { return true }
   175  func (v *I) isScalar() bool   { return true }
   176  func (v *I64) isScalar() bool { return true }
   177  func (v *I32) isScalar() bool { return true }
   178  func (v *U8) isScalar() bool  { return true }
   179  func (v *B) isScalar() bool   { return true }
   180  
   181  /* Uintptr */
   182  
   183  // Uintptr satisfies the tensor.Memory interface
   184  func (v *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   185  
   186  // Uintptr satisfies the tensor.Memory interface
   187  func (v *F32) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   188  
   189  // Uintptr satisfies the tensor.Memory interface
   190  func (v *I) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   191  
   192  // Uintptr satisfies the tensor.Memory interface
   193  func (v *I64) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   194  
   195  // Uintptr satisfies the tensor.Memory interface
   196  func (v *I32) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   197  
   198  // Uintptr satisfies the tensor.Memory interface
   199  func (v *U8) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   200  
   201  // Uintptr satisfies the tensor.Memory interface
   202  func (v *B) Uintptr() uintptr { return uintptr(unsafe.Pointer(v)) }
   203  
   204  /* MemSize */
   205  
   206  // MemSize satisfies the tensor.Memory interface
   207  func (v *F64) MemSize() uintptr { return 8 }
   208  
   209  // MemSize satisfies the tensor.Memory interface
   210  func (v *F32) MemSize() uintptr { return 4 }
   211  
   212  // MemSize satisfies the tensor.Memory interface
   213  func (v *I) MemSize() uintptr { return reflect.TypeOf(*v).Size() }
   214  
   215  // MemSize satisfies the tensor.Memory interface
   216  func (v *I64) MemSize() uintptr { return 8 }
   217  
   218  // MemSize satisfies the tensor.Memory interface
   219  func (v *I32) MemSize() uintptr { return 4 }
   220  
   221  // MemSize satisfies the tensor.Memory interface
   222  func (v *U8) MemSize() uintptr { return 1 }
   223  
   224  // MemSize satisfies the tensor.Memory interface
   225  func (v *B) MemSize() uintptr { return reflect.TypeOf(*v).Size() }
   226  
   227  /* Pointer */
   228  
   229  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   230  func (v *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   231  
   232  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   233  func (v *F32) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   234  
   235  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   236  func (v *I) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   237  
   238  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   239  func (v *I64) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   240  
   241  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   242  func (v *I32) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   243  
   244  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   245  func (v *U8) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   246  
   247  // Pointer returns the pointer as an unsafe.Pointer. Satisfies the tensor.Memory interface
   248  func (v *B) Pointer() unsafe.Pointer { return unsafe.Pointer(v) }
   249  
   250  func formatScalar(v Scalar, s fmt.State, c rune) {
   251  	var buf bytes.Buffer
   252  	var ok bool
   253  
   254  	buf.WriteRune('%')
   255  
   256  	var width int
   257  	if width, ok = s.Width(); ok {
   258  		fmt.Fprintf(&buf, "%d", width)
   259  	}
   260  
   261  	var prec int
   262  	if prec, ok = s.Precision(); ok {
   263  		fmt.Fprintf(&buf, ".%d", prec)
   264  	}
   265  
   266  	switch c {
   267  	case 's':
   268  		buf.WriteRune('v')
   269  	case 'd':
   270  		switch v.(type) {
   271  		case *F64, *F32, *U8, *B:
   272  			buf.WriteRune('v')
   273  		default:
   274  			buf.WriteRune(c)
   275  		}
   276  	case 'f', 'g':
   277  		switch v.(type) {
   278  		case *I, *I64, *I32, *U8, *B:
   279  			buf.WriteRune('v')
   280  		default:
   281  			buf.WriteRune(c)
   282  		}
   283  	default:
   284  		buf.WriteRune(c)
   285  	}
   286  
   287  	if s.Flag('+') {
   288  		s.Write([]byte(v.Dtype().String()))
   289  		s.Write([]byte{' '})
   290  	}
   291  
   292  	fmt.Fprintf(s, buf.String(), v.Data())
   293  }
   294  
   295  func anyToScalar(any interface{}) (Scalar, tensor.Dtype) {
   296  	switch at := any.(type) {
   297  	case Scalar:
   298  		return at, at.Dtype()
   299  	case float64:
   300  		return NewF64(at), Float64
   301  	case float32:
   302  		return NewF32(at), Float32
   303  	case int:
   304  		return NewI(at), Int
   305  	case int32:
   306  		return NewI32(at), Int32
   307  	case int64:
   308  		return NewI64(at), Int64
   309  	case byte:
   310  		return NewU8(at), Byte
   311  	case bool:
   312  		return NewB(at), Bool
   313  	default:
   314  		panic(fmt.Sprintf("%v(%T) not scalar/not handled", any, any))
   315  	}
   316  }
   317  
   318  func anyToValue(any interface{}) (val Value, t hm.Type, dt tensor.Dtype, err error) {
   319  	switch a := any.(type) {
   320  	case Value:
   321  		val = a
   322  		t = TypeOf(a)
   323  		dt = a.Dtype()
   324  		return
   325  	case float64, float32, int, int64, int32, byte, bool:
   326  		val, dt = anyToScalar(any)
   327  		t = dt
   328  		return
   329  	case F64:
   330  		return NewF64(float64(a)), tensor.Float64, tensor.Float64, nil
   331  	case F32:
   332  		return NewF32(float32(a)), tensor.Float32, tensor.Float32, nil
   333  	case I:
   334  		return NewI(int(a)), tensor.Int, tensor.Int, nil
   335  	case I64:
   336  		return NewI64(int64(a)), tensor.Int64, tensor.Int64, nil
   337  	case I32:
   338  		return NewI32(int32(a)), tensor.Int32, tensor.Int32, nil
   339  	case U8:
   340  		return NewU8(byte(a)), tensor.Uint8, tensor.Uint8, nil
   341  	case B:
   342  		return NewB(bool(a)), tensor.Bool, tensor.Bool, nil
   343  	case tensor.Tensor:
   344  		val = a
   345  		t = TypeOf(a)
   346  		dt = a.Dtype()
   347  		return
   348  	default:
   349  		err = errors.Errorf("value %v of %T not yet handled", any, any)
   350  		return
   351  	}
   352  }
   353  
   354  func one(dt tensor.Dtype) Scalar {
   355  	switch dt {
   356  	case tensor.Float64:
   357  		return NewF64(float64(1))
   358  	case tensor.Float32:
   359  		return NewF32(float32(1))
   360  	case tensor.Int:
   361  		return NewI(1)
   362  	case tensor.Int32:
   363  		return NewI32(int32(1))
   364  	case tensor.Int64:
   365  		return NewI64(int64(1))
   366  	case tensor.Byte:
   367  		return NewU8(byte(1))
   368  	case tensor.Bool:
   369  		return NewB(true)
   370  	default:
   371  		panic("Unhandled dtype")
   372  	}
   373  }
   374  
   375  func zero(dt tensor.Dtype) Scalar {
   376  	switch dt {
   377  	case tensor.Float64:
   378  		return NewF64(float64(0))
   379  	case tensor.Float32:
   380  		return NewF32(float32(0))
   381  	case tensor.Int:
   382  		return NewI(0)
   383  	case tensor.Int32:
   384  		return NewI32(int32(0))
   385  	case tensor.Int64:
   386  		return NewI64(int64(0))
   387  	case tensor.Byte:
   388  		return NewU8(byte(0))
   389  	case tensor.Bool:
   390  		return NewB(false)
   391  	default:
   392  		panic("Unhandled dtype")
   393  	}
   394  }