gorgonia.org/gorgonia@v0.9.17/ermagerdmonards.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 var ( 8 _ Result = (*Node)(nil) 9 _ Result = (Nodes)(nil) 10 _ Result = gErr{} 11 ) 12 13 // Result is either a Node or Nodes or error. It's a poor man's sum types and it's not sealed for good reason 14 type Result interface { 15 Input 16 Errer 17 } 18 19 // Input is something that can produce both a *Node and Nodes. Returning nil is OK. 20 type Input interface { 21 Node() *Node 22 Nodes() Nodes 23 } 24 25 // Errer is an interface that can return an error. 26 type Errer interface { 27 Err() error 28 } 29 30 // Mker is an interface of any Input that can make a new version of itself 31 type Mker interface { 32 Mk(...Input) Input 33 } 34 35 // Lift1 decorates a function with a precheck and post function lifting 36 func Lift1(fn func(a *Node) (*Node, error)) func(a Input) Result { 37 return func(a Input) Result { 38 if err := CheckOne(a); err != nil { 39 return Err(errors.WithStack(err)) 40 } 41 return TransformResult(a)(fn(a.Node())) 42 } 43 } 44 45 // Lift1Axial decorates a function with a precheck and post function lifting 46 func Lift1Axial(fn func(a *Node, axes ...int) (*Node, error)) func(a Input, axes ...int) Result { 47 return func(a Input, axes ...int) Result { 48 if err := CheckOne(a); err != nil { 49 return Err(errors.WithStack(err)) 50 } 51 return TransformResult(a)(fn(a.Node(), axes...)) 52 } 53 } 54 55 // Lift2 decorates a function with a precheck and post function lifting 56 func Lift2(fn func(a, b *Node) (*Node, error)) func(a, b Input) Result { 57 return func(a, b Input) Result { 58 if err := CheckOne(a); err != nil { 59 return Err(errors.WithStack(err)) 60 } 61 if err := CheckOne(b); err != nil { 62 return Err(errors.WithStack(err)) 63 } 64 return TransformResult(a, b)(fn(a.Node(), b.Node())) 65 } 66 } 67 68 // Lift2Broadcast decorates a function with a precheck and post function lifting 69 func Lift2Broadcast(fn func(a, b *Node, pat1, pat2 []byte) (*Node, error)) func(a, b Input, pat1, pat2 []byte) Result { 70 return func(a, b Input, pat1, pat2 []byte) Result { 71 if err := CheckOne(a); err != nil { 72 return Err(errors.WithStack(err)) 73 } 74 if err := CheckOne(b); err != nil { 75 return Err(errors.WithStack(err)) 76 } 77 return TransformResult(a, b)(fn(a.Node(), b.Node(), pat1, pat2)) 78 } 79 } 80 81 // gErr implements Result and error. 82 type gErr struct{ error } 83 84 // Err is a function that returns a gErr. It wraps errors with stack information. 85 // A gErr implements Result, as well as error. 86 // This way, the Err() method acts as an unwrapper. 87 func Err(e error) gErr { return gErr{errors.WithStack(e)} } 88 89 func (err gErr) Node() *Node { return nil } 90 func (err gErr) Nodes() Nodes { return nil } 91 func (err gErr) Err() error { return err.error } 92 93 // resultM is a wrapper for Input to create a Result. This is the default Result if an unknown Input was passed in. 94 type resultM struct{ Input } 95 96 func (r resultM) Err() error { return nil } 97 98 // LiftResult creates a Result from a Input and error pair. 99 // If the error is not nil, the Input is discarded. 100 // 101 // The usual use case is in a function that returns a `(*Node, error)`. 102 // e.g LiftResult(Add(a, b)) 103 func LiftResult(a Input, err error) Result { 104 if err != nil { 105 return Err(err) 106 } 107 switch at := a.(type) { 108 case Result: 109 return at 110 default: 111 return resultM{a} 112 } 113 } 114 115 // TransformResult is like LiftResult, but allows for custom data types that fulfil Mker 116 func TransformResult(ins ...Input) func(a Input, err error) Result { 117 return func(a Input, err error) Result { 118 if err != nil { 119 return Err(err) 120 } 121 for _, in := range ins { 122 if mk, ok := in.(Mker); ok { 123 a = mk.Mk(a) 124 } 125 } 126 switch at := a.(type) { 127 case Result: 128 return at 129 default: 130 return resultM{a} 131 } 132 } 133 } 134 135 // CheckOne checks whether an input is an error 136 func CheckOne(in Input) error { 137 if errer, ok := in.(Errer); ok && errer.Err() != nil { 138 return errer.Err() 139 } 140 return nil 141 } 142 143 // NodesFromInputs creates a Nodes from a list of Input. 144 func NodesFromInputs(xs ...Input) (Nodes, error) { 145 for i := range xs { 146 if err := CheckOne(xs[i]); err != nil { 147 return nil, errors.Wrapf(err, "NodesFromInputs %dth input", i) 148 } 149 // check if the Input is a *Node 150 if xs[i].Node() == nil { 151 return nil, errors.Errorf("Input %d is not a *Node", i) 152 } 153 } 154 155 retVal := make(Nodes, len(xs)) 156 for i := range xs { 157 retVal[i] = xs[i].Node() 158 } 159 return retVal, nil 160 }