gorgonia.org/gorgonia@v0.9.17/analysis.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  )
     7  
     8  // dataflow analysis
     9  
    10  type dataflow struct {
    11  	uniques map[uint32]*Node
    12  
    13  	replacements map[*Node]*Node
    14  	intervals    map[*Node]*interval
    15  
    16  	// tracks the special nodes' children and parents
    17  	devTransChildren map[*Node]Nodes
    18  	devTransRepl     map[*Node]*Node
    19  }
    20  
    21  func newdataflow() *dataflow {
    22  	df := new(dataflow)
    23  	df.uniques = make(map[uint32]*Node)
    24  	df.devTransChildren = make(map[*Node]Nodes)
    25  	df.devTransRepl = make(map[*Node]*Node)
    26  	return df
    27  }
    28  
    29  // equivalent to the value numbering algorithm
    30  // it returns true if it is unique
    31  func (df *dataflow) vn(n *Node) (retVal *Node, unique bool) {
    32  	compileLogf("Value numbering")
    33  	enterLogScope()
    34  	defer leaveLogScope()
    35  
    36  	node, ok := df.uniques[n.Hashcode()]
    37  
    38  	if ok {
    39  		return node, false
    40  	}
    41  
    42  	compileLogf("adding a new unique")
    43  	// otherwise, add it to uniques, and then return itself
    44  	df.uniques[n.Hashcode()] = n
    45  
    46  	return n, true
    47  }
    48  
    49  // analyzeDevice records which node is supposed to be executed on which device.
    50  //
    51  // Currently it will only use Device 0. In the future, we can be smart about which device to use
    52  func (df *dataflow) analyzeDevice(n *Node) {
    53  	switch n.op.(type) {
    54  	case CUDADoer:
    55  		if n.dataOn == CPU {
    56  			n.dataOn = Device(0)
    57  		}
    58  	case CLDoer:
    59  		if n.dataOn == CPU {
    60  			n.dataOn = Device(0)
    61  		}
    62  	default:
    63  		n.dataOn = CPU
    64  	}
    65  }
    66  
    67  // replaceWithSelf fills the replacement map with itself. This is the method used in the lispMachine only, as it skips value numbering
    68  func (df *dataflow) replaceWithSelf(sorted Nodes) {
    69  	df.replacements = make(map[*Node]*Node)
    70  	for _, n := range sorted {
    71  		df.replacements[n] = n
    72  		df.analyzeDevice(n) // Device Targeting
    73  	}
    74  }
    75  
    76  // fixIntervalDevices is used only by the lispMachine. It fixes the intervals to have the correct devices
    77  func (df *dataflow) fixIntervalDevices(sorted Nodes) {
    78  	for _, n := range sorted {
    79  		df.intervals[n].result.device = n.dataOn
    80  	}
    81  }
    82  
    83  func analyze(g *ExprGraph, sorted Nodes) *dataflow {
    84  	compileLogf("Performing dataflow analysis")
    85  	enterLogScope()
    86  	defer leaveLogScope()
    87  
    88  	compileLogf("Finding unique leaves")
    89  	df := newdataflow()
    90  	for _, n := range g.leaves {
    91  		df.uniques[n.Hashcode()] = n
    92  	}
    93  
    94  	// compileLogf("Common subexpression elimination")
    95  	// compileLogf("analyzing devices")
    96  	replacements := make(map[*Node]*Node)
    97  	for _, n := range sorted {
    98  		r, _ := df.vn(n)
    99  		replacements[n] = r // CSE
   100  		df.analyzeDevice(n) // Device targeting
   101  	}
   102  	df.replacements = replacements
   103  	compileLogf("replacements: %-p", FmtNodeMap(replacements))
   104  
   105  	// TODO
   106  	// constant propagation
   107  	/*
   108  		for _, node := range g.nodes {
   109  			n := node.(*Node)
   110  			if len(n.Children) > 0 {
   111  				allConst := true
   112  				for _, child := range n.Children {
   113  					if _, ok := child.Op.(constant); !ok {
   114  						allConst = false
   115  						break
   116  					}
   117  				}
   118  			}
   119  		}
   120  	*/
   121  	return df
   122  }
   123  
   124  func newDevTransNode(read, write *Node, from, to Device) *Node {
   125  	op := devTrans{from, to, write}
   126  	n := borrowNode()
   127  	n.id = -1
   128  	n.op = op
   129  	n.shape = read.shape.Clone()
   130  	n.t = read.t
   131  	n.isStmt = true
   132  	n.children = Nodes{read}
   133  	return n
   134  }
   135  
   136  func (df *dataflow) insertDeviceInstr(sorted Nodes) Nodes {
   137  	compileLogf("Inserting Device Transport Instructions")
   138  	enterLogScope()
   139  	defer leaveLogScope()
   140  	// input -> output
   141  	for i := 0; i < len(sorted); i++ {
   142  		node := sorted[i]
   143  		n := df.replacements[node]
   144  		dev := n.dataOn
   145  
   146  		compileLogf("Working on %v. Replacement %v. Device %v", node, n, dev)
   147  		var incr int
   148  		var useReplacement bool
   149  		replacementChildren := make(Nodes, len(n.children))
   150  		enterLogScope()
   151  		for j, child := range n.children {
   152  			c := df.replacements[child]
   153  			childDev := c.dataOn
   154  
   155  			compileLogf("Working on child :%v. Device: %v, Parent Device %v", c, childDev, dev)
   156  			if childDev != dev {
   157  				useReplacement = true
   158  				if repl, ok := df.devTransRepl[c]; ok {
   159  					replacementChildren[j] = repl
   160  					continue
   161  				}
   162  				transport := newDevTransNode(c, n, childDev, dev)
   163  				sorted = append(sorted, nil)
   164  				copy(sorted[i+1:], sorted[i:])
   165  				sorted[i] = transport
   166  				incr++
   167  				compileLogf("Inserted %v", transport)
   168  
   169  				// other stateful stuff
   170  				df.devTransRepl[c] = transport
   171  				df.replacements[transport] = transport
   172  				replacementChildren[j] = transport
   173  			} else {
   174  				replacementChildren[j] = child
   175  			}
   176  		}
   177  		leaveLogScope()
   178  
   179  		if useReplacement {
   180  			df.devTransChildren[n] = replacementChildren
   181  		}
   182  
   183  		i += incr
   184  	}
   185  	return sorted
   186  }
   187  
   188  /*
   189  	Notes on handling the live set:
   190  
   191  	1. We load all the SSAs listed in the block's LiveIn
   192  	2. Then we load all the SSAs used as input in this block Phi nodes
   193  		- The reason for this is so that those SSAs can have intervals created
   194  		  that are live in this block (well, they are kinda live)
   195  	3. These input SSAs are temporary only, because a path-dependent liveset will be calculated below
   196  
   197  	Consider a CFG that looks like this:
   198  
   199                             BLOCK 1           BLOCK 3
   200                             +-------+        +-------+
   201                       +---->| x = 1 +------->| y = 3 +----------------+
   202          BLOCK 0      |     +-------+        | use x |                v  BLOCK 4
   203         +-------+     |                      +-------+              +-------------+
   204         |       |+----+                                             | x = ϕ(1, 2) |
   205         +-------+     |     BLOCK 2                                 +-------------+
   206                       |     +-------+                                 ^
   207                       +---->| x = 2 +---------------------------------+
   208                             +-------+
   209  
   210  	`x = 1` needs to be live in BLOCK 1, BLOCK 3 and BLOCK 4
   211  	`x = 2` needs to be live in BLOCK 2 and BLOCK 4.
   212  
   213  	The solution: in BLOCK 4, load `x = 1` and `x = 2` so they can be considered live in Block 4.
   214  
   215  	The interval building process comes to BLOCK 3 next. It considers the SSAs that are live in BLOCK 4.
   216  	If `x = 2` is live in BLOCK 4, it's Bad News with capital letters (see comment below).
   217  
   218  	The solution: remove the InputSSAs of the Phi nodes when we're leaving this block.
   219  */
   220  // TODO: rephrase above to fit this package's function.
   221  // It's like the above, but without basic blocks, phi nodes, etc, making it a LOT simpler
   222  func (df *dataflow) buildIntervals(sorted Nodes) {
   223  	compileLogf("Building intervals for %v", sorted)
   224  	enterLogScope()
   225  	defer leaveLogScope()
   226  
   227  	intervals := make(map[*Node]*interval)
   228  
   229  	var g *ExprGraph
   230  	for _, n := range sorted {
   231  		if g == nil && n.g != nil {
   232  			g = n.g
   233  		}
   234  
   235  		intervals[n] = newInterval()
   236  	}
   237  
   238  	instructions := len(sorted)
   239  	for i := len(sorted) - 1; i >= 0; i-- {
   240  		n := sorted[i]
   241  		instrNum := i
   242  		nInter := intervals[n]
   243  		compileLogf("n %v | %v", n, nInter)
   244  
   245  		// inputs and constants will be live the entire program
   246  		if n.isInput() || n.isConstant() {
   247  			nInter.addRange(instrNum, instructions)
   248  			repl, ok := df.devTransRepl[n]
   249  			if ok {
   250  				interv, ok := intervals[repl]
   251  				if ok {
   252  					interv.addRange(instrNum, instructions)
   253  				}
   254  			}
   255  
   256  			continue
   257  		}
   258  		nInter.addRange(instrNum, instrNum)
   259  
   260  		// check for special cases requiring copying from device to device
   261  
   262  		var children Nodes
   263  		var ok bool
   264  		if children, ok = df.devTransChildren[n]; !ok {
   265  			children = n.children
   266  		}
   267  
   268  		for _, child := range children {
   269  			iv, ok := intervals[child]
   270  			if !ok {
   271  				// do something
   272  				// parents := g.to[n]
   273  				// for i, from := range parents {
   274  				// 	ioutil.WriteFile(fmt.Sprintf("n_%d.dot", i), []byte(from.ToDot()), 0644)
   275  				// }
   276  			}
   277  			iv.addUsePositions(instrNum)
   278  			// iv.setTo(instrNum)
   279  		}
   280  		// assume all derivations of input
   281  		if len(n.derivOf) > 0 {
   282  			for _, d := range n.derivOf {
   283  				if d.isInput() {
   284  					nInter.addUsePositions(instructions)
   285  					break
   286  				}
   287  			}
   288  		}
   289  	}
   290  
   291  	for _, iv := range intervals {
   292  		iv.fix()
   293  	}
   294  
   295  	var buf bytes.Buffer
   296  	for k, v := range intervals {
   297  		fmt.Fprintf(&buf, "%v: %v\n", k, v)
   298  	}
   299  	compileLogf("Intervals: %v", buf.String())
   300  
   301  	df.intervals = intervals
   302  	return
   303  }