github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/assertstate/bulk.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package assertstate
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"sort"
    26  	"strings"
    27  
    28  	"github.com/snapcore/snapd/asserts"
    29  	"github.com/snapcore/snapd/asserts/snapasserts"
    30  	"github.com/snapcore/snapd/overlord/snapstate"
    31  	"github.com/snapcore/snapd/overlord/state"
    32  	"github.com/snapcore/snapd/release"
    33  	"github.com/snapcore/snapd/store"
    34  )
    35  
    36  const storeGroup = "store assertion"
    37  
    38  // maxGroups is the maximum number of assertion groups we set with the
    39  // asserts.Pool used to refresh snap assertions, it corresponds
    40  // roughly to for how many snaps we will request assertions in
    41  // in one /v2/snaps/refresh request.
    42  // Given that requesting assertions for ~500 snaps together with no
    43  // updates can take around 900ms-1s, conservatively set it to half of
    44  // that. Most systems should be done in one request anyway.
    45  var maxGroups = 256
    46  
    47  func bulkRefreshSnapDeclarations(s *state.State, snapStates map[string]*snapstate.SnapState, userID int, deviceCtx snapstate.DeviceContext) error {
    48  	db := cachedDB(s)
    49  
    50  	pool := asserts.NewPool(db, maxGroups)
    51  
    52  	var mergedRPErr *resolvePoolError
    53  	tryResolvePool := func() error {
    54  		err := resolvePool(s, pool, nil, userID, deviceCtx)
    55  		if rpe, ok := err.(*resolvePoolError); ok {
    56  			if mergedRPErr == nil {
    57  				mergedRPErr = rpe
    58  			} else {
    59  				mergedRPErr.merge(rpe)
    60  			}
    61  			return nil
    62  		}
    63  		return err
    64  	}
    65  
    66  	c := 0
    67  	for instanceName, snapst := range snapStates {
    68  		sideInfo := snapst.CurrentSideInfo()
    69  		if sideInfo.SnapID == "" {
    70  			continue
    71  		}
    72  
    73  		declRef := &asserts.Ref{
    74  			Type:       asserts.SnapDeclarationType,
    75  			PrimaryKey: []string{release.Series, sideInfo.SnapID},
    76  		}
    77  		// update snap-declaration (and prereqs) for the snap,
    78  		// they were originally added at install time
    79  		if err := pool.AddToUpdate(declRef, instanceName); err != nil {
    80  			return fmt.Errorf("cannot prepare snap-declaration refresh for snap %q: %v", instanceName, err)
    81  		}
    82  
    83  		c++
    84  		if c%maxGroups == 0 {
    85  			// we have exhausted max groups, resolve
    86  			// what we setup so far and then clear groups
    87  			// to reuse the pool
    88  			if err := tryResolvePool(); err != nil {
    89  				return err
    90  			}
    91  			if err := pool.ClearGroups(); err != nil {
    92  				// this shouldn't happen but if it
    93  				// does fallback
    94  				return &bulkAssertionFallbackError{err}
    95  			}
    96  		}
    97  	}
    98  
    99  	modelAs := deviceCtx.Model()
   100  
   101  	// fetch store assertion if available
   102  	if modelAs.Store() != "" {
   103  		storeRef := asserts.Ref{
   104  			Type:       asserts.StoreType,
   105  			PrimaryKey: []string{modelAs.Store()},
   106  		}
   107  		if err := pool.AddToUpdate(&storeRef, storeGroup); err != nil {
   108  			if !asserts.IsNotFound(err) {
   109  				return fmt.Errorf("cannot prepare store assertion refresh: %v", err)
   110  			}
   111  			// assertion is not present in the db yet,
   112  			// we'll try to resolve it (fetch it) first
   113  			storeAt := &asserts.AtRevision{
   114  				Ref:      storeRef,
   115  				Revision: asserts.RevisionNotKnown,
   116  			}
   117  			err := pool.AddUnresolved(storeAt, storeGroup)
   118  			if err != nil {
   119  				return fmt.Errorf("cannot prepare store assertion fetching: %v", err)
   120  			}
   121  		}
   122  	}
   123  
   124  	if err := tryResolvePool(); err != nil {
   125  		return err
   126  	}
   127  
   128  	if mergedRPErr != nil {
   129  		if e := mergedRPErr.errors[storeGroup]; asserts.IsNotFound(e) || e == asserts.ErrUnresolved {
   130  			// ignore
   131  			delete(mergedRPErr.errors, storeGroup)
   132  		}
   133  		if len(mergedRPErr.errors) == 0 {
   134  			return nil
   135  		}
   136  		mergedRPErr.message = "cannot refresh snap-declarations for snaps"
   137  		return mergedRPErr
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func bulkRefreshValidationSetAsserts(s *state.State, vsets map[string]*ValidationSetTracking, beforeCommitChecker func(*asserts.Database, asserts.Backstore) error, userID int, deviceCtx snapstate.DeviceContext) error {
   144  	db := cachedDB(s)
   145  	pool := asserts.NewPool(db, maxGroups)
   146  
   147  	ignoreNotFound := make(map[string]bool)
   148  
   149  	for _, vs := range vsets {
   150  		var atSeq *asserts.AtSequence
   151  		if vs.PinnedAt > 0 {
   152  			// pinned to specific sequence, update to latest revision for same
   153  			// sequence.
   154  			atSeq = &asserts.AtSequence{
   155  				Type:        asserts.ValidationSetType,
   156  				SequenceKey: []string{release.Series, vs.AccountID, vs.Name},
   157  				Sequence:    vs.PinnedAt,
   158  				Pinned:      true,
   159  			}
   160  		} else {
   161  			// not pinned, update to latest sequence
   162  			atSeq = &asserts.AtSequence{
   163  				Type:        asserts.ValidationSetType,
   164  				SequenceKey: []string{release.Series, vs.AccountID, vs.Name},
   165  				Sequence:    vs.Current,
   166  			}
   167  		}
   168  		// every sequence to resolve has own group
   169  		group := atSeq.Unique()
   170  		if vs.LocalOnly {
   171  			ignoreNotFound[group] = true
   172  		}
   173  		if err := pool.AddSequenceToUpdate(atSeq, group); err != nil {
   174  			return err
   175  		}
   176  	}
   177  
   178  	err := resolvePoolNoFallback(s, pool, beforeCommitChecker, userID, deviceCtx)
   179  	if err == nil {
   180  		return nil
   181  	}
   182  
   183  	if _, ok := err.(*snapasserts.ValidationSetsConflictError); ok {
   184  		return err
   185  	}
   186  
   187  	if rerr, ok := err.(*resolvePoolError); ok {
   188  		// ignore resolving errors for validation sets that are local only (no
   189  		// assertion in the store).
   190  		for group := range ignoreNotFound {
   191  			if e := rerr.errors[group]; asserts.IsNotFound(e) || e == asserts.ErrUnresolved {
   192  				delete(rerr.errors, group)
   193  			}
   194  		}
   195  		if len(rerr.errors) == 0 {
   196  			return nil
   197  		}
   198  	}
   199  
   200  	return fmt.Errorf("cannot refresh validation set assertions: %v", err)
   201  }
   202  
   203  // marker error to request falling back to the old implemention for assertion
   204  // refreshes
   205  type bulkAssertionFallbackError struct {
   206  	err error
   207  }
   208  
   209  func (e *bulkAssertionFallbackError) Error() string {
   210  	return fmt.Sprintf("unsuccessful bulk assertion refresh, fallback: %v", e.err)
   211  }
   212  
   213  type resolvePoolError struct {
   214  	message string
   215  	// errors maps groups to errors
   216  	errors map[string]error
   217  }
   218  
   219  func (rpe *resolvePoolError) merge(rpe1 *resolvePoolError) {
   220  	// we expect usually rpe and rpe1 errors to be disjunct, but is also
   221  	// ok for rpe1 errors to win
   222  	for k, e := range rpe1.errors {
   223  		rpe.errors[k] = e
   224  	}
   225  }
   226  
   227  func (rpe *resolvePoolError) Error() string {
   228  	message := rpe.message
   229  	if message == "" {
   230  		message = "cannot fetch and resolve assertions"
   231  	}
   232  	s := make([]string, 0, 1+len(rpe.errors))
   233  	s = append(s, fmt.Sprintf("%s:", message))
   234  	groups := make([]string, 0, len(rpe.errors))
   235  	for g := range rpe.errors {
   236  		groups = append(groups, g)
   237  	}
   238  	sort.Strings(groups)
   239  	for _, g := range groups {
   240  		s = append(s, fmt.Sprintf(" - %s: %v", g, rpe.errors[g]))
   241  	}
   242  	return strings.Join(s, "\n")
   243  }
   244  
   245  func resolvePool(s *state.State, pool *asserts.Pool, checkBeforeCommit func(*asserts.Database, asserts.Backstore) error, userID int, deviceCtx snapstate.DeviceContext) error {
   246  	user, err := userFromUserID(s, userID)
   247  	if err != nil {
   248  		return err
   249  	}
   250  	sto := snapstate.Store(s, deviceCtx)
   251  	db := cachedDB(s)
   252  	unsupported := handleUnsupported(db)
   253  
   254  	for {
   255  		// TODO: pass refresh options?
   256  		s.Unlock()
   257  		_, aresults, err := sto.SnapAction(context.TODO(), nil, nil, pool, user, nil)
   258  		s.Lock()
   259  		if err != nil {
   260  			// request fallback on
   261  			//  * unexpected SnapActionErrors or
   262  			//  * unexpected HTTP status of 4xx or 500
   263  			ignore := false
   264  			switch stoErr := err.(type) {
   265  			case *store.SnapActionError:
   266  				if !stoErr.NoResults || len(stoErr.Other) != 0 {
   267  					return &bulkAssertionFallbackError{stoErr}
   268  				}
   269  				// simply no results error, we are likely done
   270  				ignore = true
   271  			case *store.UnexpectedHTTPStatusError:
   272  				if stoErr.StatusCode >= 400 && stoErr.StatusCode <= 500 {
   273  					return &bulkAssertionFallbackError{stoErr}
   274  				}
   275  			}
   276  			if !ignore {
   277  				return err
   278  			}
   279  		}
   280  		if len(aresults) == 0 {
   281  			// everything resolved if no errors
   282  			break
   283  		}
   284  
   285  		for _, ares := range aresults {
   286  			b := asserts.NewBatch(unsupported)
   287  			s.Unlock()
   288  			err := sto.DownloadAssertions(ares.StreamURLs, b, user)
   289  			s.Lock()
   290  			if err != nil {
   291  				pool.AddGroupingError(err, ares.Grouping)
   292  				continue
   293  			}
   294  			_, err = pool.AddBatch(b, ares.Grouping)
   295  			if err != nil {
   296  				return err
   297  			}
   298  		}
   299  	}
   300  
   301  	if checkBeforeCommit != nil {
   302  		if err := checkBeforeCommit(db, pool.Backstore()); err != nil {
   303  			return err
   304  		}
   305  	}
   306  	pool.CommitTo(db)
   307  
   308  	errors := pool.Errors()
   309  	if len(errors) != 0 {
   310  		return &resolvePoolError{errors: errors}
   311  	}
   312  
   313  	return nil
   314  }
   315  
   316  func resolvePoolNoFallback(s *state.State, pool *asserts.Pool, checkBeforeCommit func(*asserts.Database, asserts.Backstore) error, userID int, deviceCtx snapstate.DeviceContext) error {
   317  	err := resolvePool(s, pool, checkBeforeCommit, userID, deviceCtx)
   318  	if err != nil {
   319  		// no fallback, report inner error.
   320  		if ferr, ok := err.(*bulkAssertionFallbackError); ok {
   321  			err = ferr.err
   322  		}
   323  	}
   324  	return err
   325  }