github.com/freetocompute/snapd@v0.0.0-20210618182524-2fb355d72fd9/cmd/snap-repair/export_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main
    21  
    22  import (
    23  	"net/url"
    24  	"time"
    25  
    26  	"gopkg.in/retry.v1"
    27  
    28  	"github.com/snapcore/snapd/asserts"
    29  	"github.com/snapcore/snapd/httputil"
    30  )
    31  
    32  var (
    33  	ParseArgs = parseArgs
    34  	Run       = run
    35  )
    36  
    37  func MockBaseURL(baseurl string) (restore func()) {
    38  	orig := baseURL
    39  	u, err := url.Parse(baseurl)
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  	baseURL = u
    44  	return func() {
    45  		baseURL = orig
    46  	}
    47  }
    48  
    49  func MockFetchRetryStrategy(strategy retry.Strategy) (restore func()) {
    50  	originalFetchRetryStrategy := fetchRetryStrategy
    51  	fetchRetryStrategy = strategy
    52  	return func() {
    53  		fetchRetryStrategy = originalFetchRetryStrategy
    54  	}
    55  }
    56  
    57  func MockPeekRetryStrategy(strategy retry.Strategy) (restore func()) {
    58  	originalPeekRetryStrategy := peekRetryStrategy
    59  	peekRetryStrategy = strategy
    60  	return func() {
    61  		peekRetryStrategy = originalPeekRetryStrategy
    62  	}
    63  }
    64  
    65  func MockMaxRepairScriptSize(maxSize int) (restore func()) {
    66  	originalMaxSize := maxRepairScriptSize
    67  	maxRepairScriptSize = maxSize
    68  	return func() {
    69  		maxRepairScriptSize = originalMaxSize
    70  	}
    71  }
    72  
    73  func MockTrustedRepairRootKeys(keys []*asserts.AccountKey) (restore func()) {
    74  	original := trustedRepairRootKeys
    75  	trustedRepairRootKeys = keys
    76  	return func() {
    77  		trustedRepairRootKeys = original
    78  	}
    79  }
    80  
    81  func TrustedRepairRootKeys() []*asserts.AccountKey {
    82  	return trustedRepairRootKeys
    83  }
    84  
    85  func (run *Runner) BrandModel() (brand, model string) {
    86  	return run.state.Device.Brand, run.state.Device.Model
    87  }
    88  
    89  func (run *Runner) BaseMode() (base, mode string) {
    90  	return run.state.Device.Base, run.state.Device.Mode
    91  }
    92  
    93  func (run *Runner) SetStateModified(modified bool) {
    94  	run.stateModified = modified
    95  }
    96  
    97  func (run *Runner) SetBrandModel(brand, model string) {
    98  	run.state.Device.Brand = brand
    99  	run.state.Device.Model = model
   100  }
   101  
   102  func (run *Runner) TimeLowerBound() time.Time {
   103  	return run.state.TimeLowerBound
   104  }
   105  
   106  func (run *Runner) TLSTime() time.Time {
   107  	return httputil.BaseTransport(run.cli).TLSClientConfig.Time()
   108  }
   109  
   110  func (run *Runner) Sequence(brand string) []*RepairState {
   111  	return run.state.Sequences[brand]
   112  }
   113  
   114  func (run *Runner) SetSequence(brand string, sequence []*RepairState) {
   115  	if run.state.Sequences == nil {
   116  		run.state.Sequences = make(map[string][]*RepairState)
   117  	}
   118  	run.state.Sequences[brand] = sequence
   119  }
   120  
   121  func MockDefaultRepairTimeout(d time.Duration) (restore func()) {
   122  	orig := defaultRepairTimeout
   123  	defaultRepairTimeout = d
   124  	return func() {
   125  		defaultRepairTimeout = orig
   126  	}
   127  }
   128  
   129  func MockErrtrackerReportRepair(mock func(string, string, string, map[string]string) (string, error)) (restore func()) {
   130  	prev := errtrackerReportRepair
   131  	errtrackerReportRepair = mock
   132  	return func() { errtrackerReportRepair = prev }
   133  }
   134  
   135  func MockTimeNow(f func() time.Time) (restore func()) {
   136  	origTimeNow := timeNow
   137  	timeNow = f
   138  	return func() { timeNow = origTimeNow }
   139  }
   140  
   141  func NewCmdShow(args ...string) *cmdShow {
   142  	cmdShow := &cmdShow{}
   143  	cmdShow.Positional.Repair = args
   144  	return cmdShow
   145  }
   146  
   147  func MockOsGetuid(f func() int) (restore func()) {
   148  	origOsGetuid := osGetuid
   149  	osGetuid = f
   150  	return func() { osGetuid = origOsGetuid }
   151  }