github.com/stolowski/snapd@v0.0.0-20210407085831-115137ce5a22/osutil/disks/mockdisk.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package disks
    21  
    22  import (
    23  	"fmt"
    24  
    25  	"github.com/snapcore/snapd/osutil"
    26  )
    27  
    28  // MockDiskMapping is an implementation of Disk for mocking purposes, it is
    29  // exported so that other packages can easily mock a specific disk layout
    30  // without needing to mock the mount setup, sysfs, or udev commands just to test
    31  // high level logic.
    32  // DevNum must be a unique string per unique mocked disk, if only one disk is
    33  // being mocked it can be left empty.
    34  type MockDiskMapping struct {
    35  	// FilesystemLabelToPartUUID is a mapping of the udev encoded filesystem
    36  	// labels to the expected partition uuids.
    37  	FilesystemLabelToPartUUID map[string]string
    38  	// PartitionLabelToPartUUID is a mapping of the udev encoded partition
    39  	// labels to the expected partition uuids.
    40  	PartitionLabelToPartUUID map[string]string
    41  	DiskHasPartitions        bool
    42  	DevNum                   string
    43  }
    44  
    45  // FindMatchingPartitionUUIDWithFsLabel returns a matching PartitionUUID
    46  // for the specified filesystem label if it exists. Part of the Disk interface.
    47  func (d *MockDiskMapping) FindMatchingPartitionUUIDWithFsLabel(label string) (string, error) {
    48  	osutil.MustBeTestBinary("mock disks only to be used in tests")
    49  	if partuuid, ok := d.FilesystemLabelToPartUUID[label]; ok {
    50  		return partuuid, nil
    51  	}
    52  	return "", PartitionNotFoundError{
    53  		SearchType:  "filesystem-label",
    54  		SearchQuery: label,
    55  	}
    56  }
    57  
    58  // FindMatchingPartitionUUIDWithPartLabel returns a matching PartitionUUID
    59  // for the specified filesystem label if it exists. Part of the Disk interface.
    60  func (d *MockDiskMapping) FindMatchingPartitionUUIDWithPartLabel(label string) (string, error) {
    61  	osutil.MustBeTestBinary("mock disks only to be used in tests")
    62  	if partuuid, ok := d.PartitionLabelToPartUUID[label]; ok {
    63  		return partuuid, nil
    64  	}
    65  	return "", PartitionNotFoundError{
    66  		SearchType:  "partition-label",
    67  		SearchQuery: label,
    68  	}
    69  }
    70  
    71  // HasPartitions returns if the mock disk has partitions or not. Part of the
    72  // Disk interface.
    73  func (d *MockDiskMapping) HasPartitions() bool {
    74  	return d.DiskHasPartitions
    75  }
    76  
    77  // MountPointIsFromDisk returns if the disk that the specified mount point comes
    78  // from is the same disk as the object. Part of the Disk interface.
    79  func (d *MockDiskMapping) MountPointIsFromDisk(mountpoint string, opts *Options) (bool, error) {
    80  	osutil.MustBeTestBinary("mock disks only to be used in tests")
    81  
    82  	// this is relying on the fact that DiskFromMountPoint should have been
    83  	// mocked for us to be using this mockDisk method anyways
    84  	otherDisk, err := DiskFromMountPoint(mountpoint, opts)
    85  	if err != nil {
    86  		return false, err
    87  	}
    88  
    89  	if otherDisk.Dev() == d.Dev() && otherDisk.HasPartitions() == d.HasPartitions() {
    90  		return true, nil
    91  	}
    92  
    93  	return false, nil
    94  }
    95  
    96  // Dev returns a unique representation of the mock disk that is suitable for
    97  // comparing two mock disks to see if they are the same. Part of the Disk
    98  // interface.
    99  func (d *MockDiskMapping) Dev() string {
   100  	return d.DevNum
   101  }
   102  
   103  // Mountpoint is a combination of a mountpoint location and whether that
   104  // mountpoint is a decrypted device. It is only used in identifying mount points
   105  // with MountPointIsFromDisk and DiskFromMountPoint with
   106  // MockMountPointDisksToPartitionMapping.
   107  type Mountpoint struct {
   108  	Mountpoint        string
   109  	IsDecryptedDevice bool
   110  }
   111  
   112  // MockDeviceNameDisksToPartitionMapping will mock DiskFromDeviceName such that
   113  // the provided map of device names to mock disks is used instead of the actual
   114  // implementation using udev.
   115  func MockDeviceNameDisksToPartitionMapping(mockedMountPoints map[string]*MockDiskMapping) (restore func()) {
   116  	osutil.MustBeTestBinary("mock disks only to be used in tests")
   117  
   118  	// note that devices can have many names that are recognized by
   119  	// udev/kernel, so we don't do any validation of the mapping here like we do
   120  	// for MockMountPointDisksToPartitionMapping
   121  
   122  	old := diskFromDeviceName
   123  	diskFromDeviceName = func(deviceName string) (Disk, error) {
   124  		disk, ok := mockedMountPoints[deviceName]
   125  		if !ok {
   126  			return nil, fmt.Errorf("device name %q not mocked", deviceName)
   127  		}
   128  		return disk, nil
   129  	}
   130  
   131  	return func() {
   132  		diskFromDeviceName = old
   133  	}
   134  }
   135  
   136  // MockMountPointDisksToPartitionMapping will mock DiskFromMountPoint such that
   137  // the specified mapping is returned/used. Specifically, keys in the provided
   138  // map are mountpoints, and the values for those keys are the disks that will
   139  // be returned from DiskFromMountPoint or used internally in
   140  // MountPointIsFromDisk.
   141  func MockMountPointDisksToPartitionMapping(mockedMountPoints map[Mountpoint]*MockDiskMapping) (restore func()) {
   142  	osutil.MustBeTestBinary("mock disks only to be used in tests")
   143  
   144  	// verify that all unique MockDiskMapping's have unique DevNum's and that
   145  	// the srcMntPt's are all consistent
   146  	// we can't have the same mountpoint exist both as a decrypted device and
   147  	// not as a decrypted device, this is an impossible mapping, but we need to
   148  	// expose functionality to mock the same mountpoint as a decrypted device
   149  	// and as an unencrypyted device for different tests, but never at the same
   150  	// time with the same mapping
   151  	alreadySeen := make(map[string]*MockDiskMapping, len(mockedMountPoints))
   152  	seenSrcMntPts := make(map[string]bool, len(mockedMountPoints))
   153  	for srcMntPt, mockDisk := range mockedMountPoints {
   154  		if decryptedVal, ok := seenSrcMntPts[srcMntPt.Mountpoint]; ok {
   155  			if decryptedVal != srcMntPt.IsDecryptedDevice {
   156  				msg := fmt.Sprintf("mocked source mountpoint %s is duplicated with different options - previous option for IsDecryptedDevice was %t, current option is %t", srcMntPt.Mountpoint, decryptedVal, srcMntPt.IsDecryptedDevice)
   157  				panic(msg)
   158  			}
   159  		}
   160  		seenSrcMntPts[srcMntPt.Mountpoint] = srcMntPt.IsDecryptedDevice
   161  		if old, ok := alreadySeen[mockDisk.DevNum]; ok {
   162  			if mockDisk != old {
   163  				// we already saw a disk with this DevNum as a different pointer
   164  				// so just assume it's different
   165  				msg := fmt.Sprintf("mocked disks %+v and %+v have the same DevNum (%s) but are not the same object", old, mockDisk, mockDisk.DevNum)
   166  				panic(msg)
   167  			}
   168  			// otherwise same ptr, no point in comparing them
   169  		} else {
   170  			// didn't see it before, save it now
   171  			alreadySeen[mockDisk.DevNum] = mockDisk
   172  		}
   173  	}
   174  
   175  	old := diskFromMountPoint
   176  
   177  	diskFromMountPoint = func(mountpoint string, opts *Options) (Disk, error) {
   178  		if opts == nil {
   179  			opts = &Options{}
   180  		}
   181  		m := Mountpoint{mountpoint, opts.IsDecryptedDevice}
   182  		if mockedDisk, ok := mockedMountPoints[m]; ok {
   183  			return mockedDisk, nil
   184  		}
   185  		return nil, fmt.Errorf("mountpoint %s not mocked", mountpoint)
   186  	}
   187  	return func() {
   188  		diskFromMountPoint = old
   189  	}
   190  }