gorgonia.org/gorgonia@v0.9.17/collections.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"unsafe"
     7  
     8  	"github.com/xtgo/set"
     9  )
    10  
    11  // Nodes is a slice of nodes, but it also acts as a set of nodes by implementing the Sort interface
    12  type Nodes []*Node
    13  
    14  // Node returns nil. Always. This is bound to cause a panic somewhere if an program is not using it correctly.
    15  // The reason for implementing this is so that it may fulfil common interfaces.
    16  func (ns Nodes) Node() *Node { return nil }
    17  
    18  // Nodes returns itself. This is useful for interfaces
    19  func (ns Nodes) Nodes() Nodes { return ns }
    20  
    21  // Err returns nil always
    22  func (ns Nodes) Err() error { return nil }
    23  
    24  // implement sort.Interface
    25  
    26  func (ns Nodes) Len() int { return len(ns) }
    27  func (ns Nodes) Less(i, j int) bool {
    28  	return uintptr(unsafe.Pointer(ns[i])) < uintptr(unsafe.Pointer(ns[j]))
    29  }
    30  func (ns Nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] }
    31  
    32  // uses xtgo/set stuff
    33  
    34  // Set returns a uniquifies slice. It mutates the slice.
    35  func (ns Nodes) Set() Nodes {
    36  	sort.Sort(ns)
    37  	size := set.Uniq(ns)
    38  	ns = ns[:size]
    39  	return ns
    40  }
    41  
    42  // Add adds to set
    43  func (ns Nodes) Add(n *Node) Nodes {
    44  	for _, node := range ns {
    45  		if node == n {
    46  			return ns
    47  		}
    48  	}
    49  	ns = append(ns, n)
    50  	return ns
    51  }
    52  
    53  // Contains checks if the wanted node is in the set
    54  func (ns Nodes) Contains(want *Node) bool {
    55  	for _, n := range ns {
    56  		if n == want {
    57  			return true
    58  		}
    59  	}
    60  	return false
    61  }
    62  
    63  // Format implements fmt.Formatter, which allows Nodes to be differently formatted depending on the verbs
    64  func (ns Nodes) Format(s fmt.State, c rune) {
    65  	delimiter := ", "
    66  	if s.Flag(' ') {
    67  		delimiter = "  "
    68  	}
    69  	if s.Flag('+') {
    70  		delimiter = ", \n"
    71  	}
    72  	switch c {
    73  	case 'd':
    74  		s.Write([]byte("["))
    75  		for i, n := range ns {
    76  			fmt.Fprintf(s, "%x", n.ID())
    77  			if i < len(ns)-1 {
    78  				fmt.Fprintf(s, "%s", delimiter)
    79  			}
    80  		}
    81  		s.Write([]byte("]"))
    82  	case 'v', 's':
    83  		s.Write([]byte("["))
    84  		for i, n := range ns {
    85  			if s.Flag('#') {
    86  				fmt.Fprintf(s, "%s :: %v", n.Name(), n.t)
    87  			} else {
    88  				fmt.Fprintf(s, "%s", n.Name())
    89  			}
    90  			if i < len(ns)-1 {
    91  				fmt.Fprintf(s, "%s", delimiter)
    92  			}
    93  		}
    94  		s.Write([]byte("]"))
    95  	case 'Y':
    96  		s.Write([]byte("["))
    97  		for i, n := range ns {
    98  			fmt.Fprintf(s, "%v", n.t)
    99  			if i < len(ns)-1 {
   100  				fmt.Fprintf(s, "%s", delimiter)
   101  			}
   102  		}
   103  		s.Write([]byte("]"))
   104  
   105  	case 'P':
   106  		s.Write([]byte("["))
   107  		for i, n := range ns {
   108  			fmt.Fprintf(s, "%p", n)
   109  			if i < len(ns)-1 {
   110  				fmt.Fprintf(s, "%s", delimiter)
   111  			}
   112  		}
   113  		s.Write([]byte("]"))
   114  	}
   115  }
   116  
   117  // Difference is ns - other. Bear in mind it is NOT commutative
   118  func (ns Nodes) Difference(other Nodes) Nodes {
   119  	sort.Sort(ns)
   120  	sort.Sort(other)
   121  	s := append(ns, other...)
   122  	count := set.Diff(s, len(ns))
   123  	return s[:count]
   124  }
   125  
   126  // Intersect performs an intersection with other Nodes
   127  func (ns Nodes) Intersect(other Nodes) Nodes {
   128  	sort.Sort(ns)
   129  	sort.Sort(other)
   130  	s := append(ns, other...)
   131  	count := set.Inter(s, len(ns))
   132  	return s[:count]
   133  }
   134  
   135  // AllSameGraph returns true if all the nodes in the slice belong to the same graph. Note that constants do not have to belong to the same graph.
   136  func (ns Nodes) AllSameGraph() bool {
   137  	if len(ns) == 0 {
   138  		return false
   139  	}
   140  
   141  	var g *ExprGraph
   142  	for _, n := range ns {
   143  		if !n.isConstant() {
   144  			g = n.g
   145  			break
   146  		}
   147  	}
   148  
   149  	for _, n := range ns {
   150  		if n.g != g && !n.isConstant() {
   151  			return false
   152  		}
   153  	}
   154  	return true
   155  }
   156  
   157  // Equals returns true if two Nodes are the same
   158  func (ns Nodes) Equals(other Nodes) bool {
   159  	if len(ns) != len(other) {
   160  		return false
   161  	}
   162  
   163  	for _, n := range ns {
   164  		if !other.Contains(n) {
   165  			return false
   166  		}
   167  	}
   168  	return true
   169  }
   170  
   171  func (ns Nodes) mapSet() NodeSet { return NewNodeSet(ns...) }
   172  
   173  func (ns Nodes) index(n *Node) int {
   174  	for i, node := range ns {
   175  		if node == n {
   176  			return i
   177  		}
   178  	}
   179  	return -1
   180  }
   181  
   182  func (ns Nodes) reverse() {
   183  	l := len(ns)
   184  	for i := l/2 - 1; i >= 0; i-- {
   185  		o := l - 1 - i
   186  		ns[i], ns[o] = ns[o], ns[i]
   187  	}
   188  }
   189  
   190  func (ns Nodes) replace(what, with *Node) Nodes {
   191  	for i, n := range ns {
   192  		if n == what {
   193  			ns[i] = with
   194  		}
   195  	}
   196  	return ns
   197  }
   198  
   199  var removers = make(map[string]int)
   200  
   201  func (ns Nodes) remove(what *Node) Nodes {
   202  	for i := ns.index(what); i != -1; i = ns.index(what) {
   203  		copy(ns[i:], ns[i+1:])
   204  		ns[len(ns)-1] = nil // to prevent any unwanted references so things can be GC'd away
   205  		ns = ns[:len(ns)-1]
   206  	}
   207  
   208  	return ns
   209  }
   210  
   211  func (ns Nodes) dimSizers() []DimSizer {
   212  	retVal := borrowDimSizers(len(ns))
   213  	for i, n := range ns {
   214  		if s, ok := n.op.(sizeOp); ok {
   215  			retVal[i] = s
   216  		} else {
   217  			retVal[i] = n.shape
   218  		}
   219  	}
   220  	return retVal
   221  }