github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/spidev/spidev_linux_test.go (about)

     1  // Copyright 2021 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package spidev
     6  
     7  import (
     8  	"encoding/binary"
     9  	"errors"
    10  	"os"
    11  	"reflect"
    12  	"runtime"
    13  	"testing"
    14  	"unsafe"
    15  
    16  	"golang.org/x/sys/unix"
    17  )
    18  
    19  // mockSpidev simulates the ioctls for spidev.
    20  type mockSpidev struct {
    21  	// forceErrno when set will always return the given error from syscall.
    22  	forceErrno unix.Errno
    23  
    24  	mode        Mode
    25  	bitsPerWord uint8
    26  	speedHz     uint32
    27  	transfers   []iocTransfer
    28  }
    29  
    30  func (s *mockSpidev) syscall(trap, a1, a2 uintptr, a3 unsafe.Pointer) (r1, r2 uintptr, err unix.Errno) {
    31  	if s.forceErrno != 0 {
    32  		return 0, 0, s.forceErrno
    33  	}
    34  
    35  	if trap != unix.SYS_IOCTL {
    36  		return 0, 0, unix.EINVAL
    37  	}
    38  	if a1 < 0 {
    39  		return 0, 0, unix.EINVAL
    40  	}
    41  
    42  	switch a2 {
    43  	case iocRdBitsPerWord:
    44  		*(*uint8)(a3) = uint8(s.bitsPerWord)
    45  	case iocWrBitsPerWord:
    46  		s.bitsPerWord = *(*uint8)(a3)
    47  	case iocRdMaxSpeedHz:
    48  		*(*uint32)(a3) = uint32(s.speedHz)
    49  	case iocWrMaxSpeedHz:
    50  		s.speedHz = *(*uint32)(a3)
    51  	case iocRdMode32:
    52  		*(*uint32)(a3) = uint32(s.mode)
    53  	case iocWrMode32:
    54  		s.mode = Mode(*(*uint32)(a3))
    55  	default:
    56  		if uint32(a2&^sizeMask) != iocMessage(0) {
    57  			return 0, 0, unix.EINVAL
    58  		}
    59  
    60  		// Parse multiple transfer structs.
    61  		size := int((a2 & sizeMask) >> sizeShift)
    62  		if size%binary.Size(iocTransfer{}) != 0 {
    63  			return 0, 0, unix.EINVAL
    64  		}
    65  
    66  		// Re-create the slice from the pointer.
    67  		s.transfers = make([]iocTransfer, 0, 0)
    68  		sh := (*reflect.SliceHeader)(unsafe.Pointer(&s.transfers))
    69  		sh.Data = uintptr(a3)
    70  		sh.Len = size / binary.Size(iocTransfer{})
    71  		sh.Cap = size / binary.Size(iocTransfer{})
    72  
    73  		// Make sure the original pointer is not freed up until this point.
    74  		runtime.KeepAlive(a3)
    75  
    76  		// Replace all the non-zero address with 0xdeadbeef because the
    77  		// pointer addresses might change during the test.
    78  		for i := range s.transfers {
    79  			t := &s.transfers[i]
    80  			if t.txBuf != 0 {
    81  				t.txBuf = 0xdeadbeef
    82  			}
    83  			if t.rxBuf != 0 {
    84  				t.rxBuf = 0xdeadbeef
    85  			}
    86  		}
    87  	}
    88  
    89  	return 0, 0, 0
    90  }
    91  
    92  // TestOpenError tests when Open returns an error like file does not exist.
    93  func TestOpenError(t *testing.T) {
    94  	if _, err := Open("/dev/blahblahblahblah"); !os.IsNotExist(err) {
    95  		t.Fatalf(`Open("/dev/blahblahblahblah got %v; want %v`, err, os.ErrNotExist)
    96  	}
    97  }
    98  
    99  // TestGetters tests the functions which return values like Mode, SpeedHz, ...
   100  func TestGetters(t *testing.T) {
   101  	tmpFile, err := os.CreateTemp("", "")
   102  	if err != nil {
   103  		t.Fatalf("Could not create temporary file: %v", err)
   104  	}
   105  	defer os.Remove(tmpFile.Name())
   106  
   107  	s, err := Open(tmpFile.Name())
   108  	if err != nil {
   109  		t.Fatalf("Could not open spidev: %v", err)
   110  	}
   111  	defer s.Close()
   112  
   113  	m := &mockSpidev{
   114  		// You wouldn't use these values in practice, but it is good
   115  		// for a unit test.
   116  		mode:        0x1234,
   117  		bitsPerWord: 10,
   118  		speedHz:     12345,
   119  	}
   120  	s.syscall = m.syscall
   121  
   122  	// Test syscall with and without error.
   123  	for _, tt := range []struct {
   124  		name       string
   125  		forceErrno unix.Errno
   126  		wantErr    error
   127  	}{
   128  		{"", 0, nil},
   129  		{"WithErrno", unix.EAGAIN, unix.EAGAIN},
   130  	} {
   131  		m.forceErrno = tt.forceErrno
   132  
   133  		t.Run("Mode"+tt.name, func(t *testing.T) {
   134  			m, err := s.Mode()
   135  			if !errors.Is(err, tt.wantErr) {
   136  				t.Errorf("Mode() got error %q; want error %q", err, tt.wantErr)
   137  			}
   138  			if err != nil {
   139  				return
   140  			}
   141  			want := Mode(0x1234)
   142  			if m != want {
   143  				t.Errorf("Mode() = %#v; want %#v", m, want)
   144  			}
   145  		})
   146  
   147  		t.Run("BitsPerWord"+tt.name, func(t *testing.T) {
   148  			bpw, err := s.BitsPerWord()
   149  			if !errors.Is(err, tt.wantErr) {
   150  				t.Errorf("BitsPerWord() got error %q; want error %q", err, tt.wantErr)
   151  			}
   152  			if err != nil {
   153  				return
   154  			}
   155  			want := uint8(10)
   156  			if bpw != want {
   157  				t.Errorf("BitsPerWord() = %d; want %d", bpw, want)
   158  			}
   159  		})
   160  
   161  		t.Run("SpeedHz"+tt.name, func(t *testing.T) {
   162  			hz, err := s.SpeedHz()
   163  			if !errors.Is(err, tt.wantErr) {
   164  				t.Errorf("SpeedHz() got error %q; want error %q", err, tt.wantErr)
   165  			}
   166  			if err != nil {
   167  				return
   168  			}
   169  			want := uint32(12345)
   170  			if hz != want {
   171  				t.Errorf("SpeedHz() = %d; want %d", hz, want)
   172  			}
   173  		})
   174  	}
   175  }
   176  
   177  // TestSetters tests the functions which set values like SetMode, SetSpeedHz, ...
   178  func TestSetters(t *testing.T) {
   179  	tmpFile, err := os.CreateTemp("", "")
   180  	if err != nil {
   181  		t.Fatalf("Could not create temporary file: %v", err)
   182  	}
   183  	defer os.Remove(tmpFile.Name())
   184  
   185  	s, err := Open(tmpFile.Name())
   186  	if err != nil {
   187  		t.Fatalf("Could not open spidev: %v", err)
   188  	}
   189  	defer s.Close()
   190  
   191  	m := &mockSpidev{}
   192  	s.syscall = m.syscall
   193  
   194  	// Test syscall with and without error.
   195  	for _, tt := range []struct {
   196  		name       string
   197  		forceErrno unix.Errno
   198  		wantErr    error
   199  	}{
   200  		{"", 0, nil},
   201  		{"WithErrno", unix.EAGAIN, unix.EAGAIN},
   202  	} {
   203  		m.forceErrno = tt.forceErrno
   204  
   205  		t.Run("SetMode"+tt.name, func(t *testing.T) {
   206  			if err := s.SetMode(0x12345); !errors.Is(err, tt.wantErr) {
   207  				t.Errorf("SetMode() got error %q; want error %q", err, tt.wantErr)
   208  			}
   209  			if err != nil {
   210  				return
   211  			}
   212  			const want = Mode(0x12345)
   213  			if m.mode != want {
   214  				t.Errorf("SetMode() = %#v; want %#v", m.mode, want)
   215  			}
   216  		})
   217  
   218  		t.Run("SetBitsPerWord"+tt.name, func(t *testing.T) {
   219  			if err := s.SetBitsPerWord(20); !errors.Is(err, tt.wantErr) {
   220  				t.Errorf("SetBitsPerWord() got error %q; want error %q", err, tt.wantErr)
   221  			}
   222  			if err != nil {
   223  				return
   224  			}
   225  			const want = 20
   226  			if m.bitsPerWord != want {
   227  				t.Errorf("SetBitsPerWord() = %d; want %d", m.bitsPerWord, want)
   228  			}
   229  		})
   230  
   231  		t.Run("SetSpeedHz"+tt.name, func(t *testing.T) {
   232  			if err := s.SetSpeedHz(12345); !errors.Is(err, tt.wantErr) {
   233  				t.Errorf("SetSpeedHz() got error %q; want error %q", err, tt.wantErr)
   234  			}
   235  			if err != nil {
   236  				return
   237  			}
   238  			const want = 12345
   239  			if m.speedHz != want {
   240  				t.Errorf("SetSpeedHz() = %d; want %d", m.speedHz, want)
   241  			}
   242  		})
   243  	}
   244  }
   245  
   246  // TestTransfer tests multiple scenarios involving the Transfer method.
   247  func TestTransfer(t *testing.T) {
   248  	// To avoid OOMing the CI, we set the maxTransferSize to a smaller
   249  	// value temporarily for this test.
   250  	defer func(x int) { maxTransferSize = x }(maxTransferSize)
   251  	maxTransferSize = 0x100000
   252  
   253  	for _, tt := range []struct {
   254  		name          string
   255  		transfers     []Transfer
   256  		forceErrno    unix.Errno
   257  		wantTransfers []iocTransfer
   258  		wantErr       error
   259  	}{
   260  		{
   261  			name: "ErrTxOverflow",
   262  			transfers: []Transfer{
   263  				{
   264  					Tx: make([]uint8, maxTransferSize+1),
   265  				},
   266  			},
   267  			wantErr: ErrTxOverflow{
   268  				TxLen: maxTransferSize + 1,
   269  				TxMax: maxTransferSize,
   270  			},
   271  		},
   272  		{
   273  			name: "ErrRxOverflow",
   274  			transfers: []Transfer{
   275  				{
   276  					Rx: make([]uint8, maxTransferSize+1),
   277  				},
   278  			},
   279  			wantErr: ErrRxOverflow{
   280  				RxLen: maxTransferSize + 1,
   281  				RxMax: maxTransferSize,
   282  			},
   283  		},
   284  		{
   285  			name: "ErrBufferMismatch",
   286  			transfers: []Transfer{
   287  				{
   288  					Tx: make([]uint8, 10),
   289  					Rx: make([]uint8, 20),
   290  				},
   291  			},
   292  			wantErr: ErrBufferMismatch{
   293  				TxLen: 10,
   294  				RxLen: 20,
   295  			},
   296  		},
   297  		{
   298  			name:       "Errno",
   299  			forceErrno: unix.EAGAIN,
   300  			transfers: []Transfer{
   301  				{
   302  					Tx: make([]uint8, 10),
   303  					Rx: make([]uint8, 10),
   304  				},
   305  			},
   306  			wantErr: unix.EAGAIN,
   307  		},
   308  		{
   309  			name: "TxZero",
   310  			transfers: []Transfer{
   311  				{
   312  					Rx: make([]uint8, 10),
   313  				},
   314  			},
   315  			wantTransfers: []iocTransfer{
   316  				{
   317  					rxBuf:  0xdeadbeef,
   318  					length: 10,
   319  				},
   320  			},
   321  		},
   322  		{
   323  			name: "RxZero",
   324  			transfers: []Transfer{
   325  				{
   326  					Tx: make([]uint8, 10),
   327  				},
   328  			},
   329  			wantTransfers: []iocTransfer{
   330  				{
   331  					txBuf:  0xdeadbeef,
   332  					length: 10,
   333  				},
   334  			},
   335  		},
   336  		{
   337  			name: "OneTransfer",
   338  			transfers: []Transfer{
   339  				{
   340  					Tx:             []uint8{1, 2, 3},
   341  					Rx:             []uint8{0, 0, 0},
   342  					SpeedHz:        0x12345678,
   343  					DelayUSecs:     0x1234,
   344  					BitsPerWord:    0x8,
   345  					CSChange:       true,
   346  					TxNBits:        24,
   347  					RxNBits:        24,
   348  					WordDelayUSecs: 0x10,
   349  				},
   350  			},
   351  			wantTransfers: []iocTransfer{
   352  				{
   353  					txBuf:          0xdeadbeef,
   354  					rxBuf:          0xdeadbeef,
   355  					length:         3,
   356  					speedHz:        0x12345678,
   357  					delayUSecs:     0x1234,
   358  					bitsPerWord:    0x8,
   359  					csChange:       1,
   360  					txNBits:        24,
   361  					rxNBits:        24,
   362  					wordDelayUSecs: 0x10,
   363  				},
   364  			},
   365  		},
   366  		{
   367  			name: "TwoTransfers",
   368  			transfers: []Transfer{
   369  				{
   370  					Tx: []uint8{1, 2, 3},
   371  					Rx: []uint8{0, 0, 0},
   372  				},
   373  				{
   374  					Tx: []uint8{4, 5, 6, 7},
   375  				},
   376  			},
   377  			wantTransfers: []iocTransfer{
   378  				{
   379  					txBuf:  0xdeadbeef,
   380  					rxBuf:  0xdeadbeef,
   381  					length: 3,
   382  				},
   383  				{
   384  					txBuf:  0xdeadbeef,
   385  					length: 4,
   386  				},
   387  			},
   388  		},
   389  	} {
   390  		t.Run(tt.name, func(t *testing.T) {
   391  			tmpFile, err := os.CreateTemp("", "")
   392  			if err != nil {
   393  				t.Fatalf("Could not create temporary file: %v", err)
   394  			}
   395  			defer os.Remove(tmpFile.Name())
   396  
   397  			s, err := Open(tmpFile.Name())
   398  			if err != nil {
   399  				t.Fatalf("Could not open spidev: %v", err)
   400  			}
   401  			defer s.Close()
   402  
   403  			m := &mockSpidev{
   404  				forceErrno: tt.forceErrno,
   405  			}
   406  			s.syscall = m.syscall
   407  
   408  			gotErr := s.Transfer(tt.transfers)
   409  			if !errors.Is(gotErr, tt.wantErr) {
   410  				t.Errorf("Got Transfer err %q; want %q", gotErr, tt.wantErr)
   411  			}
   412  			if !reflect.DeepEqual(m.transfers, tt.wantTransfers) {
   413  				t.Errorf("Got Transfers %#v; want %#v", m.transfers, tt.wantTransfers)
   414  			}
   415  		})
   416  	}
   417  }