github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/sets/sets.go (about)

     1  // Package sets provides generic sets.
     2  package sets
     3  
     4  import (
     5  	"encoding/json"
     6  
     7  	"golang.org/x/exp/maps"
     8  )
     9  
    10  // Set is a convenience wrapper around map[T]struct{}.
    11  type Set[T comparable] map[T]struct{}
    12  
    13  // Add inserts the given elements to s and returns s.
    14  func (s Set[T]) Add(t ...T) Set[T] {
    15  	for _, v := range t {
    16  		s[v] = struct{}{}
    17  	}
    18  	return s
    19  }
    20  
    21  // AddSet inserts the elements of t to s and returns s.
    22  func (s Set[T]) AddSet(t Set[T]) Set[T] {
    23  	for v := range t {
    24  		s[v] = struct{}{}
    25  	}
    26  	return s
    27  }
    28  
    29  // Remove deletes the given elements from s and returns s.
    30  func (s Set[T]) Remove(t ...T) Set[T] {
    31  	for _, v := range t {
    32  		delete(s, v)
    33  	}
    34  	return s
    35  }
    36  
    37  // RemoveSet deletes the elements of t from s and returns s.
    38  func (s Set[T]) RemoveSet(t Set[T]) Set[T] {
    39  	for v := range t {
    40  		delete(s, v)
    41  	}
    42  	return s
    43  }
    44  
    45  // Has returns whether t is a member of s.
    46  func (s Set[T]) Has(t T) bool {
    47  	_, ok := s[t]
    48  	return ok
    49  }
    50  
    51  // Intersect returns a new set holding the elements that are common
    52  // to s and t.
    53  func (s Set[T]) Intersect(t Set[T]) Set[T] {
    54  	if len(s) > len(t) {
    55  		s, t = t, s
    56  	}
    57  	result := Set[T]{}
    58  	for v := range s {
    59  		if t.Has(v) {
    60  			result[v] = struct{}{}
    61  		}
    62  	}
    63  	return result
    64  }
    65  
    66  // MarshalJSON implements the json.Marshaler interface.
    67  func (s Set[T]) MarshalJSON() ([]byte, error) {
    68  	return json.Marshal(maps.Keys(s))
    69  }
    70  
    71  // UnmarshalJSON implements the json.Unmarshaler interface.
    72  func (s *Set[T]) UnmarshalJSON(b []byte) error {
    73  	var slice []T
    74  	if err := json.Unmarshal(b, &slice); err != nil {
    75  		return err
    76  	}
    77  	if *s == nil {
    78  		*s = Set[T]{}
    79  	}
    80  	s.Add(slice...)
    81  	return nil
    82  }