github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/key/map.go (about)

     1  // Copyright (c) 2019 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package key
     6  
     7  import (
     8  	"errors"
     9  	"sort"
    10  	"strings"
    11  )
    12  
    13  // Map allows the indexing of entries with arbitrary key types, so long as the keys are
    14  // either hashable natively or implement Hashable
    15  type Map struct {
    16  	normal map[interface{}]interface{}
    17  	custom map[uint64]entry
    18  	length int // length of the Map
    19  }
    20  
    21  // NewMap creates a new Map from a list of key-value pairs, so long as the list is of even length.
    22  func NewMap(keysAndVals ...interface{}) *Map {
    23  	length := len(keysAndVals)
    24  	if length%2 != 0 {
    25  		panic("Odd number of arguments passed to NewMap. Arguments should be of form: " +
    26  			"key1, value1, key2, value2, ...")
    27  	}
    28  	m := Map{}
    29  	for i := 0; i < length; i += 2 {
    30  		m.Set(keysAndVals[i], keysAndVals[i+1])
    31  	}
    32  	return &m
    33  }
    34  
    35  // String outputs the string representation of the map
    36  func (m *Map) String() string {
    37  	if m == nil {
    38  		return "key.Map(nil)"
    39  	}
    40  	stringify := func(v interface{}) string {
    41  		return stringifyCollectionHelper(v)
    42  	}
    43  	type kv struct {
    44  		k string
    45  		v string
    46  	}
    47  	var length int
    48  	kvs := make([]kv, 0, m.Len())
    49  	_ = m.Iter(func(k, v interface{}) error {
    50  		element := kv{
    51  			k: stringify(k),
    52  			v: stringify(v),
    53  		}
    54  		kvs = append(kvs, element)
    55  		length += len(element.k) + len(element.v)
    56  		return nil
    57  	})
    58  	sort.Slice(kvs, func(i, j int) bool { return kvs[i].k < kvs[j].k })
    59  	var buf strings.Builder
    60  	buf.Grow(length + len("key.Map[]") + 2*len(kvs) /* room for seperators: ", :" */)
    61  	buf.WriteString("key.Map[")
    62  	for i, kv := range kvs {
    63  		if i != 0 {
    64  			buf.WriteByte(' ')
    65  		}
    66  		buf.WriteString(kv.k + ":" + kv.v)
    67  	}
    68  	buf.WriteString("]")
    69  	return buf.String()
    70  }
    71  
    72  // KeyString returns a string that is suitable to be used as a key in
    73  // an index of strings.
    74  func (m *Map) KeyString() string {
    75  	s := make(map[string]interface{}, m.Len())
    76  	_ = m.Iter(func(k, v interface{}) error {
    77  		s[stringify(k)] = v
    78  		return nil
    79  	})
    80  	keys := SortedKeys(s)
    81  	for i, k := range keys {
    82  		keys[i] = k + "=" + stringify(s[k])
    83  	}
    84  	return strings.Join(keys, "_")
    85  }
    86  
    87  // Len returns the length of the Map
    88  func (m *Map) Len() int {
    89  	if m == nil {
    90  		return 0
    91  	}
    92  	return m.length
    93  }
    94  
    95  // An entry represents an entry in a map whose key is not normally hashable,
    96  // and is therefore of type Hashable
    97  // (that is, a Hash method has been defined for this entry's key, and we can index it)
    98  //
    99  // Because hash collisions are possible (though unlikely), an entry is
   100  // actually a linked list. To save space, the valOrNext field serves
   101  // double-duty. It is either just the value associated with the
   102  // entry's key, indicating the end of the list, or it contains a
   103  // *chainedEntry. A chainedEntry holds the key's value and the next
   104  // entry (which may also contain a *chainedEntry).
   105  type entry struct {
   106  	k Hashable
   107  	// contains a value or *chainedEntry
   108  	valOrNext interface{}
   109  }
   110  
   111  type chainedEntry struct {
   112  	val interface{}
   113  	entry
   114  }
   115  
   116  // entrySearch searches entry for matching keys. It returns the
   117  // containing entry if found, and the last entry in the chain if not
   118  // found.
   119  func entrySearch(ent *entry, k Hashable) (containing *entry, found bool) {
   120  	for {
   121  		if k.Equal(ent.k) {
   122  			return ent, true
   123  		}
   124  		chEnt, ok := ent.valOrNext.(*chainedEntry)
   125  		if !ok {
   126  			return ent, false
   127  		}
   128  		ent = &chEnt.entry
   129  	}
   130  }
   131  
   132  func entryGetValue(ent *entry) interface{} {
   133  	if chEnt, ok := ent.valOrNext.(*chainedEntry); ok {
   134  		return chEnt.val
   135  	}
   136  	return ent.valOrNext
   137  }
   138  
   139  func entrySetValue(ent *entry, v interface{}) {
   140  	if chEnt, ok := ent.valOrNext.(*chainedEntry); ok {
   141  		chEnt.val = v
   142  		return
   143  	}
   144  	ent.valOrNext = v
   145  }
   146  
   147  // entryAppend appends a new entry to the end of ent.
   148  func entryAppend(ent *entry, k Hashable, v interface{}) {
   149  	if _, ok := ent.valOrNext.(*chainedEntry); ok {
   150  		panic("chained entry passed to entryAppend ")
   151  	}
   152  	ent.valOrNext = &chainedEntry{
   153  		val: ent.valOrNext,
   154  		entry: entry{
   155  			k:         k,
   156  			valOrNext: v,
   157  		},
   158  	}
   159  }
   160  
   161  // entryRemove removes an entry that has key k. The new head entry is
   162  // returned along with true if k is found and removed, and false if
   163  // not found.
   164  func entryRemove(head *entry, k Hashable) (*entry, bool) {
   165  	if k.Equal(head.k) {
   166  		// head of list matches
   167  		if chEnt, ok := head.valOrNext.(*chainedEntry); ok {
   168  			return &chEnt.entry, true
   169  		}
   170  		return nil, true
   171  	}
   172  	prev := head
   173  	for {
   174  		next, ok := prev.valOrNext.(*chainedEntry)
   175  		if !ok {
   176  			// reached end of chain, not found
   177  			return head, false
   178  		}
   179  		if k.Equal(next.entry.k) {
   180  			if nextNext, ok := next.entry.valOrNext.(*chainedEntry); ok {
   181  				// Remove next from entry list, but next contains
   182  				// prev's val, so move that into nextNext.
   183  				nextNext.val = next.val
   184  				prev.valOrNext = nextNext
   185  			} else {
   186  				// Remove end of the list
   187  				prev.valOrNext = next.val
   188  			}
   189  			return head, true
   190  		}
   191  		prev = &next.entry
   192  	}
   193  }
   194  
   195  func entryIter(ent entry, f func(k, v interface{}) error) error {
   196  	for {
   197  		if chEnt, ok := ent.valOrNext.(*chainedEntry); ok {
   198  			if err := f(ent.k, chEnt.val); err != nil {
   199  				return err
   200  			}
   201  			ent = chEnt.entry
   202  			continue
   203  		}
   204  		return f(ent.k, ent.valOrNext)
   205  	}
   206  }
   207  
   208  // Hashable represents the key for an entry in a Map that cannot natively be hashed
   209  type Hashable interface {
   210  	Hash() uint64
   211  	Equal(other interface{}) bool
   212  }
   213  
   214  // Equal compares two Maps
   215  func (m *Map) Equal(other interface{}) bool {
   216  	if (m == nil) != (other == nil) {
   217  		return false
   218  	}
   219  	o, ok := other.(*Map)
   220  	if !ok {
   221  		return false
   222  	}
   223  	if m.length != o.length {
   224  		return false
   225  	}
   226  	err := m.Iter(func(k, v interface{}) error {
   227  		otherV, ok := o.Get(k)
   228  		if !ok {
   229  			return errors.New("notequal")
   230  		}
   231  		if !keyEqual(v, otherV) {
   232  			return errors.New("notequal")
   233  		}
   234  		return nil
   235  	})
   236  	return err == nil
   237  }
   238  
   239  // Hash returns the hash value of this Map
   240  func (m *Map) Hash() uint64 {
   241  	if m == nil {
   242  		return 0
   243  	}
   244  	var h uintptr
   245  	m.Iter(func(k, v interface{}) error {
   246  		h += HashInterface(k) + HashInterface(v)
   247  		return nil
   248  	})
   249  	return uint64(h)
   250  }
   251  
   252  // Set adds a key-value pair to the Map
   253  func (m *Map) Set(k, v interface{}) {
   254  	if k == nil {
   255  		return
   256  	}
   257  	if hkey, ok := k.(Hashable); ok {
   258  		if m.custom == nil {
   259  			m.custom = make(map[uint64]entry)
   260  		}
   261  		// get hash, add to custom if not present
   262  		// if present, append to next of root entry
   263  		h := hkey.Hash()
   264  		rootentry, ok := m.custom[h]
   265  		if !ok {
   266  			m.custom[h] = entry{k: hkey, valOrNext: v}
   267  			m.length++
   268  			return
   269  		}
   270  		ent, found := entrySearch(&rootentry, hkey)
   271  		if found {
   272  			entrySetValue(ent, v)
   273  			m.custom[h] = rootentry
   274  			return
   275  		}
   276  		entryAppend(ent, hkey, v)
   277  		m.custom[h] = rootentry
   278  		m.length++
   279  	} else {
   280  		if m.normal == nil {
   281  			m.normal = make(map[interface{}]interface{})
   282  		}
   283  		l := len(m.normal)
   284  		m.normal[k] = v
   285  		if l != len(m.normal) { // len has changed
   286  			m.length++
   287  		}
   288  	}
   289  }
   290  
   291  // Get retrieves the value stored with key k from the Map
   292  func (m *Map) Get(k interface{}) (interface{}, bool) {
   293  	if m == nil {
   294  		return nil, false
   295  	}
   296  	if hkey, ok := k.(Hashable); ok {
   297  		h := hkey.Hash()
   298  		hentry, ok := m.custom[h]
   299  		if !ok {
   300  			return nil, false
   301  		}
   302  		ent, found := entrySearch(&hentry, hkey)
   303  		if !found {
   304  			return nil, false
   305  		}
   306  		return entryGetValue(ent), true
   307  	}
   308  	v, ok := m.normal[k]
   309  	return v, ok
   310  }
   311  
   312  // Del removes an entry with key k from the Map
   313  func (m *Map) Del(k interface{}) {
   314  	if m == nil {
   315  		return
   316  	}
   317  	if hkey, ok := k.(Hashable); ok {
   318  		if m.custom == nil {
   319  			return
   320  		}
   321  		h := hkey.Hash()
   322  		hentry, ok := m.custom[h]
   323  		if !ok {
   324  			return
   325  		}
   326  		newEnt, found := entryRemove(&hentry, hkey)
   327  		if !found {
   328  			return
   329  		}
   330  		m.length--
   331  		if newEnt == nil {
   332  			delete(m.custom, h)
   333  		} else {
   334  			m.custom[h] = *newEnt
   335  		}
   336  		return
   337  	}
   338  	// not Hashable, check normal
   339  	if m.normal == nil {
   340  		return
   341  	}
   342  	l := len(m.normal)
   343  	delete(m.normal, k)
   344  	if l != len(m.normal) {
   345  		m.length--
   346  	}
   347  }
   348  
   349  // Iter applies func f to every key-value pair in the Map
   350  func (m *Map) Iter(f func(k, v interface{}) error) error {
   351  	if m == nil {
   352  		return nil
   353  	}
   354  	for k, v := range m.normal {
   355  		if err := f(k, v); err != nil {
   356  			return err
   357  		}
   358  	}
   359  	for _, ent := range m.custom {
   360  		if err := entryIter(ent, f); err != nil {
   361  			return err
   362  		}
   363  	}
   364  	return nil
   365  }
   366  
   367  // Keys returns a list of all keys in the Map
   368  func (m *Map) Keys() []interface{} {
   369  	keys := make([]interface{}, m.Len())
   370  	i := 0
   371  	m.Iter(func(k, v interface{}) error {
   372  		keys[i] = k
   373  		i++
   374  		return nil
   375  	})
   376  	return keys
   377  }
   378  
   379  // Values returns a list of all values in the Map
   380  func (m *Map) Values() []interface{} {
   381  	values := make([]interface{}, m.Len())
   382  	i := 0
   383  	m.Iter(func(k, v interface{}) error {
   384  		values[i] = v
   385  		i++
   386  		return nil
   387  	})
   388  	return values
   389  }