gorgonia.org/gorgonia@v0.9.17/equalities.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"gorgonia.org/dawson"
     5  	"gorgonia.org/tensor"
     6  )
     7  
     8  func scalarEq(a, b Scalar) bool {
     9  	switch at := a.(type) {
    10  	case *F64:
    11  		if bt, ok := b.(*F64); ok {
    12  			if at == bt {
    13  				return true
    14  			}
    15  			return *at == *bt
    16  		}
    17  		return false
    18  	case *F32:
    19  		if bt, ok := b.(*F32); ok {
    20  			if at == bt {
    21  				return true
    22  			}
    23  			return *at == *bt
    24  		}
    25  		return false
    26  	case *I:
    27  		if bt, ok := b.(*I); ok {
    28  			if at == bt {
    29  				return true
    30  			}
    31  			return *at == *bt
    32  		}
    33  		return false
    34  	case *I32:
    35  		if bt, ok := b.(*I32); ok {
    36  			if at == bt {
    37  				return true
    38  			}
    39  			return *at == *bt
    40  		}
    41  		return false
    42  	case *I64:
    43  		if bt, ok := b.(*I64); ok {
    44  			if at == bt {
    45  				return true
    46  			}
    47  			return *at == *bt
    48  		}
    49  		return false
    50  	case *U8:
    51  		if bt, ok := b.(*U8); ok {
    52  			if at == bt {
    53  				return true
    54  			}
    55  			return *at == *bt
    56  		}
    57  		return false
    58  	case *B:
    59  		if bt, ok := b.(*B); ok {
    60  			if at == bt {
    61  				return true
    62  			}
    63  			return *at == *bt
    64  		}
    65  		return false
    66  	}
    67  	return false
    68  }
    69  
    70  func scalarClose(a, b Scalar) bool {
    71  	switch at := a.(type) {
    72  	case *F64:
    73  		if bt, ok := b.(*F64); ok {
    74  			return dawson.CloseF64(float64(*at), float64(*bt))
    75  		}
    76  		return false
    77  	case *F32:
    78  		if bt, ok := b.(*F32); ok {
    79  			return dawson.CloseF32(float32(*at), float32(*bt))
    80  		}
    81  		return false
    82  	default:
    83  		return scalarEq(a, b)
    84  	}
    85  }
    86  
    87  func tensorClose(a, b tensor.Tensor) bool {
    88  	aDt := a.Dtype()
    89  	bDt := b.Dtype()
    90  	if aDt != bDt {
    91  		return false
    92  	}
    93  
    94  	switch aDt {
    95  	case tensor.Float64:
    96  		aFs := a.Data().([]float64)
    97  		bFs := b.Data().([]float64)
    98  		if len(aFs) != len(bFs) {
    99  			return false
   100  		}
   101  		aFs = aFs[:]
   102  		bFs = bFs[:len(aFs)]
   103  		for i, v := range aFs {
   104  			if !dawson.CloseF64(v, bFs[i]) {
   105  				return false
   106  			}
   107  		}
   108  		return true
   109  	case tensor.Float32:
   110  		aFs := a.Data().([]float32)
   111  		bFs := b.Data().([]float32)
   112  		if len(aFs) != len(bFs) {
   113  			return false
   114  		}
   115  		aFs = aFs[:]
   116  		bFs = bFs[:len(aFs)]
   117  		for i, v := range aFs {
   118  			if !dawson.CloseF32(v, bFs[i]) {
   119  				return false
   120  			}
   121  		}
   122  		return true
   123  	default:
   124  		return a.Eq(b)
   125  	}
   126  
   127  }
   128  
   129  /*
   130  func axesEq(a, b axes) bool {
   131  	if len(a) != len(b) {
   132  		return false
   133  	}
   134  
   135  	for i, s := range a {
   136  		if b[i] != s {
   137  			return false
   138  		}
   139  	}
   140  	return true
   141  }
   142  
   143  // yes it's exactly the same as axesEq
   144  func coordEq(a, b coordinates) bool {
   145  	if len(a) != len(b) {
   146  		return false
   147  	}
   148  
   149  	for i, s := range a {
   150  		if b[i] != s {
   151  			return false
   152  		}
   153  	}
   154  	return true
   155  }
   156  */
   157  
   158  func constEq(a, b constant) (ok bool) {
   159  	switch at := a.(type) {
   160  	case constantScalar:
   161  		var bt constantScalar
   162  		if bt, ok = b.(constantScalar); !ok {
   163  			return
   164  		}
   165  
   166  		return bt == at
   167  	case constantTensor:
   168  		var bt constantTensor
   169  		if bt, ok = b.(constantTensor); !ok {
   170  			return
   171  		}
   172  		return at.v.Eq(bt.v)
   173  	default:
   174  		panic("Not yet implemented")
   175  	}
   176  }
   177  
   178  // fastest comparisons to least fastest
   179  func nodeEq(a, b *Node) bool {
   180  	if a == b {
   181  		return true
   182  	}
   183  
   184  	if a.isInput() {
   185  		if !b.isInput() {
   186  			return false
   187  		}
   188  		return a.name == b.name
   189  	}
   190  
   191  	if b.isInput() {
   192  		return false
   193  	}
   194  
   195  	// hashcode is good for comparing Op (TODO: benchmark this vs reflect.DeepEq)
   196  	if a.op.Hashcode() != b.op.Hashcode() {
   197  		return false
   198  	}
   199  
   200  	if len(a.children) != len(b.children) {
   201  		return false
   202  	}
   203  
   204  	if a.t != b.t {
   205  		return false
   206  	}
   207  
   208  	if !a.shape.Eq(b.shape) {
   209  		return false
   210  	}
   211  
   212  	return true
   213  }