github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/asserts/internal/grouping.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package internal
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/base64"
    25  	"encoding/binary"
    26  	"fmt"
    27  	"sort"
    28  )
    29  
    30  // Groupings maintain labels to identify membership to one or more groups.
    31  // Labels are implemented as subsets of integers from 0
    32  // up to an excluded maximum, where the integers represent the groups.
    33  // Assumptions:
    34  //  - most labels are for one group or very few
    35  //  - a few labels are sparse with more groups in them
    36  //  - very few comprise the universe of all groups
    37  type Groupings struct {
    38  	n               uint
    39  	maxGroup        uint16
    40  	bitsetThreshold uint16
    41  }
    42  
    43  // NewGroupings creates a new Groupings supporting labels for membership
    44  // to up n groups. n must be a positive multiple of 16 and <=65536.
    45  func NewGroupings(n int) (*Groupings, error) {
    46  	if n <= 0 || n > 65536 {
    47  		return nil, fmt.Errorf("n=%d groups is outside of valid range (0, 65536]", n)
    48  	}
    49  	if n%16 != 0 {
    50  		return nil, fmt.Errorf("n=%d groups is not a multiple of 16", n)
    51  	}
    52  	return &Groupings{n: uint(n), bitsetThreshold: uint16(n / 16)}, nil
    53  }
    54  
    55  // WithinRange checks whether group is within the admissible range for
    56  // labeling otherwise it returns an error.
    57  func (gr *Groupings) WithinRange(group uint16) error {
    58  	if uint(group) >= gr.n {
    59  		return fmt.Errorf("group exceeds admissible maximum: %d >= %d", group, gr.n)
    60  	}
    61  	return nil
    62  }
    63  
    64  type Grouping struct {
    65  	size  uint16
    66  	elems []uint16
    67  }
    68  
    69  func (g Grouping) Copy() Grouping {
    70  	elems2 := make([]uint16, len(g.elems), cap(g.elems))
    71  	copy(elems2[:], g.elems[:])
    72  	g.elems = elems2
    73  	return g
    74  }
    75  
    76  // search locates group among the sorted Grouping elements, it returns:
    77  //  * true if found
    78  //  * false if not found
    79  //  * the index at which group should be inserted to keep the
    80  //    elements sorted if not found and the bit-set representation is not in use
    81  func (gr *Groupings) search(g *Grouping, group uint16) (found bool, j uint16) {
    82  	if g.size > gr.bitsetThreshold {
    83  		return bitsetContains(g, group), 0
    84  	}
    85  	j = uint16(sort.Search(int(g.size), func(i int) bool { return g.elems[i] >= group }))
    86  	if j < g.size && g.elems[j] == group {
    87  		return true, 0
    88  	}
    89  	return false, j
    90  }
    91  
    92  func bitsetContains(g *Grouping, group uint16) bool {
    93  	return (g.elems[group/16] & (1 << (group % 16))) != 0
    94  }
    95  
    96  // AddTo adds the given group to the grouping.
    97  func (gr *Groupings) AddTo(g *Grouping, group uint16) error {
    98  	if err := gr.WithinRange(group); err != nil {
    99  		return err
   100  	}
   101  	if group > gr.maxGroup {
   102  		gr.maxGroup = group
   103  	}
   104  	if g.size == 0 {
   105  		g.size = 1
   106  		g.elems = []uint16{group}
   107  		return nil
   108  	}
   109  	found, j := gr.search(g, group)
   110  	if found {
   111  		return nil
   112  	}
   113  	newsize := g.size + 1
   114  	if newsize > gr.bitsetThreshold {
   115  		// switching to a bit-set representation after the size point
   116  		// where the space cost is the same, the representation uses
   117  		// bitsetThreshold-many 16-bits words stored in elems.
   118  		// We don't always use the bit-set representation because
   119  		// * we expect small groupings and iteration to be common,
   120  		//   iteration is more costly over the bit-set representation
   121  		// * serialization matches more or less what we do in memory,
   122  		//   so again is more efficient for small groupings in the
   123  		//   extensive representation.
   124  		if g.size == gr.bitsetThreshold {
   125  			prevelems := g.elems
   126  			g.elems = make([]uint16, gr.bitsetThreshold)
   127  			for _, e := range prevelems {
   128  				bitsetAdd(g, e)
   129  			}
   130  		}
   131  		g.size = newsize
   132  		bitsetAdd(g, group)
   133  		return nil
   134  	}
   135  	var newelems []uint16
   136  	if int(g.size) == cap(g.elems) {
   137  		newelems = make([]uint16, newsize, cap(g.elems)*2)
   138  		copy(newelems, g.elems[:j])
   139  	} else {
   140  		newelems = g.elems[:newsize]
   141  	}
   142  	if j < g.size {
   143  		copy(newelems[j+1:], g.elems[j:])
   144  	}
   145  	// inserting new group at j index keeping the elements sorted
   146  	newelems[j] = group
   147  	g.size = newsize
   148  	g.elems = newelems
   149  	return nil
   150  }
   151  
   152  func bitsetAdd(g *Grouping, group uint16) {
   153  	g.elems[group/16] |= 1 << (group % 16)
   154  }
   155  
   156  // Contains returns whether the given group is a member of the grouping.
   157  func (gr *Groupings) Contains(g *Grouping, group uint16) bool {
   158  	found, _ := gr.search(g, group)
   159  	return found
   160  }
   161  
   162  // Serialize produces a string encoding the given integers.
   163  func Serialize(elems []uint16) string {
   164  	b := bytes.NewBuffer(make([]byte, 0, len(elems)*2))
   165  	binary.Write(b, binary.LittleEndian, elems)
   166  	return base64.RawURLEncoding.EncodeToString(b.Bytes())
   167  }
   168  
   169  // Serialize produces a string representing the grouping label.
   170  func (gr *Groupings) Serialize(g *Grouping) string {
   171  	// groupings are serialized as:
   172  	//  * the actual element groups if there are up to
   173  	//    bitsetThreshold elements: elems[0], elems[1], ...
   174  	//  * otherwise the number of elements, followed by the bitset
   175  	//    representation comprised of bitsetThreshold-many 16-bits words
   176  	//    (stored using elems as well)
   177  	if g.size > gr.bitsetThreshold {
   178  		return gr.bitsetSerialize(g)
   179  	}
   180  	return Serialize(g.elems)
   181  }
   182  
   183  func (gr *Groupings) bitsetSerialize(g *Grouping) string {
   184  	b := bytes.NewBuffer(make([]byte, 0, (gr.bitsetThreshold+1)*2))
   185  	binary.Write(b, binary.LittleEndian, g.size)
   186  	binary.Write(b, binary.LittleEndian, g.elems)
   187  	return base64.RawURLEncoding.EncodeToString(b.Bytes())
   188  }
   189  
   190  const errSerializedLabelFmt = "invalid serialized grouping label: %v"
   191  
   192  // Deserialize reconstructs a grouping out of the serialized label.
   193  func (gr *Groupings) Deserialize(label string) (*Grouping, error) {
   194  	b, err := base64.RawURLEncoding.DecodeString(label)
   195  	if err != nil {
   196  		return nil, fmt.Errorf(errSerializedLabelFmt, err)
   197  	}
   198  	if len(b)%2 != 0 {
   199  		return nil, fmt.Errorf(errSerializedLabelFmt, "not divisible into 16-bits words")
   200  	}
   201  	m := len(b) / 2
   202  	var g Grouping
   203  	if m == int(gr.bitsetThreshold+1) {
   204  		// deserialize number of elements + bitset representation
   205  		// comprising bitsetThreshold-many 16-bits words
   206  		return gr.bitsetDeserialize(&g, b)
   207  	}
   208  	if m > int(gr.bitsetThreshold) {
   209  		return nil, fmt.Errorf(errSerializedLabelFmt, "too large")
   210  	}
   211  	g.size = uint16(m)
   212  	esz := uint16(1)
   213  	for esz < g.size {
   214  		esz *= 2
   215  	}
   216  	g.elems = make([]uint16, g.size, esz)
   217  	binary.Read(bytes.NewBuffer(b), binary.LittleEndian, g.elems)
   218  	for i, e := range g.elems {
   219  		if e > gr.maxGroup {
   220  			return nil, fmt.Errorf(errSerializedLabelFmt, "element larger than maximum group")
   221  		}
   222  		if i > 0 && g.elems[i-1] >= e {
   223  			return nil, fmt.Errorf(errSerializedLabelFmt, "not sorted")
   224  		}
   225  	}
   226  	return &g, nil
   227  }
   228  
   229  func (gr *Groupings) bitsetDeserialize(g *Grouping, b []byte) (*Grouping, error) {
   230  	buf := bytes.NewBuffer(b)
   231  	binary.Read(buf, binary.LittleEndian, &g.size)
   232  	if g.size > gr.maxGroup+1 {
   233  		return nil, fmt.Errorf(errSerializedLabelFmt, "bitset size cannot be possibly larger than maximum group plus 1")
   234  	}
   235  	if g.size <= gr.bitsetThreshold {
   236  		// should not have used a bitset repr for so few elements
   237  		return nil, fmt.Errorf(errSerializedLabelFmt, "bitset for too few elements")
   238  	}
   239  	g.elems = make([]uint16, gr.bitsetThreshold)
   240  	binary.Read(buf, binary.LittleEndian, g.elems)
   241  	return g, nil
   242  }
   243  
   244  // Iter iterates over the groups in the grouping and calls f with each of
   245  // them. If f returns an error Iter immediately returns with it.
   246  func (gr *Groupings) Iter(g *Grouping, f func(group uint16) error) error {
   247  	if g.size > gr.bitsetThreshold {
   248  		return gr.bitsetIter(g, f)
   249  	}
   250  	for _, e := range g.elems {
   251  		if err := f(e); err != nil {
   252  			return err
   253  		}
   254  	}
   255  	return nil
   256  }
   257  
   258  func (gr *Groupings) bitsetIter(g *Grouping, f func(group uint16) error) error {
   259  	c := g.size
   260  	for i := uint16(0); i <= gr.maxGroup/16; i++ {
   261  		w := g.elems[i]
   262  		if w == 0 {
   263  			continue
   264  		}
   265  		for j := uint16(0); w != 0; j++ {
   266  			if w&1 != 0 {
   267  				if err := f(i*16 + j); err != nil {
   268  					return err
   269  				}
   270  				c--
   271  				if c == 0 {
   272  					// found all elements
   273  					return nil
   274  				}
   275  			}
   276  			w >>= 1
   277  		}
   278  	}
   279  	return nil
   280  }