github.com/Lephar/snapd@v0.0.0-20210825215435-c7fba9cef4d2/overlord/hookstate/ctlcmd/helpers.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package ctlcmd
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/json"
    25  	"fmt"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/snapcore/snapd/i18n"
    30  	"github.com/snapcore/snapd/jsonutil"
    31  	"github.com/snapcore/snapd/overlord/configstate"
    32  	"github.com/snapcore/snapd/overlord/hookstate"
    33  	"github.com/snapcore/snapd/overlord/servicestate"
    34  	"github.com/snapcore/snapd/overlord/snapstate"
    35  	"github.com/snapcore/snapd/overlord/state"
    36  	"github.com/snapcore/snapd/snap"
    37  )
    38  
    39  var finalTasks map[string]bool
    40  
    41  func init() {
    42  	finalTasks = make(map[string]bool, len(snapstate.FinalTasks))
    43  	for _, kind := range snapstate.FinalTasks {
    44  		finalTasks[kind] = true
    45  	}
    46  }
    47  
    48  func getServiceInfos(st *state.State, snapName string, serviceNames []string) ([]*snap.AppInfo, error) {
    49  	st.Lock()
    50  	defer st.Unlock()
    51  
    52  	var snapst snapstate.SnapState
    53  	if err := snapstate.Get(st, snapName, &snapst); err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	info, err := snapst.CurrentInfo()
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	if len(serviceNames) == 0 {
    62  		// all services
    63  		return info.Services(), nil
    64  	}
    65  
    66  	var svcs []*snap.AppInfo
    67  	for _, svcName := range serviceNames {
    68  		if svcName == snapName {
    69  			// all the services
    70  			return info.Services(), nil
    71  		}
    72  		if !strings.HasPrefix(svcName, snapName+".") {
    73  			return nil, fmt.Errorf(i18n.G("unknown service: %q"), svcName)
    74  		}
    75  		// this doesn't support service aliases
    76  		app, ok := info.Apps[svcName[1+len(snapName):]]
    77  		if !(ok && app.IsService()) {
    78  			return nil, fmt.Errorf(i18n.G("unknown service: %q"), svcName)
    79  		}
    80  		svcs = append(svcs, app)
    81  	}
    82  
    83  	return svcs, nil
    84  }
    85  
    86  var servicestateControl = servicestate.Control
    87  
    88  func queueCommand(context *hookstate.Context, tts []*state.TaskSet) error {
    89  	hookTask, ok := context.Task()
    90  	if !ok {
    91  		return fmt.Errorf("attempted to queue command with ephemeral context")
    92  	}
    93  
    94  	st := context.State()
    95  	st.Lock()
    96  	defer st.Unlock()
    97  
    98  	change := hookTask.Change()
    99  	hookTaskLanes := hookTask.Lanes()
   100  	tasks := change.LaneTasks(hookTaskLanes...)
   101  
   102  	// When installing or updating multiple snaps, there is one lane per snap.
   103  	// We want service command to join respective lane (it's the lane the hook belongs to).
   104  	// In case there are no lanes, only the default lane no. 0, there is no need to join it.
   105  	if len(hookTaskLanes) == 1 && hookTaskLanes[0] == 0 {
   106  		hookTaskLanes = nil
   107  	}
   108  	for _, l := range hookTaskLanes {
   109  		for _, ts := range tts {
   110  			ts.JoinLane(l)
   111  		}
   112  	}
   113  
   114  	for _, ts := range tts {
   115  		for _, t := range tasks {
   116  			// queue service command after all tasks, except for final tasks which must come after service commands
   117  			if finalTasks[t.Kind()] {
   118  				t.WaitAll(ts)
   119  			} else {
   120  				ts.WaitFor(t)
   121  			}
   122  		}
   123  		change.AddAll(ts)
   124  	}
   125  	// As this can be run from what was originally the last task of a change,
   126  	// make sure the tasks added to the change are considered immediately.
   127  	st.EnsureBefore(0)
   128  
   129  	return nil
   130  }
   131  
   132  func runServiceCommand(context *hookstate.Context, inst *servicestate.Instruction) error {
   133  	if context == nil {
   134  		// this message is reused in health.go
   135  		return fmt.Errorf(i18n.G("cannot %s without a context"), inst.Action)
   136  	}
   137  
   138  	st := context.State()
   139  	appInfos, err := getServiceInfos(st, context.InstanceName(), inst.Names)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	flags := &servicestate.Flags{CreateExecCommandTasks: true}
   145  	// passing context so we can ignore self-conflicts with the current change
   146  	st.Lock()
   147  	tts, err := servicestateControl(st, appInfos, inst, flags, context)
   148  	st.Unlock()
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	if !context.IsEphemeral() && context.HookName() == "configure" {
   154  		return queueCommand(context, tts)
   155  	}
   156  
   157  	st.Lock()
   158  	chg := st.NewChange("service-control", fmt.Sprintf("Running service command for snap %q", context.InstanceName()))
   159  	for _, ts := range tts {
   160  		chg.AddAll(ts)
   161  	}
   162  	st.EnsureBefore(0)
   163  	st.Unlock()
   164  
   165  	select {
   166  	case <-chg.Ready():
   167  		st.Lock()
   168  		defer st.Unlock()
   169  		return chg.Err()
   170  	case <-time.After(configstate.ConfigureHookTimeout() / 2):
   171  		return fmt.Errorf("%s command is taking too long", inst.Action)
   172  	}
   173  }
   174  
   175  // NoAttributeError indicates that an interface attribute is not set.
   176  type NoAttributeError struct {
   177  	Attribute string
   178  }
   179  
   180  func (e *NoAttributeError) Error() string {
   181  	return fmt.Sprintf("no %q attribute", e.Attribute)
   182  }
   183  
   184  // isNoAttribute returns whether the provided error is a *NoAttributeError.
   185  func isNoAttribute(err error) bool {
   186  	_, ok := err.(*NoAttributeError)
   187  	return ok
   188  }
   189  
   190  func jsonRaw(v interface{}) *json.RawMessage {
   191  	data, err := json.Marshal(v)
   192  	if err != nil {
   193  		panic(fmt.Errorf("internal error: cannot marshal attributes: %v", err))
   194  	}
   195  	raw := json.RawMessage(data)
   196  	return &raw
   197  }
   198  
   199  // getAttribute unmarshals into result the value of the provided key from attributes map.
   200  // If the key does not exist, an error of type *NoAttributeError is returned.
   201  // The provided key may be formed as a dotted key path through nested maps.
   202  // For example, the "a.b.c" key describes the {a: {b: {c: value}}} map.
   203  func getAttribute(snapName string, subkeys []string, pos int, attrs map[string]interface{}, result interface{}) error {
   204  	if pos >= len(subkeys) {
   205  		return fmt.Errorf("internal error: invalid subkeys index %d for subkeys %q", pos, subkeys)
   206  	}
   207  	value, ok := attrs[subkeys[pos]]
   208  	if !ok {
   209  		return &NoAttributeError{Attribute: strings.Join(subkeys[:pos+1], ".")}
   210  	}
   211  
   212  	if pos+1 == len(subkeys) {
   213  		raw, ok := value.(*json.RawMessage)
   214  		if !ok {
   215  			raw = jsonRaw(value)
   216  		}
   217  		if err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &result); err != nil {
   218  			key := strings.Join(subkeys, ".")
   219  			return fmt.Errorf("internal error: cannot unmarshal snap %s attribute %q into %T: %s, json: %s", snapName, key, result, err, *raw)
   220  		}
   221  		return nil
   222  	}
   223  
   224  	attrsm, ok := value.(map[string]interface{})
   225  	if !ok {
   226  		raw, ok := value.(*json.RawMessage)
   227  		if !ok {
   228  			raw = jsonRaw(value)
   229  		}
   230  		if err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &attrsm); err != nil {
   231  			return fmt.Errorf("snap %q attribute %q is not a map", snapName, strings.Join(subkeys[:pos+1], "."))
   232  		}
   233  	}
   234  	return getAttribute(snapName, subkeys, pos+1, attrsm, result)
   235  }