github.com/MetalBlockchain/metalgo@v1.11.9/utils/set/set.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package set
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"slices"
    10  
    11  	"golang.org/x/exp/maps"
    12  
    13  	"github.com/MetalBlockchain/metalgo/utils"
    14  	"github.com/MetalBlockchain/metalgo/utils/wrappers"
    15  
    16  	avajson "github.com/MetalBlockchain/metalgo/utils/json"
    17  )
    18  
    19  // The minimum capacity of a set
    20  const minSetSize = 16
    21  
    22  var _ json.Marshaler = (*Set[int])(nil)
    23  
    24  // Set is a set of elements.
    25  type Set[T comparable] map[T]struct{}
    26  
    27  // Of returns a Set initialized with [elts]
    28  func Of[T comparable](elts ...T) Set[T] {
    29  	s := NewSet[T](len(elts))
    30  	s.Add(elts...)
    31  	return s
    32  }
    33  
    34  // Return a new set with initial capacity [size].
    35  // More or less than [size] elements can be added to this set.
    36  // Using NewSet() rather than Set[T]{} is just an optimization that can
    37  // be used if you know how many elements will be put in this set.
    38  func NewSet[T comparable](size int) Set[T] {
    39  	if size < 0 {
    40  		return Set[T]{}
    41  	}
    42  	return make(map[T]struct{}, size)
    43  }
    44  
    45  func (s *Set[T]) resize(size int) {
    46  	if *s == nil {
    47  		if minSetSize > size {
    48  			size = minSetSize
    49  		}
    50  		*s = make(map[T]struct{}, size)
    51  	}
    52  }
    53  
    54  // Add all the elements to this set.
    55  // If the element is already in the set, nothing happens.
    56  func (s *Set[T]) Add(elts ...T) {
    57  	s.resize(2 * len(elts))
    58  	for _, elt := range elts {
    59  		(*s)[elt] = struct{}{}
    60  	}
    61  }
    62  
    63  // Union adds all the elements from the provided set to this set.
    64  func (s *Set[T]) Union(set Set[T]) {
    65  	s.resize(2 * set.Len())
    66  	for elt := range set {
    67  		(*s)[elt] = struct{}{}
    68  	}
    69  }
    70  
    71  // Difference removes all the elements in [set] from [s].
    72  func (s *Set[T]) Difference(set Set[T]) {
    73  	for elt := range set {
    74  		delete(*s, elt)
    75  	}
    76  }
    77  
    78  // Contains returns true iff the set contains this element.
    79  func (s *Set[T]) Contains(elt T) bool {
    80  	_, contains := (*s)[elt]
    81  	return contains
    82  }
    83  
    84  // Overlaps returns true if the intersection of the set is non-empty
    85  func (s *Set[T]) Overlaps(big Set[T]) bool {
    86  	small := *s
    87  	if small.Len() > big.Len() {
    88  		small, big = big, small
    89  	}
    90  
    91  	for elt := range small {
    92  		if _, ok := big[elt]; ok {
    93  			return true
    94  		}
    95  	}
    96  	return false
    97  }
    98  
    99  // Len returns the number of elements in this set.
   100  func (s Set[_]) Len() int {
   101  	return len(s)
   102  }
   103  
   104  // Remove all the given elements from this set.
   105  // If an element isn't in the set, it's ignored.
   106  func (s *Set[T]) Remove(elts ...T) {
   107  	for _, elt := range elts {
   108  		delete(*s, elt)
   109  	}
   110  }
   111  
   112  // Clear empties this set
   113  func (s *Set[_]) Clear() {
   114  	clear(*s)
   115  }
   116  
   117  // List converts this set into a list
   118  func (s Set[T]) List() []T {
   119  	return maps.Keys(s)
   120  }
   121  
   122  // Equals returns true if the sets contain the same elements
   123  func (s Set[T]) Equals(other Set[T]) bool {
   124  	return maps.Equal(s, other)
   125  }
   126  
   127  // Removes and returns an element.
   128  // If the set is empty, does nothing and returns false.
   129  func (s *Set[T]) Pop() (T, bool) {
   130  	for elt := range *s {
   131  		delete(*s, elt)
   132  		return elt, true
   133  	}
   134  	return utils.Zero[T](), false
   135  }
   136  
   137  func (s *Set[T]) UnmarshalJSON(b []byte) error {
   138  	str := string(b)
   139  	if str == avajson.Null {
   140  		return nil
   141  	}
   142  	var elts []T
   143  	if err := json.Unmarshal(b, &elts); err != nil {
   144  		return err
   145  	}
   146  	s.Clear()
   147  	s.Add(elts...)
   148  	return nil
   149  }
   150  
   151  func (s Set[_]) MarshalJSON() ([]byte, error) {
   152  	var (
   153  		eltBytes = make([][]byte, len(s))
   154  		i        int
   155  		err      error
   156  	)
   157  	for elt := range s {
   158  		eltBytes[i], err = json.Marshal(elt)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  		i++
   163  	}
   164  	// Sort for determinism
   165  	slices.SortFunc(eltBytes, bytes.Compare)
   166  
   167  	// Build the JSON
   168  	var (
   169  		jsonBuf = bytes.Buffer{}
   170  		errs    = wrappers.Errs{}
   171  	)
   172  	_, err = jsonBuf.WriteString("[")
   173  	errs.Add(err)
   174  	for i, elt := range eltBytes {
   175  		_, err := jsonBuf.Write(elt)
   176  		errs.Add(err)
   177  		if i != len(eltBytes)-1 {
   178  			_, err := jsonBuf.WriteString(",")
   179  			errs.Add(err)
   180  		}
   181  	}
   182  	_, err = jsonBuf.WriteString("]")
   183  	errs.Add(err)
   184  
   185  	return jsonBuf.Bytes(), errs.Err
   186  }
   187  
   188  // Returns a random element. If the set is empty, returns false
   189  func (s *Set[T]) Peek() (T, bool) {
   190  	for elt := range *s {
   191  		return elt, true
   192  	}
   193  	return utils.Zero[T](), false
   194  }