github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/snapstate/snapstatetest/devicectx.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018-2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapstatetest
    21  
    22  import (
    23  	"github.com/snapcore/snapd/asserts"
    24  	"github.com/snapcore/snapd/asserts/sysdb"
    25  	"github.com/snapcore/snapd/overlord/snapstate"
    26  	"github.com/snapcore/snapd/overlord/state"
    27  )
    28  
    29  type TrivialDeviceContext struct {
    30  	DeviceModel *asserts.Model
    31  	Remodeling  bool
    32  	CtxStore    snapstate.StoreService
    33  }
    34  
    35  func (dc *TrivialDeviceContext) Model() *asserts.Model {
    36  	return dc.DeviceModel
    37  }
    38  
    39  func (dc *TrivialDeviceContext) Store() snapstate.StoreService {
    40  	return dc.CtxStore
    41  }
    42  
    43  func (dc *TrivialDeviceContext) ForRemodeling() bool {
    44  	return dc.Remodeling
    45  }
    46  
    47  func MockDeviceModel(model *asserts.Model) (restore func()) {
    48  	var deviceCtx snapstate.DeviceContext
    49  	if model != nil {
    50  		deviceCtx = &TrivialDeviceContext{DeviceModel: model}
    51  	}
    52  	return MockDeviceContext(deviceCtx)
    53  }
    54  
    55  func MockDeviceContext(deviceCtx snapstate.DeviceContext) (restore func()) {
    56  	deviceCtxHook := func(st *state.State, task *state.Task, providedDeviceCtx snapstate.DeviceContext) (snapstate.DeviceContext, error) {
    57  		if providedDeviceCtx != nil {
    58  			return providedDeviceCtx, nil
    59  		}
    60  		if deviceCtx == nil {
    61  			return nil, state.ErrNoState
    62  		}
    63  		return deviceCtx, nil
    64  	}
    65  	r1 := ReplaceDeviceCtxHook(deviceCtxHook)
    66  	// for convenience reflect from the context whether there is a
    67  	// remodeling
    68  	r2 := ReplaceRemodelingHook(func(*state.State) bool {
    69  		return deviceCtx != nil && deviceCtx.ForRemodeling()
    70  	})
    71  	return func() {
    72  		r1()
    73  		r2()
    74  	}
    75  }
    76  
    77  func ReplaceDeviceCtxHook(deviceCtxHook func(st *state.State, task *state.Task, providedDeviceCtx snapstate.DeviceContext) (snapstate.DeviceContext, error)) (restore func()) {
    78  	oldHook := snapstate.DeviceCtx
    79  	snapstate.DeviceCtx = deviceCtxHook
    80  	return func() {
    81  		snapstate.DeviceCtx = oldHook
    82  	}
    83  }
    84  
    85  func UseFallbackDeviceModel() (restore func()) {
    86  	return MockDeviceModel(sysdb.GenericClassicModel())
    87  }
    88  
    89  func ReplaceRemodelingHook(remodelingHook func(st *state.State) bool) (restore func()) {
    90  	oldHook := snapstate.Remodeling
    91  	snapstate.Remodeling = remodelingHook
    92  	return func() {
    93  		snapstate.Remodeling = oldHook
    94  	}
    95  }