gorgonia.org/gorgonia@v0.9.17/values_utils.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/chewxy/hm"
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // TypeOf returns the Type of the value
    12  func TypeOf(v Value) hm.Type {
    13  	switch t := v.(type) {
    14  	case tensor.Tensor:
    15  		dt, dim := tensorInfo(t)
    16  		return makeTensorType(dim, dt)
    17  	case Scalar:
    18  		return t.Dtype()
    19  	case Typer:
    20  		return t.Type()
    21  
    22  	default:
    23  		panic(fmt.Sprintf("TypeOf Not yet implemented for %v %T", v, v))
    24  	}
    25  }
    26  
    27  func typeCheckTypeOf(v Value) hm.Type {
    28  	switch t := v.(type) {
    29  	case tensor.Tensor:
    30  		dt, dim := tensorInfo(t)
    31  		return newTensorType(dim, dt)
    32  	case Scalar:
    33  		return t.Dtype()
    34  	case Typer:
    35  		return t.Type()
    36  
    37  	default:
    38  		panic(fmt.Sprintf("TypeOf Not yet implemented for %v %T", v, v))
    39  	}
    40  }
    41  
    42  // ValueEq is the equality function for values
    43  func ValueEq(a, b Value) bool {
    44  	if a == nil && b == nil {
    45  		return true
    46  	}
    47  	switch at := a.(type) {
    48  	case Scalar:
    49  		if bt, ok := b.(Scalar); ok {
    50  			return scalarEq(at, bt)
    51  		}
    52  		return false
    53  	case tensor.Tensor:
    54  		if bt, ok := b.(tensor.Tensor); ok {
    55  			return at.Eq(bt)
    56  			//log.Printf("at.info %#v, bt.info %#v", a.(*tensor.Dense).Info(), b.(*tensor.Dense).Info())
    57  		}
    58  		return false
    59  	case ValueEqualer:
    60  		return at.ValueEq(b)
    61  	default:
    62  		panic(fmt.Sprintf("Not implemented yet, %T", a))
    63  	}
    64  }
    65  
    66  // ValueClose checks whether two values are close to one another. It's predominantly used as an alternative equality test for floats
    67  func ValueClose(a, b Value) bool {
    68  	if a == nil && b == nil {
    69  		return true
    70  	}
    71  
    72  	switch at := a.(type) {
    73  	case Scalar:
    74  		if bt, ok := b.(Scalar); ok {
    75  			return scalarClose(at, bt)
    76  		}
    77  		return false
    78  	case tensor.Tensor:
    79  		if bt, ok := b.(tensor.Tensor); ok {
    80  			return tensorClose(at, bt)
    81  		}
    82  		return false
    83  	case ValueCloser:
    84  		return at.ValueClose(b)
    85  	default:
    86  		panic("Not implemented yet")
    87  	}
    88  }
    89  
    90  // CloneValue clones a value. For scalars, since Go copies scalars, it returns itself
    91  func CloneValue(v Value) (Value, error) {
    92  	switch vt := v.(type) {
    93  	case *F64:
    94  		retVal := *vt
    95  		return &retVal, nil
    96  	case *F32:
    97  		retVal := *vt
    98  		return &retVal, nil
    99  	case *I:
   100  		retVal := *vt
   101  		return &retVal, nil
   102  	case *I32:
   103  		retVal := *vt
   104  		return &retVal, nil
   105  	case *I64:
   106  		retVal := *vt
   107  		return &retVal, nil
   108  	case *U8:
   109  		retVal := *vt
   110  		return &retVal, nil
   111  	case *B:
   112  		retVal := *vt
   113  		return &retVal, nil
   114  	case tensor.Tensor:
   115  		return vt.Clone().(*tensor.Dense), nil
   116  	case CloneErrorer:
   117  		ret, err := vt.Clone()
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  		retVal, ok := ret.(Value)
   122  		if !ok {
   123  			return nil, errors.Errorf("Cloner is not a value: %v %T", v, v)
   124  		}
   125  		return retVal, nil
   126  	case Cloner:
   127  		return vt.Clone().(Value), nil
   128  	default:
   129  		return nil, errors.Errorf("Unable to clone value of type %T", v)
   130  	}
   131  }
   132  
   133  // ZeroValue returns the zero value of a type
   134  func ZeroValue(v Value) Value {
   135  	switch vt := v.(type) {
   136  	case *F64:
   137  		*vt = 0
   138  		return vt
   139  	case *F32:
   140  		*vt = 0
   141  		return vt
   142  	case *I:
   143  		*vt = 0
   144  		return vt
   145  	case *I32:
   146  		*vt = 0
   147  		return vt
   148  	case *I64:
   149  		*vt = 0
   150  		return vt
   151  	case *U8:
   152  		*vt = 0
   153  		return vt
   154  	case *B:
   155  		*vt = false
   156  		return vt
   157  	case tensor.Tensor:
   158  		vt.Zero()
   159  		return vt
   160  	case ZeroValuer:
   161  		return vt.ZeroValue()
   162  	default:
   163  		panic(fmt.Sprintf("Cannot return zero value of %T", v))
   164  	}
   165  }
   166  
   167  // Copy copies the src values into dest values. For scalars, it just returns itself
   168  func Copy(dest, src Value) (Value, error) {
   169  	var ok bool
   170  	switch srcT := src.(type) {
   171  	case *F64:
   172  		var destS *F64
   173  		if destS, ok = dest.(*F64); !ok {
   174  			return nil, errors.Errorf("Expected dest to be *F64. Got %T instead", dest)
   175  		}
   176  		*destS = *srcT
   177  		return destS, nil
   178  	case *F32:
   179  		var destS *F32
   180  		if destS, ok = dest.(*F32); !ok {
   181  			return nil, errors.Errorf("Expected dest to be *F32. Got %T instead", dest)
   182  		}
   183  		*destS = *srcT
   184  		return destS, nil
   185  	case *I:
   186  		var destS *I
   187  		if destS, ok = dest.(*I); !ok {
   188  			return nil, errors.Errorf("Expected dest to be *I) . Got %T instead", dest)
   189  		}
   190  		*destS = *srcT
   191  		return destS, nil
   192  	case *I64:
   193  		var destS *I64
   194  		if destS, ok = dest.(*I64); !ok {
   195  			return nil, errors.Errorf("Expected dest to be *I64. Got %T instead", dest)
   196  		}
   197  		*destS = *srcT
   198  		return destS, nil
   199  	case *I32:
   200  		var destS *I32
   201  		if destS, ok = dest.(*I32); !ok {
   202  			return nil, errors.Errorf("Expected dest to be *I32. Got %T instead", dest)
   203  		}
   204  		*destS = *srcT
   205  		return destS, nil
   206  	case *U8:
   207  		var destS *U8
   208  		if destS, ok = dest.(*U8); !ok {
   209  			return nil, errors.Errorf("Expected dest to be *U8). Got %T instead", dest)
   210  		}
   211  		*destS = *srcT
   212  		return destS, nil
   213  	case *B:
   214  		var destS *B
   215  		if destS, ok = dest.(*B); !ok {
   216  			return nil, errors.Errorf("Expected dest to be *B) . Got %T instead", dest)
   217  		}
   218  		*destS = *srcT
   219  		return destS, nil
   220  	case tensor.Tensor:
   221  		var destT tensor.Tensor
   222  		if destT, ok = dest.(tensor.Tensor); !ok {
   223  			return nil, errors.Errorf("Expected dest to be a tensor.Tensor. Got %T instead", dest)
   224  		}
   225  		err := tensor.Copy(destT, srcT)
   226  		return dest, err
   227  	case CopierTo:
   228  		err := srcT.CopyTo(dest)
   229  		return dest, err
   230  	default:
   231  		var copyFrom CopierFrom
   232  		if copyFrom, ok = dest.(CopierFrom); ok {
   233  			err := copyFrom.CopyFrom(src)
   234  			return dest, err
   235  		}
   236  		return nil, errors.Errorf("Unable to copy value of type %T into value of type %T", src, dest)
   237  	}
   238  }
   239  
   240  func setEngine(v Value, e tensor.Engine) {
   241  	switch vv := v.(type) {
   242  	case *dualValue:
   243  		setEngine(vv.Value, e)
   244  		setEngine(vv.d, e)
   245  	case tensor.Tensor:
   246  		tensor.WithEngine(e)(vv)
   247  	}
   248  }