github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/snap/epoch.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snap
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	"strconv"
    26  
    27  	"github.com/snapcore/snapd/logger"
    28  )
    29  
    30  // An Epoch represents the ability of the snap to read and write its data. Most
    31  // developers need not worry about it, and snaps default to the 0th epoch, and
    32  // users are only offered refreshes to epoch 0 snaps. Once an epoch bump is in
    33  // order, there's a simplified expression they can use which should cover the
    34  // majority of the cases:
    35  //
    36  //   epoch: N
    37  //
    38  // means a snap can read/write exactly the Nth epoch's data, and
    39  //
    40  //   epoch: N*
    41  //
    42  // means a snap can additionally read (N-1)th epoch's data, which means it's a
    43  // snap that can migrate epochs (so a user on epoch 0 can get offered a refresh
    44  // to a snap on epoch 1*).
    45  //
    46  // If the above is not enough, a developer can explicitly describe what epochs a
    47  // snap can read and write:
    48  //
    49  //   epoch:
    50  //     read: [1, 2, 3]
    51  //     write: [1, 3]
    52  //
    53  // the read attribute defaults to the value of the write attribute, and the
    54  // write attribute defaults to the last item in the read attribute. If both are
    55  // unset, it's the same as not specifying an epoch at all (i.e. epoch: 0). The
    56  // lists must not have more than 10 elements, they must be strictly increasing,
    57  // and there must be a non-empty intersection between them.
    58  //
    59  // Epoch numbers must be written in base 10, with no zero padding.
    60  type Epoch struct {
    61  	Read  []uint32 `yaml:"read"`
    62  	Write []uint32 `yaml:"write"`
    63  }
    64  
    65  // E returns the epoch represented by the expression s. It's meant for use in
    66  // testing, as it panics at the first sign of trouble.
    67  func E(s string) Epoch {
    68  	var e Epoch
    69  	if err := e.fromString(s); err != nil {
    70  		panic(fmt.Errorf("%q: %v", s, err))
    71  	}
    72  	return e
    73  }
    74  
    75  func (e *Epoch) fromString(s string) error {
    76  	if len(s) == 0 || s == "0" {
    77  		e.Read = []uint32{0}
    78  		e.Write = []uint32{0}
    79  		return nil
    80  	}
    81  	star := false
    82  	if s[len(s)-1] == '*' {
    83  		star = true
    84  		s = s[:len(s)-1]
    85  	}
    86  	n, err := parseInt(s)
    87  	if err != nil {
    88  		return err
    89  	}
    90  	if star {
    91  		if n == 0 {
    92  			return &EpochError{Message: epochZeroStar}
    93  		}
    94  		e.Read = []uint32{n - 1, n}
    95  	} else {
    96  		e.Read = []uint32{n}
    97  	}
    98  	e.Write = []uint32{n}
    99  
   100  	return nil
   101  }
   102  
   103  func (e *Epoch) fromStructured(structured structuredEpoch) error {
   104  	if structured.Read == nil {
   105  		if structured.Write == nil {
   106  			structured.Write = []uint32{0}
   107  		}
   108  		structured.Read = structured.Write
   109  	} else if len(structured.Read) == 0 {
   110  		// this means they explicitly set it to []. Bad they!
   111  		return &EpochError{Message: emptyEpochList}
   112  	}
   113  	if structured.Write == nil {
   114  		structured.Write = structured.Read[len(structured.Read)-1:]
   115  	} else if len(structured.Write) == 0 {
   116  		return &EpochError{Message: emptyEpochList}
   117  	}
   118  
   119  	p := &Epoch{Read: structured.Read, Write: structured.Write}
   120  	if err := p.Validate(); err != nil {
   121  		return err
   122  	}
   123  
   124  	*e = *p
   125  
   126  	return nil
   127  }
   128  
   129  func (e *Epoch) UnmarshalJSON(bs []byte) error {
   130  	return e.UnmarshalYAML(func(v interface{}) error {
   131  		return json.Unmarshal(bs, &v)
   132  	})
   133  }
   134  
   135  func (e *Epoch) UnmarshalYAML(unmarshal func(interface{}) error) error {
   136  	var shortEpoch string
   137  	if err := unmarshal(&shortEpoch); err == nil {
   138  		return e.fromString(shortEpoch)
   139  	}
   140  	var structured structuredEpoch
   141  	if err := unmarshal(&structured); err != nil {
   142  		return err
   143  	}
   144  
   145  	return e.fromStructured(structured)
   146  }
   147  
   148  // IsZero checks whether a snap's epoch is not set (or is set to the default
   149  // value of "0").  Also zero are some epochs that would be normalized to "0",
   150  // such as {"read": 0}, as well as some invalid ones like {"read": []}.
   151  func (e *Epoch) IsZero() bool {
   152  	if e == nil {
   153  		return true
   154  	}
   155  
   156  	rZero := len(e.Read) == 0 || (len(e.Read) == 1 && e.Read[0] == 0)
   157  	wZero := len(e.Write) == 0 || (len(e.Write) == 1 && e.Write[0] == 0)
   158  
   159  	return rZero && wZero
   160  }
   161  
   162  func epochListEq(a, b []uint32) bool {
   163  	if len(a) != len(b) {
   164  		return false
   165  	}
   166  	for i := range a {
   167  		if a[i] != b[i] {
   168  			return false
   169  		}
   170  	}
   171  	return true
   172  }
   173  
   174  func (e *Epoch) Equal(other *Epoch) bool {
   175  	if e.IsZero() {
   176  		return other.IsZero()
   177  	}
   178  	return epochListEq(e.Read, other.Read) && epochListEq(e.Write, other.Write)
   179  }
   180  
   181  // Validate checks that the epoch makes sense.
   182  func (e *Epoch) Validate() error {
   183  	if (e.Read != nil && len(e.Read) == 0) || (e.Write != nil && len(e.Write) == 0) {
   184  		// these are invalid, but if both are true then IsZero will be true.
   185  		// In practice this check is redundant because it's caught in deserialise.
   186  		// Belts-and-suspenders all the way down.
   187  		return &EpochError{Message: emptyEpochList}
   188  	}
   189  	if e.IsZero() {
   190  		return nil
   191  	}
   192  	if len(e.Read) > 10 || len(e.Write) > 10 {
   193  		return &EpochError{Message: epochListJustRidiculouslyLong}
   194  	}
   195  	if !isIncreasing(e.Read) || !isIncreasing(e.Write) {
   196  		return &EpochError{Message: epochListNotIncreasing}
   197  	}
   198  
   199  	if intersect(e.Read, e.Write) {
   200  		return nil
   201  	}
   202  	return &EpochError{Message: noEpochIntersection}
   203  }
   204  
   205  func (e *Epoch) simplify() interface{} {
   206  	if e.IsZero() {
   207  		return "0"
   208  	}
   209  	if len(e.Write) == 1 && len(e.Read) == 1 && e.Read[0] == e.Write[0] {
   210  		return strconv.FormatUint(uint64(e.Read[0]), 10)
   211  	}
   212  	if len(e.Write) == 1 && len(e.Read) == 2 && e.Read[0]+1 == e.Read[1] && e.Read[1] == e.Write[0] {
   213  		return strconv.FormatUint(uint64(e.Read[1]), 10) + "*"
   214  	}
   215  	return &structuredEpoch{Read: e.Read, Write: e.Write}
   216  }
   217  
   218  func (e Epoch) MarshalJSON() ([]byte, error) {
   219  	se := &structuredEpoch{Read: e.Read, Write: e.Write}
   220  	if len(se.Read) == 0 {
   221  		se.Read = uint32slice{0}
   222  	}
   223  	if len(se.Write) == 0 {
   224  		se.Write = uint32slice{0}
   225  	}
   226  	return json.Marshal(se)
   227  }
   228  
   229  func (Epoch) MarshalYAML() (interface{}, error) {
   230  	panic("unexpected attempt to marshal an Epoch to YAML")
   231  }
   232  
   233  func (e Epoch) String() string {
   234  	i := e.simplify()
   235  	if s, ok := i.(string); ok {
   236  		return s
   237  	}
   238  
   239  	buf, err := json.Marshal(i)
   240  	if err != nil {
   241  		// can this happen?
   242  		logger.Noticef("trying to marshal %#v, simplified to %#v, got %v", e, i, err)
   243  		return "-1"
   244  	}
   245  	return string(buf)
   246  }
   247  
   248  // CanRead checks whether this epoch can read the data written by the
   249  // other one.
   250  func (e *Epoch) CanRead(other Epoch) bool {
   251  	// the intersection between e.Read and other.Write needs to be non-empty
   252  
   253  	// normalize (empty epoch should be treated like "0" here)
   254  	var rs, ws []uint32
   255  	if e != nil {
   256  		rs = e.Read
   257  	}
   258  	ws = other.Write
   259  	if len(rs) == 0 {
   260  		rs = []uint32{0}
   261  	}
   262  	if len(ws) == 0 {
   263  		ws = []uint32{0}
   264  	}
   265  
   266  	return intersect(rs, ws)
   267  }
   268  
   269  func intersect(rs, ws []uint32) bool {
   270  	// O(𝑚𝑛) instead of O(𝑚log𝑛) for the binary search we could do, but
   271  	// 𝑚 and 𝑛 < 10, so the simple solution is good enough (and if that
   272  	// alone makes you nervous, know that it is ~2× faster in the worst
   273  	// case; bisect starts being faster at ~50 entries).
   274  	for _, r := range rs {
   275  		for _, w := range ws {
   276  			if r == w {
   277  				return true
   278  			}
   279  		}
   280  	}
   281  	return false
   282  }
   283  
   284  // EpochError tracks the details of a failed epoch parse or validation.
   285  type EpochError struct {
   286  	Message string
   287  }
   288  
   289  func (e EpochError) Error() string {
   290  	return e.Message
   291  }
   292  
   293  const (
   294  	epochZeroStar                 = "0* is an invalid epoch"
   295  	hugeEpochNumber               = "epoch numbers must be less than 2³², but got %q"
   296  	badEpochNumber                = "epoch numbers must be base 10 with no zero padding, but got %q"
   297  	badEpochList                  = "epoch read/write attributes must be lists of epoch numbers"
   298  	emptyEpochList                = "epoch list cannot be explicitly empty"
   299  	epochListNotIncreasing        = "epoch list must be a strictly increasing sequence"
   300  	epochListJustRidiculouslyLong = "epoch list must not have more than 10 entries"
   301  	noEpochIntersection           = "epoch read and write lists must have a non-empty intersection"
   302  )
   303  
   304  func parseInt(s string) (uint32, error) {
   305  	if !(len(s) > 1 && s[0] == '0') {
   306  		u, err := strconv.ParseUint(s, 10, 32)
   307  		if err == nil {
   308  			return uint32(u), nil
   309  		}
   310  		if e, ok := err.(*strconv.NumError); ok {
   311  			if e.Err == strconv.ErrRange {
   312  				return 0, &EpochError{
   313  					Message: fmt.Sprintf(hugeEpochNumber, s),
   314  				}
   315  			}
   316  		}
   317  	}
   318  	return 0, &EpochError{
   319  		Message: fmt.Sprintf(badEpochNumber, s),
   320  	}
   321  }
   322  
   323  type uint32slice []uint32
   324  
   325  func (z *uint32slice) UnmarshalYAML(unmarshal func(interface{}) error) error {
   326  	var ss []string
   327  	if err := unmarshal(&ss); err != nil {
   328  		return &EpochError{Message: badEpochList}
   329  	}
   330  	x := make([]uint32, len(ss))
   331  	for i, s := range ss {
   332  		n, err := parseInt(s)
   333  		if err != nil {
   334  			return err
   335  		}
   336  		x[i] = n
   337  	}
   338  	*z = x
   339  	return nil
   340  }
   341  
   342  func (z *uint32slice) UnmarshalJSON(bs []byte) error {
   343  	var ss []json.RawMessage
   344  	if err := json.Unmarshal(bs, &ss); err != nil {
   345  		return &EpochError{Message: badEpochList}
   346  	}
   347  	x := make([]uint32, len(ss))
   348  	for i, s := range ss {
   349  		n, err := parseInt(string(s))
   350  		if err != nil {
   351  			return err
   352  		}
   353  		x[i] = n
   354  	}
   355  	*z = x
   356  	return nil
   357  }
   358  
   359  func isIncreasing(z []uint32) bool {
   360  	if len(z) < 2 {
   361  		return true
   362  	}
   363  	for i := range z[1:] {
   364  		if z[i] >= z[i+1] {
   365  			return false
   366  		}
   367  	}
   368  	return true
   369  }
   370  
   371  type structuredEpoch struct {
   372  	Read  uint32slice `json:"read"`
   373  	Write uint32slice `json:"write"`
   374  }