gorgonia.org/gorgonia@v0.9.17/collections.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "sort" 6 "unsafe" 7 8 "github.com/xtgo/set" 9 ) 10 11 // Nodes is a slice of nodes, but it also acts as a set of nodes by implementing the Sort interface 12 type Nodes []*Node 13 14 // Node returns nil. Always. This is bound to cause a panic somewhere if an program is not using it correctly. 15 // The reason for implementing this is so that it may fulfil common interfaces. 16 func (ns Nodes) Node() *Node { return nil } 17 18 // Nodes returns itself. This is useful for interfaces 19 func (ns Nodes) Nodes() Nodes { return ns } 20 21 // Err returns nil always 22 func (ns Nodes) Err() error { return nil } 23 24 // implement sort.Interface 25 26 func (ns Nodes) Len() int { return len(ns) } 27 func (ns Nodes) Less(i, j int) bool { 28 return uintptr(unsafe.Pointer(ns[i])) < uintptr(unsafe.Pointer(ns[j])) 29 } 30 func (ns Nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } 31 32 // uses xtgo/set stuff 33 34 // Set returns a uniquifies slice. It mutates the slice. 35 func (ns Nodes) Set() Nodes { 36 sort.Sort(ns) 37 size := set.Uniq(ns) 38 ns = ns[:size] 39 return ns 40 } 41 42 // Add adds to set 43 func (ns Nodes) Add(n *Node) Nodes { 44 for _, node := range ns { 45 if node == n { 46 return ns 47 } 48 } 49 ns = append(ns, n) 50 return ns 51 } 52 53 // Contains checks if the wanted node is in the set 54 func (ns Nodes) Contains(want *Node) bool { 55 for _, n := range ns { 56 if n == want { 57 return true 58 } 59 } 60 return false 61 } 62 63 // Format implements fmt.Formatter, which allows Nodes to be differently formatted depending on the verbs 64 func (ns Nodes) Format(s fmt.State, c rune) { 65 delimiter := ", " 66 if s.Flag(' ') { 67 delimiter = " " 68 } 69 if s.Flag('+') { 70 delimiter = ", \n" 71 } 72 switch c { 73 case 'd': 74 s.Write([]byte("[")) 75 for i, n := range ns { 76 fmt.Fprintf(s, "%x", n.ID()) 77 if i < len(ns)-1 { 78 fmt.Fprintf(s, "%s", delimiter) 79 } 80 } 81 s.Write([]byte("]")) 82 case 'v', 's': 83 s.Write([]byte("[")) 84 for i, n := range ns { 85 if s.Flag('#') { 86 fmt.Fprintf(s, "%s :: %v", n.Name(), n.t) 87 } else { 88 fmt.Fprintf(s, "%s", n.Name()) 89 } 90 if i < len(ns)-1 { 91 fmt.Fprintf(s, "%s", delimiter) 92 } 93 } 94 s.Write([]byte("]")) 95 case 'Y': 96 s.Write([]byte("[")) 97 for i, n := range ns { 98 fmt.Fprintf(s, "%v", n.t) 99 if i < len(ns)-1 { 100 fmt.Fprintf(s, "%s", delimiter) 101 } 102 } 103 s.Write([]byte("]")) 104 105 case 'P': 106 s.Write([]byte("[")) 107 for i, n := range ns { 108 fmt.Fprintf(s, "%p", n) 109 if i < len(ns)-1 { 110 fmt.Fprintf(s, "%s", delimiter) 111 } 112 } 113 s.Write([]byte("]")) 114 } 115 } 116 117 // Difference is ns - other. Bear in mind it is NOT commutative 118 func (ns Nodes) Difference(other Nodes) Nodes { 119 sort.Sort(ns) 120 sort.Sort(other) 121 s := append(ns, other...) 122 count := set.Diff(s, len(ns)) 123 return s[:count] 124 } 125 126 // Intersect performs an intersection with other Nodes 127 func (ns Nodes) Intersect(other Nodes) Nodes { 128 sort.Sort(ns) 129 sort.Sort(other) 130 s := append(ns, other...) 131 count := set.Inter(s, len(ns)) 132 return s[:count] 133 } 134 135 // AllSameGraph returns true if all the nodes in the slice belong to the same graph. Note that constants do not have to belong to the same graph. 136 func (ns Nodes) AllSameGraph() bool { 137 if len(ns) == 0 { 138 return false 139 } 140 141 var g *ExprGraph 142 for _, n := range ns { 143 if !n.isConstant() { 144 g = n.g 145 break 146 } 147 } 148 149 for _, n := range ns { 150 if n.g != g && !n.isConstant() { 151 return false 152 } 153 } 154 return true 155 } 156 157 // Equals returns true if two Nodes are the same 158 func (ns Nodes) Equals(other Nodes) bool { 159 if len(ns) != len(other) { 160 return false 161 } 162 163 for _, n := range ns { 164 if !other.Contains(n) { 165 return false 166 } 167 } 168 return true 169 } 170 171 func (ns Nodes) mapSet() NodeSet { return NewNodeSet(ns...) } 172 173 func (ns Nodes) index(n *Node) int { 174 for i, node := range ns { 175 if node == n { 176 return i 177 } 178 } 179 return -1 180 } 181 182 func (ns Nodes) reverse() { 183 l := len(ns) 184 for i := l/2 - 1; i >= 0; i-- { 185 o := l - 1 - i 186 ns[i], ns[o] = ns[o], ns[i] 187 } 188 } 189 190 func (ns Nodes) replace(what, with *Node) Nodes { 191 for i, n := range ns { 192 if n == what { 193 ns[i] = with 194 } 195 } 196 return ns 197 } 198 199 var removers = make(map[string]int) 200 201 func (ns Nodes) remove(what *Node) Nodes { 202 for i := ns.index(what); i != -1; i = ns.index(what) { 203 copy(ns[i:], ns[i+1:]) 204 ns[len(ns)-1] = nil // to prevent any unwanted references so things can be GC'd away 205 ns = ns[:len(ns)-1] 206 } 207 208 return ns 209 } 210 211 func (ns Nodes) dimSizers() []DimSizer { 212 retVal := borrowDimSizers(len(ns)) 213 for i, n := range ns { 214 if s, ok := n.op.(sizeOp); ok { 215 retVal[i] = s 216 } else { 217 retVal[i] = n.shape 218 } 219 } 220 return retVal 221 }