github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/spidev/spidev_linux_test.go (about) 1 // Copyright 2021 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package spidev 6 7 import ( 8 "encoding/binary" 9 "errors" 10 "os" 11 "reflect" 12 "runtime" 13 "testing" 14 "unsafe" 15 16 "golang.org/x/sys/unix" 17 ) 18 19 // mockSpidev simulates the ioctls for spidev. 20 type mockSpidev struct { 21 // forceErrno when set will always return the given error from syscall. 22 forceErrno unix.Errno 23 24 mode Mode 25 bitsPerWord uint8 26 speedHz uint32 27 transfers []iocTransfer 28 } 29 30 func (s *mockSpidev) syscall(trap, a1, a2 uintptr, a3 unsafe.Pointer) (r1, r2 uintptr, err unix.Errno) { 31 if s.forceErrno != 0 { 32 return 0, 0, s.forceErrno 33 } 34 35 if trap != unix.SYS_IOCTL { 36 return 0, 0, unix.EINVAL 37 } 38 if a1 < 0 { 39 return 0, 0, unix.EINVAL 40 } 41 42 switch a2 { 43 case iocRdBitsPerWord: 44 *(*uint8)(a3) = uint8(s.bitsPerWord) 45 case iocWrBitsPerWord: 46 s.bitsPerWord = *(*uint8)(a3) 47 case iocRdMaxSpeedHz: 48 *(*uint32)(a3) = uint32(s.speedHz) 49 case iocWrMaxSpeedHz: 50 s.speedHz = *(*uint32)(a3) 51 case iocRdMode32: 52 *(*uint32)(a3) = uint32(s.mode) 53 case iocWrMode32: 54 s.mode = Mode(*(*uint32)(a3)) 55 default: 56 if uint32(a2&^sizeMask) != iocMessage(0) { 57 return 0, 0, unix.EINVAL 58 } 59 60 // Parse multiple transfer structs. 61 size := int((a2 & sizeMask) >> sizeShift) 62 if size%binary.Size(iocTransfer{}) != 0 { 63 return 0, 0, unix.EINVAL 64 } 65 66 // Re-create the slice from the pointer. 67 s.transfers = make([]iocTransfer, 0, 0) 68 sh := (*reflect.SliceHeader)(unsafe.Pointer(&s.transfers)) 69 sh.Data = uintptr(a3) 70 sh.Len = size / binary.Size(iocTransfer{}) 71 sh.Cap = size / binary.Size(iocTransfer{}) 72 73 // Make sure the original pointer is not freed up until this point. 74 runtime.KeepAlive(a3) 75 76 // Replace all the non-zero address with 0xdeadbeef because the 77 // pointer addresses might change during the test. 78 for i := range s.transfers { 79 t := &s.transfers[i] 80 if t.txBuf != 0 { 81 t.txBuf = 0xdeadbeef 82 } 83 if t.rxBuf != 0 { 84 t.rxBuf = 0xdeadbeef 85 } 86 } 87 } 88 89 return 0, 0, 0 90 } 91 92 // TestOpenError tests when Open returns an error like file does not exist. 93 func TestOpenError(t *testing.T) { 94 if _, err := Open("/dev/blahblahblahblah"); !os.IsNotExist(err) { 95 t.Fatalf(`Open("/dev/blahblahblahblah got %v; want %v`, err, os.ErrNotExist) 96 } 97 } 98 99 // TestGetters tests the functions which return values like Mode, SpeedHz, ... 100 func TestGetters(t *testing.T) { 101 tmpFile, err := os.CreateTemp("", "") 102 if err != nil { 103 t.Fatalf("Could not create temporary file: %v", err) 104 } 105 defer os.Remove(tmpFile.Name()) 106 107 s, err := Open(tmpFile.Name()) 108 if err != nil { 109 t.Fatalf("Could not open spidev: %v", err) 110 } 111 defer s.Close() 112 113 m := &mockSpidev{ 114 // You wouldn't use these values in practice, but it is good 115 // for a unit test. 116 mode: 0x1234, 117 bitsPerWord: 10, 118 speedHz: 12345, 119 } 120 s.syscall = m.syscall 121 122 // Test syscall with and without error. 123 for _, tt := range []struct { 124 name string 125 forceErrno unix.Errno 126 wantErr error 127 }{ 128 {"", 0, nil}, 129 {"WithErrno", unix.EAGAIN, unix.EAGAIN}, 130 } { 131 m.forceErrno = tt.forceErrno 132 133 t.Run("Mode"+tt.name, func(t *testing.T) { 134 m, err := s.Mode() 135 if !errors.Is(err, tt.wantErr) { 136 t.Errorf("Mode() got error %q; want error %q", err, tt.wantErr) 137 } 138 if err != nil { 139 return 140 } 141 want := Mode(0x1234) 142 if m != want { 143 t.Errorf("Mode() = %#v; want %#v", m, want) 144 } 145 }) 146 147 t.Run("BitsPerWord"+tt.name, func(t *testing.T) { 148 bpw, err := s.BitsPerWord() 149 if !errors.Is(err, tt.wantErr) { 150 t.Errorf("BitsPerWord() got error %q; want error %q", err, tt.wantErr) 151 } 152 if err != nil { 153 return 154 } 155 want := uint8(10) 156 if bpw != want { 157 t.Errorf("BitsPerWord() = %d; want %d", bpw, want) 158 } 159 }) 160 161 t.Run("SpeedHz"+tt.name, func(t *testing.T) { 162 hz, err := s.SpeedHz() 163 if !errors.Is(err, tt.wantErr) { 164 t.Errorf("SpeedHz() got error %q; want error %q", err, tt.wantErr) 165 } 166 if err != nil { 167 return 168 } 169 want := uint32(12345) 170 if hz != want { 171 t.Errorf("SpeedHz() = %d; want %d", hz, want) 172 } 173 }) 174 } 175 } 176 177 // TestSetters tests the functions which set values like SetMode, SetSpeedHz, ... 178 func TestSetters(t *testing.T) { 179 tmpFile, err := os.CreateTemp("", "") 180 if err != nil { 181 t.Fatalf("Could not create temporary file: %v", err) 182 } 183 defer os.Remove(tmpFile.Name()) 184 185 s, err := Open(tmpFile.Name()) 186 if err != nil { 187 t.Fatalf("Could not open spidev: %v", err) 188 } 189 defer s.Close() 190 191 m := &mockSpidev{} 192 s.syscall = m.syscall 193 194 // Test syscall with and without error. 195 for _, tt := range []struct { 196 name string 197 forceErrno unix.Errno 198 wantErr error 199 }{ 200 {"", 0, nil}, 201 {"WithErrno", unix.EAGAIN, unix.EAGAIN}, 202 } { 203 m.forceErrno = tt.forceErrno 204 205 t.Run("SetMode"+tt.name, func(t *testing.T) { 206 if err := s.SetMode(0x12345); !errors.Is(err, tt.wantErr) { 207 t.Errorf("SetMode() got error %q; want error %q", err, tt.wantErr) 208 } 209 if err != nil { 210 return 211 } 212 const want = Mode(0x12345) 213 if m.mode != want { 214 t.Errorf("SetMode() = %#v; want %#v", m.mode, want) 215 } 216 }) 217 218 t.Run("SetBitsPerWord"+tt.name, func(t *testing.T) { 219 if err := s.SetBitsPerWord(20); !errors.Is(err, tt.wantErr) { 220 t.Errorf("SetBitsPerWord() got error %q; want error %q", err, tt.wantErr) 221 } 222 if err != nil { 223 return 224 } 225 const want = 20 226 if m.bitsPerWord != want { 227 t.Errorf("SetBitsPerWord() = %d; want %d", m.bitsPerWord, want) 228 } 229 }) 230 231 t.Run("SetSpeedHz"+tt.name, func(t *testing.T) { 232 if err := s.SetSpeedHz(12345); !errors.Is(err, tt.wantErr) { 233 t.Errorf("SetSpeedHz() got error %q; want error %q", err, tt.wantErr) 234 } 235 if err != nil { 236 return 237 } 238 const want = 12345 239 if m.speedHz != want { 240 t.Errorf("SetSpeedHz() = %d; want %d", m.speedHz, want) 241 } 242 }) 243 } 244 } 245 246 // TestTransfer tests multiple scenarios involving the Transfer method. 247 func TestTransfer(t *testing.T) { 248 // To avoid OOMing the CI, we set the maxTransferSize to a smaller 249 // value temporarily for this test. 250 defer func(x int) { maxTransferSize = x }(maxTransferSize) 251 maxTransferSize = 0x100000 252 253 for _, tt := range []struct { 254 name string 255 transfers []Transfer 256 forceErrno unix.Errno 257 wantTransfers []iocTransfer 258 wantErr error 259 }{ 260 { 261 name: "ErrTxOverflow", 262 transfers: []Transfer{ 263 { 264 Tx: make([]uint8, maxTransferSize+1), 265 }, 266 }, 267 wantErr: ErrTxOverflow{ 268 TxLen: maxTransferSize + 1, 269 TxMax: maxTransferSize, 270 }, 271 }, 272 { 273 name: "ErrRxOverflow", 274 transfers: []Transfer{ 275 { 276 Rx: make([]uint8, maxTransferSize+1), 277 }, 278 }, 279 wantErr: ErrRxOverflow{ 280 RxLen: maxTransferSize + 1, 281 RxMax: maxTransferSize, 282 }, 283 }, 284 { 285 name: "ErrBufferMismatch", 286 transfers: []Transfer{ 287 { 288 Tx: make([]uint8, 10), 289 Rx: make([]uint8, 20), 290 }, 291 }, 292 wantErr: ErrBufferMismatch{ 293 TxLen: 10, 294 RxLen: 20, 295 }, 296 }, 297 { 298 name: "Errno", 299 forceErrno: unix.EAGAIN, 300 transfers: []Transfer{ 301 { 302 Tx: make([]uint8, 10), 303 Rx: make([]uint8, 10), 304 }, 305 }, 306 wantErr: unix.EAGAIN, 307 }, 308 { 309 name: "TxZero", 310 transfers: []Transfer{ 311 { 312 Rx: make([]uint8, 10), 313 }, 314 }, 315 wantTransfers: []iocTransfer{ 316 { 317 rxBuf: 0xdeadbeef, 318 length: 10, 319 }, 320 }, 321 }, 322 { 323 name: "RxZero", 324 transfers: []Transfer{ 325 { 326 Tx: make([]uint8, 10), 327 }, 328 }, 329 wantTransfers: []iocTransfer{ 330 { 331 txBuf: 0xdeadbeef, 332 length: 10, 333 }, 334 }, 335 }, 336 { 337 name: "OneTransfer", 338 transfers: []Transfer{ 339 { 340 Tx: []uint8{1, 2, 3}, 341 Rx: []uint8{0, 0, 0}, 342 SpeedHz: 0x12345678, 343 DelayUSecs: 0x1234, 344 BitsPerWord: 0x8, 345 CSChange: true, 346 TxNBits: 24, 347 RxNBits: 24, 348 WordDelayUSecs: 0x10, 349 }, 350 }, 351 wantTransfers: []iocTransfer{ 352 { 353 txBuf: 0xdeadbeef, 354 rxBuf: 0xdeadbeef, 355 length: 3, 356 speedHz: 0x12345678, 357 delayUSecs: 0x1234, 358 bitsPerWord: 0x8, 359 csChange: 1, 360 txNBits: 24, 361 rxNBits: 24, 362 wordDelayUSecs: 0x10, 363 }, 364 }, 365 }, 366 { 367 name: "TwoTransfers", 368 transfers: []Transfer{ 369 { 370 Tx: []uint8{1, 2, 3}, 371 Rx: []uint8{0, 0, 0}, 372 }, 373 { 374 Tx: []uint8{4, 5, 6, 7}, 375 }, 376 }, 377 wantTransfers: []iocTransfer{ 378 { 379 txBuf: 0xdeadbeef, 380 rxBuf: 0xdeadbeef, 381 length: 3, 382 }, 383 { 384 txBuf: 0xdeadbeef, 385 length: 4, 386 }, 387 }, 388 }, 389 } { 390 t.Run(tt.name, func(t *testing.T) { 391 tmpFile, err := os.CreateTemp("", "") 392 if err != nil { 393 t.Fatalf("Could not create temporary file: %v", err) 394 } 395 defer os.Remove(tmpFile.Name()) 396 397 s, err := Open(tmpFile.Name()) 398 if err != nil { 399 t.Fatalf("Could not open spidev: %v", err) 400 } 401 defer s.Close() 402 403 m := &mockSpidev{ 404 forceErrno: tt.forceErrno, 405 } 406 s.syscall = m.syscall 407 408 gotErr := s.Transfer(tt.transfers) 409 if !errors.Is(gotErr, tt.wantErr) { 410 t.Errorf("Got Transfer err %q; want %q", gotErr, tt.wantErr) 411 } 412 if !reflect.DeepEqual(m.transfers, tt.wantTransfers) { 413 t.Errorf("Got Transfers %#v; want %#v", m.transfers, tt.wantTransfers) 414 } 415 }) 416 } 417 }