gorgonia.org/gorgonia@v0.9.17/differentiation.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gonum.org/v1/gonum/graph"
     6  )
     7  
     8  /*
     9  This file holds code for symbolic differentiation.
    10  The purpose of the symbolic differentiation is to analyze and prepare the nodes for automatic differentiation.
    11  
    12  The main function that does all the magic is in Backpropagate().
    13  
    14  
    15  see also: http://colah.github.io/posts/2015-08-Backprop/
    16  */
    17  
    18  // forwardDiffAnalysis returns the nodes that affect outputs.
    19  //
    20  // Given a list of outputs, we want to know which nodes will affect the output
    21  func forwardDiffAnalysis(outputs, sortedNodes Nodes) (retVal NodeSet, err error) {
    22  	symdiffLogf("Forward analysis. Already sorted?")
    23  	enterLogScope()
    24  	defer leaveLogScope()
    25  
    26  	if !outputs.AllSameGraph() {
    27  		return nil, errors.New("The supplied output Nodes are not the same graph")
    28  	}
    29  
    30  	diffSet := outputs.mapSet()
    31  
    32  	symdiffLogf("Diff Set: %v", diffSet)
    33  	symdiffLogf("%d", sortedNodes)
    34  	for _, n := range sortedNodes {
    35  		if diffSet.Contains(n) && !n.isInput() {
    36  			diffs := n.diffWRT()
    37  			for j, child := range n.children {
    38  				d := diffs[j]
    39  				if d {
    40  					symdiffLogf("Adding %x to  differentiable set", child.ID())
    41  					diffSet.Add(child)
    42  				}
    43  			}
    44  		}
    45  	}
    46  	return diffSet, nil
    47  }
    48  
    49  // backwardDiffAnalysis returns a list of Nodes that are affected by differentiating output.
    50  // Given a list of WRTs, we want to find a list of nodes that will be affected when backpropagating.
    51  func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) {
    52  	symdiffLogf("Backwards analysis")
    53  	enterLogScope()
    54  	defer leaveLogScope()
    55  
    56  	if !wrt.AllSameGraph() {
    57  		return nil, errors.New("The supplied output Nodes are not the same graph")
    58  	}
    59  
    60  	diffSet := wrt.mapSet()
    61  	symdiffLogf("wrt:%d diffset: %d", len(wrt), len(diffSet))
    62  	symdiffLogf("%v", diffSet)
    63  	symdiffLogf("sorted: %d", sortedNodes)
    64  
    65  	enterLogScope()
    66  	for i := len(sortedNodes) - 1; i >= 0; i-- {
    67  		n := sortedNodes[i]
    68  		symdiffLogf("working on %v. Has %d children", n, len(n.children))
    69  
    70  		var op SDOp
    71  		var ok bool
    72  		var diffs []bool
    73  		if op, ok = n.op.(SDOp); ok {
    74  			diffs = op.DiffWRT(len(n.children))
    75  		}
    76  
    77  		symdiffLogf("differentiable WRT: %v", diffs)
    78  		enterLogScope()
    79  		symdiffLogf("Children: %v", n.children)
    80  		if len(diffs) == 0 {
    81  			// check if this makes nodes unreachable. If it does, then error out
    82  			if n.isStmt {
    83  				symdiffLogf("Statement nodes are Non differentiable!")
    84  				leaveLogScope()
    85  				continue
    86  			} else if n.isInput() {
    87  				symdiffLogf("Input nodes are Non differentiable")
    88  				leaveLogScope()
    89  				continue
    90  			} else if len(n.children) == 0 {
    91  				symdiffLogf("Leaf nodes have no children")
    92  				leaveLogScope()
    93  				continue
    94  			}
    95  			g := n.g
    96  			for _, child := range n.children {
    97  				parents := graph.NodesOf(g.To(child.ID()))
    98  				if len(parents) == 1 && len(child.children) > 0 {
    99  					leaveLogScope()
   100  					return nil, errors.Errorf("Being unable to differentiate %v would leave a portion of the graph unreachable. Unable to continue", n)
   101  				}
   102  			}
   103  			symdiffLogf("SKIPPING... Non differentiable!")
   104  			leaveLogScope()
   105  			continue
   106  		}
   107  
   108  	inner:
   109  		for j, child := range n.children {
   110  			d := diffs[j]
   111  			if diffSet.Contains(child) && d {
   112  				symdiffLogf("Adding %x to differentiable set", child.ID())
   113  				diffSet.Add(n)
   114  				break inner
   115  			}
   116  		}
   117  		leaveLogScope()
   118  	}
   119  	leaveLogScope()
   120  	return diffSet, nil
   121  }
   122  
   123  // Backpropagate backpropagates errors by performing reverse-mode symbolic differentiation, starting from the outputs, and working its way towads the inputs.
   124  //
   125  // This is the rough algorithm:
   126  //		1. Filter out nodes that are unreachable
   127  //		2. Forwards analysis, where a list of nodes affecting the output is added to consideration
   128  //		3. Backwards analysis, where a list of nodes affected by differentiating the output are added to the consideration
   129  //		4. If there is a difference in both sets, it will cause an error (both sets should be the same)
   130  //		5. Traverse the graph from output towards input. On each visit, perform the symbolic differentiation
   131  //
   132  // For most cases, Grad() should be used instead of Backpropagate(), as Grad() performs several checks which would be the general use case, before calling Backpropagate()
   133  func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
   134  	symdiffLogf("BACKPROP START")
   135  	symdiffLogf("Outputs: %d", outputs)
   136  	symdiffLogf("gradOutputs: %d", gradOutputs)
   137  	symdiffLogf("WRT: %d", wrt)
   138  
   139  	enterLogScope()
   140  	defer leaveLogScope()
   141  
   142  	g := outputs[0].g
   143  
   144  	// this entire section about removing foreveralone nodes need a rethink
   145  	symdiffLogf("removing foreveralone nodes")
   146  	enterLogScope()
   147  	for i := 0; i < len(g.AllNodes()); i++ {
   148  		n := g.AllNodes()[i]
   149  
   150  		fr := g.From(n.ID()).Len()
   151  		to := g.To(n.ID()).Len()
   152  
   153  		if fr == 0 && to == 0 && !n.isConstant() && !n.isInput() {
   154  			g.RemoveNode(n)
   155  			symdiffLogf("removed %v(%p); %x; %s", n, n, n.ID(), n.Name())
   156  		}
   157  	}
   158  	leaveLogScope()
   159  
   160  	var sortedNodes Nodes
   161  	if sortedNodes, err = Sort(g); err != nil {
   162  		return nil, errors.Wrap(err, sortFail)
   163  	}
   164  	symdiffLogf("sorted nodes: %v", sortedNodes)
   165  	symdiffLogf("sorted nodes: %d", sortedNodes)
   166  
   167  	var affectsOutput NodeSet
   168  	var affectedByOutput NodeSet
   169  	if affectsOutput, err = forwardDiffAnalysis(outputs, sortedNodes); err != nil {
   170  		return nil, errors.Wrap(err, "Failed during forward differentiation analysis")
   171  	}
   172  
   173  	if affectedByOutput, err = backwardDiffAnalysis(wrt, sortedNodes); err != nil {
   174  		return nil, errors.Wrap(err, "Failed during forward differentiation analysis")
   175  	}
   176  
   177  	symdiffLogf("affects output: %v", affectsOutput)
   178  	symdiffLogf("affected by output : %v", affectedByOutput)
   179  
   180  	wrtSet := wrt.mapSet()
   181  	badWRTs := wrtSet.Difference(affectsOutput)
   182  	if len(badWRTs) > 0 {
   183  		return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.Errorf("Non Differentiable WRTs: %v", badWRTs)}
   184  	}
   185  
   186  	outputSet := outputs.mapSet()
   187  	badOutputs := outputSet.Difference(affectedByOutput)
   188  	if len(badOutputs) > 0 {
   189  		symdiffLogf("badOutputs: %#v", badOutputs)
   190  		return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.Errorf("Non-Differentable Outputs: %v", badOutputs)}
   191  	}
   192  
   193  	// map a node to a list of gradient terms
   194  	// these  gradient terms will be summed up when we visit the node
   195  	// when iterating through the nondes in reverse topological order
   196  	nodeGradMap := make(map[*Node]Nodes)
   197  	for i, n := range outputs {
   198  		symdiffLogf("Adding outputs for %x", n.ID())
   199  		nodeGradMap[n] = Nodes{gradOutputs[i]}
   200  	}
   201  
   202  	// "active" nodes are the ones that are differentially influenced by the inputs
   203  	// and also differentiably influence the outputs. These are the nodes where we need to call the
   204  	// "pullback" function to backpropagate derivatives
   205  	activeNodes := affectsOutput.Intersect(affectedByOutput)
   206  
   207  	symdiffLogf("Active: %v", activeNodes)
   208  
   209  	symdiffLogf("Sorted: %d", sortedNodes)
   210  	symdiffLogf("nodeGradMap: %+#d", FmtNodeMap(nodeGradMap))
   211  	enterLogScope()
   212  
   213  	for _, node := range sortedNodes {
   214  		if _, ok := activeNodes[node]; !ok {
   215  			symdiffLogf("skipping %x", node.ID())
   216  			continue
   217  		}
   218  
   219  		if node.deriv != nil {
   220  			symdiffLogf("skipping %x - previously differentiated", node.ID())
   221  			nodeGradMap[node] = append(nodeGradMap[node], node.deriv)
   222  			continue
   223  		}
   224  
   225  		symdiffLogf("Working on %x %v", node.ID(), node)
   226  		enterLogScope()
   227  
   228  		// Check if there is any grads coming into this node
   229  		if len(nodeGradMap[node]) < 1 {
   230  			leaveLogScope()
   231  			return nil, SymDiffError{
   232  				single:  node,
   233  				gradMap: nodeGradMap,
   234  				err:     errors.New("No gradients found for node"),
   235  			}
   236  		}
   237  
   238  		// once we've reached a node, we already backpropagated from its dependents
   239  		// so we sum up the gradients
   240  		symdiffLogf("nodeGradMap[%x]: %d", node.ID(), nodeGradMap[node])
   241  		if len(nodeGradMap[node]) > 1 {
   242  			var n *Node
   243  			symdiffLogf("reduce adding")
   244  			if n, err = ReduceAdd(nodeGradMap[node], WithGroupName(gradClust)); err != nil {
   245  				leaveLogScope()
   246  				return nil, SymDiffError{
   247  					single:  node,
   248  					nodes:   nodeGradMap[node],
   249  					gradMap: nodeGradMap,
   250  					err:     errors.Wrap(err, "ReduceAdd failed during differentiation"),
   251  				}
   252  
   253  			}
   254  			symdiffLogf("reduced to... %x", n.ID())
   255  			// node.derives = append(node.derives, n)
   256  			n.derivOf = append(n.derivOf, node)
   257  			node.deriv = n
   258  			nodeGradMap[node] = Nodes{n}
   259  			// }
   260  		} else if len(nodeGradMap[node]) == 1 {
   261  			deriv := nodeGradMap[node][0]
   262  			deriv.derivOf = append(deriv.derivOf, node)
   263  			node.deriv = deriv
   264  		}
   265  
   266  		gradNode := nodeGradMap[node][0]
   267  		if !node.isInput() {
   268  			symdiffLogf("differentiating %x (%v)", node.ID(), node.op)
   269  			enterLogScope()
   270  
   271  			var op SDOp
   272  			var childrenGrads Nodes
   273  			var ok bool
   274  
   275  			if op, ok = node.op.(SDOp); !ok {
   276  				return nil, SymDiffError{
   277  					single: node,
   278  					err:    errors.New("Not a SymDifOp"),
   279  				}
   280  			}
   281  
   282  			symdiffLogf("op: %v || optype: %v ||  node: %v || Children: %#Y || Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)
   283  			if childrenGrads, err = op.SymDiff(node.children, node, gradNode); err != nil {
   284  				leaveLogScope()
   285  				return nil, SymDiffError{
   286  					single:  node,
   287  					grad:    gradNode,
   288  					gradMap: nodeGradMap,
   289  					err:     errors.Wrapf(err, ".SymDiff() failed"),
   290  				}
   291  			}
   292  
   293  			symdiffLogf("Derived(%d): %P", len(childrenGrads), childrenGrads)
   294  			leaveLogScope()
   295  
   296  			diffs := node.diffWRT()
   297  			for i, child := range node.children {
   298  				symdiffLogf("child is %v, i: %v", child, i)
   299  				differentiable := diffs[i]
   300  				childGrad := childrenGrads[i]
   301  
   302  				if differentiable {
   303  					childGrad.setGroup(gradClust)
   304  					if grads, ok := nodeGradMap[child]; ok {
   305  						grads = append(grads, childGrad)
   306  						nodeGradMap[child] = grads
   307  					} else {
   308  						nodeGradMap[child] = Nodes{childGrad}
   309  					}
   310  				} else {
   311  					symdiffLogf("Child %x is non differentiable", child.ID())
   312  					if childGrad != nil {
   313  						childGrad.setGroup(strayClust)
   314  					}
   315  				}
   316  			}
   317  		} else {
   318  			symdiffLogf("iz input")
   319  			symdiffLogf("%d ", nodeGradMap[node])
   320  		}
   321  		leaveLogScope()
   322  
   323  	}
   324  	leaveLogScope()
   325  	// only we already summed up the gradients for the input nodes, so just take
   326  	// 0th element
   327  	for _, n := range wrt {
   328  		symdiffLogf("nodeGradMap wrt: %d", nodeGradMap[n])
   329  		retVal = append(retVal, nodeGradMap[n][0])
   330  	}
   331  	return
   332  }
   333  
   334  // SetDerivOf is used to hack around the fundamental limitations of Gorgonia.
   335  //
   336  // Specifically it is used to set a node as the derivative of another node,
   337  // used in the cuDNN version of batch norm.
   338  //
   339  // The cuDNN BatchNorm operation produces the derivatives for the scale and bias as a side effect
   340  // of calculating the derivative of the input. Because Gorgonia's Ops are modelled as pure functions (and no tuples)
   341  // this causes a bit of trouble. With the clever use of scratch space ops multireturn can be simulated.
   342  // But this causes derivatives to not be set correctly.
   343  func SetDerivOf(deriv, of *Node) {
   344  	deriv.derivOf = append(deriv.derivOf, of)
   345  	of.deriv = deriv
   346  }