github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/cmd/snap-repair/export_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main
    21  
    22  import (
    23  	"net/url"
    24  	"time"
    25  
    26  	"gopkg.in/retry.v1"
    27  
    28  	"github.com/snapcore/snapd/asserts"
    29  	"github.com/snapcore/snapd/httputil"
    30  )
    31  
    32  var (
    33  	ParseArgs = parseArgs
    34  	Run       = run
    35  )
    36  
    37  func MockBaseURL(baseurl string) (restore func()) {
    38  	orig := baseURL
    39  	u, err := url.Parse(baseurl)
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  	baseURL = u
    44  	return func() {
    45  		baseURL = orig
    46  	}
    47  }
    48  
    49  func MockFetchRetryStrategy(strategy retry.Strategy) (restore func()) {
    50  	originalFetchRetryStrategy := fetchRetryStrategy
    51  	fetchRetryStrategy = strategy
    52  	return func() {
    53  		fetchRetryStrategy = originalFetchRetryStrategy
    54  	}
    55  }
    56  
    57  func MockPeekRetryStrategy(strategy retry.Strategy) (restore func()) {
    58  	originalPeekRetryStrategy := peekRetryStrategy
    59  	peekRetryStrategy = strategy
    60  	return func() {
    61  		peekRetryStrategy = originalPeekRetryStrategy
    62  	}
    63  }
    64  
    65  func MockMaxRepairScriptSize(maxSize int) (restore func()) {
    66  	originalMaxSize := maxRepairScriptSize
    67  	maxRepairScriptSize = maxSize
    68  	return func() {
    69  		maxRepairScriptSize = originalMaxSize
    70  	}
    71  }
    72  
    73  func MockTrustedRepairRootKeys(keys []*asserts.AccountKey) (restore func()) {
    74  	original := trustedRepairRootKeys
    75  	trustedRepairRootKeys = keys
    76  	return func() {
    77  		trustedRepairRootKeys = original
    78  	}
    79  }
    80  
    81  func TrustedRepairRootKeys() []*asserts.AccountKey {
    82  	return trustedRepairRootKeys
    83  }
    84  
    85  func (run *Runner) BrandModel() (brand, model string) {
    86  	return run.state.Device.Brand, run.state.Device.Model
    87  }
    88  
    89  func (run *Runner) SetStateModified(modified bool) {
    90  	run.stateModified = modified
    91  }
    92  
    93  func (run *Runner) SetBrandModel(brand, model string) {
    94  	run.state.Device.Brand = brand
    95  	run.state.Device.Model = model
    96  }
    97  
    98  func (run *Runner) TimeLowerBound() time.Time {
    99  	return run.state.TimeLowerBound
   100  }
   101  
   102  func (run *Runner) TLSTime() time.Time {
   103  	return httputil.BaseTransport(run.cli).TLSClientConfig.Time()
   104  }
   105  
   106  func (run *Runner) Sequence(brand string) []*RepairState {
   107  	return run.state.Sequences[brand]
   108  }
   109  
   110  func (run *Runner) SetSequence(brand string, sequence []*RepairState) {
   111  	if run.state.Sequences == nil {
   112  		run.state.Sequences = make(map[string][]*RepairState)
   113  	}
   114  	run.state.Sequences[brand] = sequence
   115  }
   116  
   117  func MockDefaultRepairTimeout(d time.Duration) (restore func()) {
   118  	orig := defaultRepairTimeout
   119  	defaultRepairTimeout = d
   120  	return func() {
   121  		defaultRepairTimeout = orig
   122  	}
   123  }
   124  
   125  func MockErrtrackerReportRepair(mock func(string, string, string, map[string]string) (string, error)) (restore func()) {
   126  	prev := errtrackerReportRepair
   127  	errtrackerReportRepair = mock
   128  	return func() { errtrackerReportRepair = prev }
   129  }
   130  
   131  func MockTimeNow(f func() time.Time) (restore func()) {
   132  	origTimeNow := timeNow
   133  	timeNow = f
   134  	return func() { timeNow = origTimeNow }
   135  }
   136  
   137  func NewCmdShow(args ...string) *cmdShow {
   138  	cmdShow := &cmdShow{}
   139  	cmdShow.Positional.Repair = args
   140  	return cmdShow
   141  }
   142  
   143  func MockOsGetuid(f func() int) (restore func()) {
   144  	origOsGetuid := osGetuid
   145  	osGetuid = f
   146  	return func() { osGetuid = origOsGetuid }
   147  }