gorgonia.org/gorgonia@v0.9.17/differentiation.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/pkg/errors" 5 "gonum.org/v1/gonum/graph" 6 ) 7 8 /* 9 This file holds code for symbolic differentiation. 10 The purpose of the symbolic differentiation is to analyze and prepare the nodes for automatic differentiation. 11 12 The main function that does all the magic is in Backpropagate(). 13 14 15 see also: http://colah.github.io/posts/2015-08-Backprop/ 16 */ 17 18 // forwardDiffAnalysis returns the nodes that affect outputs. 19 // 20 // Given a list of outputs, we want to know which nodes will affect the output 21 func forwardDiffAnalysis(outputs, sortedNodes Nodes) (retVal NodeSet, err error) { 22 symdiffLogf("Forward analysis. Already sorted?") 23 enterLogScope() 24 defer leaveLogScope() 25 26 if !outputs.AllSameGraph() { 27 return nil, errors.New("The supplied output Nodes are not the same graph") 28 } 29 30 diffSet := outputs.mapSet() 31 32 symdiffLogf("Diff Set: %v", diffSet) 33 symdiffLogf("%d", sortedNodes) 34 for _, n := range sortedNodes { 35 if diffSet.Contains(n) && !n.isInput() { 36 diffs := n.diffWRT() 37 for j, child := range n.children { 38 d := diffs[j] 39 if d { 40 symdiffLogf("Adding %x to differentiable set", child.ID()) 41 diffSet.Add(child) 42 } 43 } 44 } 45 } 46 return diffSet, nil 47 } 48 49 // backwardDiffAnalysis returns a list of Nodes that are affected by differentiating output. 50 // Given a list of WRTs, we want to find a list of nodes that will be affected when backpropagating. 51 func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) { 52 symdiffLogf("Backwards analysis") 53 enterLogScope() 54 defer leaveLogScope() 55 56 if !wrt.AllSameGraph() { 57 return nil, errors.New("The supplied output Nodes are not the same graph") 58 } 59 60 diffSet := wrt.mapSet() 61 symdiffLogf("wrt:%d diffset: %d", len(wrt), len(diffSet)) 62 symdiffLogf("%v", diffSet) 63 symdiffLogf("sorted: %d", sortedNodes) 64 65 enterLogScope() 66 for i := len(sortedNodes) - 1; i >= 0; i-- { 67 n := sortedNodes[i] 68 symdiffLogf("working on %v. Has %d children", n, len(n.children)) 69 70 var op SDOp 71 var ok bool 72 var diffs []bool 73 if op, ok = n.op.(SDOp); ok { 74 diffs = op.DiffWRT(len(n.children)) 75 } 76 77 symdiffLogf("differentiable WRT: %v", diffs) 78 enterLogScope() 79 symdiffLogf("Children: %v", n.children) 80 if len(diffs) == 0 { 81 // check if this makes nodes unreachable. If it does, then error out 82 if n.isStmt { 83 symdiffLogf("Statement nodes are Non differentiable!") 84 leaveLogScope() 85 continue 86 } else if n.isInput() { 87 symdiffLogf("Input nodes are Non differentiable") 88 leaveLogScope() 89 continue 90 } else if len(n.children) == 0 { 91 symdiffLogf("Leaf nodes have no children") 92 leaveLogScope() 93 continue 94 } 95 g := n.g 96 for _, child := range n.children { 97 parents := graph.NodesOf(g.To(child.ID())) 98 if len(parents) == 1 && len(child.children) > 0 { 99 leaveLogScope() 100 return nil, errors.Errorf("Being unable to differentiate %v would leave a portion of the graph unreachable. Unable to continue", n) 101 } 102 } 103 symdiffLogf("SKIPPING... Non differentiable!") 104 leaveLogScope() 105 continue 106 } 107 108 inner: 109 for j, child := range n.children { 110 d := diffs[j] 111 if diffSet.Contains(child) && d { 112 symdiffLogf("Adding %x to differentiable set", child.ID()) 113 diffSet.Add(n) 114 break inner 115 } 116 } 117 leaveLogScope() 118 } 119 leaveLogScope() 120 return diffSet, nil 121 } 122 123 // Backpropagate backpropagates errors by performing reverse-mode symbolic differentiation, starting from the outputs, and working its way towads the inputs. 124 // 125 // This is the rough algorithm: 126 // 1. Filter out nodes that are unreachable 127 // 2. Forwards analysis, where a list of nodes affecting the output is added to consideration 128 // 3. Backwards analysis, where a list of nodes affected by differentiating the output are added to the consideration 129 // 4. If there is a difference in both sets, it will cause an error (both sets should be the same) 130 // 5. Traverse the graph from output towards input. On each visit, perform the symbolic differentiation 131 // 132 // For most cases, Grad() should be used instead of Backpropagate(), as Grad() performs several checks which would be the general use case, before calling Backpropagate() 133 func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) { 134 symdiffLogf("BACKPROP START") 135 symdiffLogf("Outputs: %d", outputs) 136 symdiffLogf("gradOutputs: %d", gradOutputs) 137 symdiffLogf("WRT: %d", wrt) 138 139 enterLogScope() 140 defer leaveLogScope() 141 142 g := outputs[0].g 143 144 // this entire section about removing foreveralone nodes need a rethink 145 symdiffLogf("removing foreveralone nodes") 146 enterLogScope() 147 for i := 0; i < len(g.AllNodes()); i++ { 148 n := g.AllNodes()[i] 149 150 fr := g.From(n.ID()).Len() 151 to := g.To(n.ID()).Len() 152 153 if fr == 0 && to == 0 && !n.isConstant() && !n.isInput() { 154 g.RemoveNode(n) 155 symdiffLogf("removed %v(%p); %x; %s", n, n, n.ID(), n.Name()) 156 } 157 } 158 leaveLogScope() 159 160 var sortedNodes Nodes 161 if sortedNodes, err = Sort(g); err != nil { 162 return nil, errors.Wrap(err, sortFail) 163 } 164 symdiffLogf("sorted nodes: %v", sortedNodes) 165 symdiffLogf("sorted nodes: %d", sortedNodes) 166 167 var affectsOutput NodeSet 168 var affectedByOutput NodeSet 169 if affectsOutput, err = forwardDiffAnalysis(outputs, sortedNodes); err != nil { 170 return nil, errors.Wrap(err, "Failed during forward differentiation analysis") 171 } 172 173 if affectedByOutput, err = backwardDiffAnalysis(wrt, sortedNodes); err != nil { 174 return nil, errors.Wrap(err, "Failed during forward differentiation analysis") 175 } 176 177 symdiffLogf("affects output: %v", affectsOutput) 178 symdiffLogf("affected by output : %v", affectedByOutput) 179 180 wrtSet := wrt.mapSet() 181 badWRTs := wrtSet.Difference(affectsOutput) 182 if len(badWRTs) > 0 { 183 return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.Errorf("Non Differentiable WRTs: %v", badWRTs)} 184 } 185 186 outputSet := outputs.mapSet() 187 badOutputs := outputSet.Difference(affectedByOutput) 188 if len(badOutputs) > 0 { 189 symdiffLogf("badOutputs: %#v", badOutputs) 190 return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.Errorf("Non-Differentable Outputs: %v", badOutputs)} 191 } 192 193 // map a node to a list of gradient terms 194 // these gradient terms will be summed up when we visit the node 195 // when iterating through the nondes in reverse topological order 196 nodeGradMap := make(map[*Node]Nodes) 197 for i, n := range outputs { 198 symdiffLogf("Adding outputs for %x", n.ID()) 199 nodeGradMap[n] = Nodes{gradOutputs[i]} 200 } 201 202 // "active" nodes are the ones that are differentially influenced by the inputs 203 // and also differentiably influence the outputs. These are the nodes where we need to call the 204 // "pullback" function to backpropagate derivatives 205 activeNodes := affectsOutput.Intersect(affectedByOutput) 206 207 symdiffLogf("Active: %v", activeNodes) 208 209 symdiffLogf("Sorted: %d", sortedNodes) 210 symdiffLogf("nodeGradMap: %+#d", FmtNodeMap(nodeGradMap)) 211 enterLogScope() 212 213 for _, node := range sortedNodes { 214 if _, ok := activeNodes[node]; !ok { 215 symdiffLogf("skipping %x", node.ID()) 216 continue 217 } 218 219 if node.deriv != nil { 220 symdiffLogf("skipping %x - previously differentiated", node.ID()) 221 nodeGradMap[node] = append(nodeGradMap[node], node.deriv) 222 continue 223 } 224 225 symdiffLogf("Working on %x %v", node.ID(), node) 226 enterLogScope() 227 228 // Check if there is any grads coming into this node 229 if len(nodeGradMap[node]) < 1 { 230 leaveLogScope() 231 return nil, SymDiffError{ 232 single: node, 233 gradMap: nodeGradMap, 234 err: errors.New("No gradients found for node"), 235 } 236 } 237 238 // once we've reached a node, we already backpropagated from its dependents 239 // so we sum up the gradients 240 symdiffLogf("nodeGradMap[%x]: %d", node.ID(), nodeGradMap[node]) 241 if len(nodeGradMap[node]) > 1 { 242 var n *Node 243 symdiffLogf("reduce adding") 244 if n, err = ReduceAdd(nodeGradMap[node], WithGroupName(gradClust)); err != nil { 245 leaveLogScope() 246 return nil, SymDiffError{ 247 single: node, 248 nodes: nodeGradMap[node], 249 gradMap: nodeGradMap, 250 err: errors.Wrap(err, "ReduceAdd failed during differentiation"), 251 } 252 253 } 254 symdiffLogf("reduced to... %x", n.ID()) 255 // node.derives = append(node.derives, n) 256 n.derivOf = append(n.derivOf, node) 257 node.deriv = n 258 nodeGradMap[node] = Nodes{n} 259 // } 260 } else if len(nodeGradMap[node]) == 1 { 261 deriv := nodeGradMap[node][0] 262 deriv.derivOf = append(deriv.derivOf, node) 263 node.deriv = deriv 264 } 265 266 gradNode := nodeGradMap[node][0] 267 if !node.isInput() { 268 symdiffLogf("differentiating %x (%v)", node.ID(), node.op) 269 enterLogScope() 270 271 var op SDOp 272 var childrenGrads Nodes 273 var ok bool 274 275 if op, ok = node.op.(SDOp); !ok { 276 return nil, SymDiffError{ 277 single: node, 278 err: errors.New("Not a SymDifOp"), 279 } 280 } 281 282 symdiffLogf("op: %v || optype: %v || node: %v || Children: %#Y || Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode) 283 if childrenGrads, err = op.SymDiff(node.children, node, gradNode); err != nil { 284 leaveLogScope() 285 return nil, SymDiffError{ 286 single: node, 287 grad: gradNode, 288 gradMap: nodeGradMap, 289 err: errors.Wrapf(err, ".SymDiff() failed"), 290 } 291 } 292 293 symdiffLogf("Derived(%d): %P", len(childrenGrads), childrenGrads) 294 leaveLogScope() 295 296 diffs := node.diffWRT() 297 for i, child := range node.children { 298 symdiffLogf("child is %v, i: %v", child, i) 299 differentiable := diffs[i] 300 childGrad := childrenGrads[i] 301 302 if differentiable { 303 childGrad.setGroup(gradClust) 304 if grads, ok := nodeGradMap[child]; ok { 305 grads = append(grads, childGrad) 306 nodeGradMap[child] = grads 307 } else { 308 nodeGradMap[child] = Nodes{childGrad} 309 } 310 } else { 311 symdiffLogf("Child %x is non differentiable", child.ID()) 312 if childGrad != nil { 313 childGrad.setGroup(strayClust) 314 } 315 } 316 } 317 } else { 318 symdiffLogf("iz input") 319 symdiffLogf("%d ", nodeGradMap[node]) 320 } 321 leaveLogScope() 322 323 } 324 leaveLogScope() 325 // only we already summed up the gradients for the input nodes, so just take 326 // 0th element 327 for _, n := range wrt { 328 symdiffLogf("nodeGradMap wrt: %d", nodeGradMap[n]) 329 retVal = append(retVal, nodeGradMap[n][0]) 330 } 331 return 332 } 333 334 // SetDerivOf is used to hack around the fundamental limitations of Gorgonia. 335 // 336 // Specifically it is used to set a node as the derivative of another node, 337 // used in the cuDNN version of batch norm. 338 // 339 // The cuDNN BatchNorm operation produces the derivatives for the scale and bias as a side effect 340 // of calculating the derivative of the input. Because Gorgonia's Ops are modelled as pure functions (and no tuples) 341 // this causes a bit of trouble. With the clever use of scratch space ops multireturn can be simulated. 342 // But this causes derivatives to not be set correctly. 343 func SetDerivOf(deriv, of *Node) { 344 deriv.derivOf = append(deriv.derivOf, of) 345 of.deriv = deriv 346 }