github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/memio/mmap_test.go (about)

     1  // Copyright 2012-2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package memio
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  )
    15  
    16  func TestIORealSyscalls(t *testing.T) {
    17  	for _, tt := range []struct {
    18  		name                string
    19  		addr                int64
    20  		writeData, readData UintN
    21  		err                 string
    22  	}{
    23  		{
    24  			name:      "uint8",
    25  			addr:      0x10,
    26  			writeData: &[]Uint8{0x12}[0],
    27  			readData:  new(Uint8),
    28  		},
    29  		{
    30  			name:      "uint16",
    31  			addr:      0x20,
    32  			writeData: &[]Uint16{0x1234}[0],
    33  			readData:  new(Uint16),
    34  		},
    35  		{
    36  			name:      "uint32",
    37  			addr:      0x30,
    38  			writeData: &[]Uint32{0x12345678}[0],
    39  			readData:  new(Uint32),
    40  		},
    41  		{
    42  			name:      "uint64",
    43  			addr:      0x40,
    44  			writeData: &[]Uint64{0x1234567890abcdef}[0],
    45  			readData:  new(Uint64),
    46  		},
    47  		{
    48  			name:      "byte slice",
    49  			addr:      0x50,
    50  			writeData: &[]ByteSlice{[]byte("Hello")}[0],
    51  			readData:  &[]ByteSlice{make([]byte, 5)}[0],
    52  		},
    53  	} {
    54  		t.Run(fmt.Sprintf(tt.name), func(t *testing.T) {
    55  			tmpFile, err := os.CreateTemp("", "io_test")
    56  			if err != nil {
    57  				t.Fatal(err)
    58  			}
    59  			tmpFile.Write(make([]byte, 10000))
    60  			tmpFile.Close()
    61  			defer os.Remove(tmpFile.Name())
    62  			m, err := NewMMap(tmpFile.Name())
    63  			if err != nil {
    64  				t.Errorf("%q failed at NewMMap: %q", tt.name, err)
    65  			}
    66  			defer m.Close()
    67  			// Write to the file.
    68  			if err := m.WriteAt(tt.addr, tt.writeData); err != nil {
    69  				if err.Error() == tt.err {
    70  					return
    71  				}
    72  				t.Fatal(err)
    73  			}
    74  
    75  			// Read back the value.
    76  			if err := m.ReadAt(tt.addr, tt.readData); err != nil {
    77  				if err.Error() == tt.err {
    78  					return
    79  				}
    80  				t.Fatal(err)
    81  			}
    82  
    83  			want := tt.writeData
    84  			got := tt.readData
    85  			if !reflect.DeepEqual(want, got) {
    86  				t.Fatalf("Write(%#016x, %v) = %v; want %v",
    87  					tt.addr, want, got, want)
    88  			}
    89  		})
    90  		t.Run(tt.name+"Deprecated", func(t *testing.T) {
    91  			tmpFile, err := os.CreateTemp("", "io_test")
    92  			if err != nil {
    93  				t.Fatal(err)
    94  			}
    95  			tmpFile.Write(make([]byte, 10000))
    96  			tmpFile.Close()
    97  			defer os.Remove(tmpFile.Name())
    98  			memPath = tmpFile.Name()
    99  			defer func() { memPath = "/dev/mem" }()
   100  
   101  			// Write to the file.
   102  			if err := Write(tt.addr, tt.writeData); err != nil {
   103  				if err.Error() == tt.err {
   104  					return
   105  				}
   106  				t.Fatal(err)
   107  			}
   108  
   109  			// Read back the value.
   110  			if err := Read(tt.addr, tt.readData); err != nil {
   111  				if err.Error() == tt.err {
   112  					return
   113  				}
   114  				t.Fatal(err)
   115  			}
   116  
   117  			want := tt.writeData
   118  			got := tt.readData
   119  			if !reflect.DeepEqual(want, got) {
   120  				t.Fatalf("Write(%#016x, %v) = %v; want %v",
   121  					tt.addr, want, got, want)
   122  			}
   123  		})
   124  	}
   125  }
   126  func TestNetMMapFail(t *testing.T) {
   127  	_, err := NewMMap("file-does-not-exist")
   128  	if !errors.Is(err, os.ErrNotExist) {
   129  		t.Errorf("TestNetMapFail failed: %q", err)
   130  	}
   131  }
   132  
   133  func TestReadWriteErrorWrongPath(t *testing.T) {
   134  	memPath = "file-does-not-exist"
   135  	defer func() { memPath = "/dev/mem" }()
   136  
   137  	var data UintN
   138  	if err := Write(0x35, data); !errors.Is(err, os.ErrNotExist) {
   139  		t.Errorf("TestReadWriteErrorWrongPath failed at Write(..): %q", err)
   140  	}
   141  	if err := Read(0x35, data); !errors.Is(err, os.ErrNotExist) {
   142  		t.Errorf("TestReadWriteErrorWrongPath failed at Read(..): %q", err)
   143  	}
   144  }
   145  
   146  type fakeSyscalls struct {
   147  	errMmap   error
   148  	errMunMap error
   149  	retBytes  []byte
   150  }
   151  
   152  func (f *fakeSyscalls) Mmap(fd int, page int64, mapSize int, prot int, callid int) ([]byte, error) {
   153  	return f.retBytes, f.errMmap
   154  }
   155  
   156  func (f *fakeSyscalls) Munmap(mem []byte) error {
   157  	return f.errMunMap
   158  }
   159  
   160  func TestMemIOAbstractSyscalls(t *testing.T) {
   161  	tmpFile, err := os.CreateTemp("", "io_test")
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  	tmpFile.Write(make([]byte, 10000))
   166  	tmpFile.Close()
   167  	defer os.Remove(tmpFile.Name())
   168  	m, err := NewMMap(tmpFile.Name())
   169  	if err != nil {
   170  		t.Errorf("TestMemIOAbstractSyscalls failed at NewMMap: %q", err)
   171  	}
   172  	defer m.Close()
   173  	for _, tt := range []struct {
   174  		name      string
   175  		errMmap   string
   176  		errMunMap string
   177  		retbyte   []byte
   178  		data      UintN
   179  	}{
   180  		{
   181  			name:    "TestMmapError",
   182  			errMmap: "force mmap error",
   183  			data:    &[]Uint8{0x12}[0],
   184  		},
   185  	} {
   186  		{
   187  			m.syscalls = &fakeSyscalls{
   188  				errMmap:   errors.New(tt.errMmap),
   189  				errMunMap: errors.New(tt.errMunMap),
   190  				retBytes:  tt.retbyte,
   191  			}
   192  			t.Run(tt.name, func(t *testing.T) {
   193  				if err := m.ReadAt(0x23, tt.data); !strings.Contains(err.Error(), tt.errMmap) {
   194  					t.Errorf("%q_ReadAt failed. Want: %q, Got: %q", tt.name, tt.errMmap, err)
   195  				}
   196  			})
   197  			t.Run(tt.name, func(t *testing.T) {
   198  				if err := m.WriteAt(0x23, tt.data); !strings.Contains(err.Error(), tt.errMmap) {
   199  					t.Errorf("%q_WriteAt failed. Want: %q, Got: %q", tt.name, tt.errMmap, err)
   200  				}
   201  			})
   202  		}
   203  	}
   204  
   205  }