github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/overlord/assertstate/bulk.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package assertstate
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"sort"
    26  	"strings"
    27  
    28  	"github.com/snapcore/snapd/asserts"
    29  	"github.com/snapcore/snapd/overlord/snapstate"
    30  	"github.com/snapcore/snapd/overlord/state"
    31  	"github.com/snapcore/snapd/release"
    32  	"github.com/snapcore/snapd/store"
    33  )
    34  
    35  const storeGroup = "store assertion"
    36  
    37  // maxGroups is the maximum number of assertion groups we set with the
    38  // asserts.Pool used to refresh snap assertions, it corresponds
    39  // roughly to for how many snaps we will request assertions in
    40  // in one /v2/snaps/refresh request.
    41  // Given that requesting assertions for ~500 snaps together with no
    42  // updates can take around 900ms-1s, conservatively set it to half of
    43  // that. Most systems should be done in one request anyway.
    44  var maxGroups = 256
    45  
    46  func bulkRefreshSnapDeclarations(s *state.State, snapStates map[string]*snapstate.SnapState, userID int, deviceCtx snapstate.DeviceContext) error {
    47  	db := cachedDB(s)
    48  
    49  	pool := asserts.NewPool(db, maxGroups)
    50  
    51  	var mergedRPErr *resolvePoolError
    52  	tryResolvePool := func() error {
    53  		err := resolvePool(s, pool, userID, deviceCtx)
    54  		if rpe, ok := err.(*resolvePoolError); ok {
    55  			if mergedRPErr == nil {
    56  				mergedRPErr = rpe
    57  			} else {
    58  				mergedRPErr.merge(rpe)
    59  			}
    60  			return nil
    61  		}
    62  		return err
    63  	}
    64  
    65  	c := 0
    66  	for instanceName, snapst := range snapStates {
    67  		sideInfo := snapst.CurrentSideInfo()
    68  		if sideInfo.SnapID == "" {
    69  			continue
    70  		}
    71  
    72  		declRef := &asserts.Ref{
    73  			Type:       asserts.SnapDeclarationType,
    74  			PrimaryKey: []string{release.Series, sideInfo.SnapID},
    75  		}
    76  		// update snap-declaration (and prereqs) for the snap,
    77  		// they were originally added at install time
    78  		if err := pool.AddToUpdate(declRef, instanceName); err != nil {
    79  			return fmt.Errorf("cannot prepare snap-declaration refresh for snap %q: %v", instanceName, err)
    80  		}
    81  
    82  		c++
    83  		if c%maxGroups == 0 {
    84  			// we have exhausted max groups, resolve
    85  			// what we setup so far and then clear groups
    86  			// to reuse the pool
    87  			if err := tryResolvePool(); err != nil {
    88  				return err
    89  			}
    90  			if err := pool.ClearGroups(); err != nil {
    91  				// this shouldn't happen but if it
    92  				// does fallback
    93  				return &bulkAssertionFallbackError{err}
    94  			}
    95  		}
    96  	}
    97  
    98  	modelAs := deviceCtx.Model()
    99  
   100  	// fetch store assertion if available
   101  	if modelAs.Store() != "" {
   102  		storeRef := asserts.Ref{
   103  			Type:       asserts.StoreType,
   104  			PrimaryKey: []string{modelAs.Store()},
   105  		}
   106  		if err := pool.AddToUpdate(&storeRef, storeGroup); err != nil {
   107  			if !asserts.IsNotFound(err) {
   108  				return fmt.Errorf("cannot prepare store assertion refresh: %v", err)
   109  			}
   110  			// assertion is not present in the db yet,
   111  			// we'll try to resolve it (fetch it) first
   112  			storeAt := &asserts.AtRevision{
   113  				Ref:      storeRef,
   114  				Revision: asserts.RevisionNotKnown,
   115  			}
   116  			err := pool.AddUnresolved(storeAt, storeGroup)
   117  			if err != nil {
   118  				return fmt.Errorf("cannot prepare store assertion fetching: %v", err)
   119  			}
   120  		}
   121  	}
   122  
   123  	if err := tryResolvePool(); err != nil {
   124  		return err
   125  	}
   126  
   127  	if mergedRPErr != nil {
   128  		if e := mergedRPErr.errors[storeGroup]; asserts.IsNotFound(e) || e == asserts.ErrUnresolved {
   129  			// ignore
   130  			delete(mergedRPErr.errors, storeGroup)
   131  		}
   132  		if len(mergedRPErr.errors) == 0 {
   133  			return nil
   134  		}
   135  		mergedRPErr.message = "cannot refresh snap-declarations for snaps"
   136  		return mergedRPErr
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  // marker error to request falling back to the old implemention for assertion
   143  // refreshes
   144  type bulkAssertionFallbackError struct {
   145  	err error
   146  }
   147  
   148  func (e *bulkAssertionFallbackError) Error() string {
   149  	return fmt.Sprintf("unsuccessful bulk assertion refresh, fallback: %v", e.err)
   150  }
   151  
   152  type resolvePoolError struct {
   153  	message string
   154  	// errors maps groups to errors
   155  	errors map[string]error
   156  }
   157  
   158  func (rpe *resolvePoolError) merge(rpe1 *resolvePoolError) {
   159  	// we expect usually rpe and rpe1 errors to be disjunct, but is also
   160  	// ok for rpe1 errors to win
   161  	for k, e := range rpe1.errors {
   162  		rpe.errors[k] = e
   163  	}
   164  }
   165  
   166  func (rpe *resolvePoolError) Error() string {
   167  	message := rpe.message
   168  	if message == "" {
   169  		message = "cannot fetch and resolve assertions"
   170  	}
   171  	s := make([]string, 0, 1+len(rpe.errors))
   172  	s = append(s, fmt.Sprintf("%s:", message))
   173  	groups := make([]string, 0, len(rpe.errors))
   174  	for g := range rpe.errors {
   175  		groups = append(groups, g)
   176  	}
   177  	sort.Strings(groups)
   178  	for _, g := range groups {
   179  		s = append(s, fmt.Sprintf(" - %s: %v", g, rpe.errors[g]))
   180  	}
   181  	return strings.Join(s, "\n")
   182  }
   183  
   184  func resolvePool(s *state.State, pool *asserts.Pool, userID int, deviceCtx snapstate.DeviceContext) error {
   185  	user, err := userFromUserID(s, userID)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	sto := snapstate.Store(s, deviceCtx)
   190  	db := cachedDB(s)
   191  	unsupported := handleUnsupported(db)
   192  
   193  	for {
   194  		// TODO: pass refresh options?
   195  		s.Unlock()
   196  		_, aresults, err := sto.SnapAction(context.TODO(), nil, nil, pool, user, nil)
   197  		s.Lock()
   198  		if err != nil {
   199  			// request fallback on
   200  			//  * unexpected SnapActionErrors or
   201  			//  * unexpected HTTP status of 4xx or 500
   202  			ignore := false
   203  			switch stoErr := err.(type) {
   204  			case *store.SnapActionError:
   205  				if !stoErr.NoResults || len(stoErr.Other) != 0 {
   206  					return &bulkAssertionFallbackError{stoErr}
   207  				}
   208  				// simply no results error, we are likely done
   209  				ignore = true
   210  			case *store.UnexpectedHTTPStatusError:
   211  				if stoErr.StatusCode >= 400 && stoErr.StatusCode <= 500 {
   212  					return &bulkAssertionFallbackError{stoErr}
   213  				}
   214  			}
   215  			if !ignore {
   216  				return err
   217  			}
   218  		}
   219  		if len(aresults) == 0 {
   220  			// everything resolved if no errors
   221  			break
   222  		}
   223  
   224  		for _, ares := range aresults {
   225  			b := asserts.NewBatch(unsupported)
   226  			s.Unlock()
   227  			err := sto.DownloadAssertions(ares.StreamURLs, b, user)
   228  			s.Lock()
   229  			if err != nil {
   230  				pool.AddGroupingError(err, ares.Grouping)
   231  				continue
   232  			}
   233  			_, err = pool.AddBatch(b, ares.Grouping)
   234  			if err != nil {
   235  				return err
   236  			}
   237  		}
   238  	}
   239  
   240  	pool.CommitTo(db)
   241  
   242  	errors := pool.Errors()
   243  	if len(errors) != 0 {
   244  		return &resolvePoolError{errors: errors}
   245  	}
   246  
   247  	return nil
   248  }