github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/pkg/boot/multiboot/multiboot_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package multiboot
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/binary"
    10  	"fmt"
    11  	"io"
    12  	"reflect"
    13  	"testing"
    14  )
    15  
    16  func createFile(hdr *header, offset, size int) (io.Reader, error) {
    17  	buf := bytes.Repeat([]byte{0xDE, 0xAD, 0xBE, 0xEF}, (size+4)/4)
    18  	buf = buf[:size]
    19  	if hdr != nil {
    20  		w := bytes.Buffer{}
    21  		if err := binary.Write(&w, binary.LittleEndian, *hdr); err != nil {
    22  			return nil, err
    23  		}
    24  		copy(buf[offset:], w.Bytes())
    25  	}
    26  	return bytes.NewReader(buf), nil
    27  }
    28  
    29  type testFlag string
    30  
    31  const (
    32  	flagGood        testFlag = "good"
    33  	flagUnsupported testFlag = "unsup"
    34  	flagBad         testFlag = "bad"
    35  )
    36  
    37  func createHeader(fl testFlag) header {
    38  	flags := headerFlag(0x00000002)
    39  	var checksum uint32
    40  	switch fl {
    41  	case flagGood:
    42  		checksum = 0xFFFFFFFF - headerMagic - uint32(flags) + 1
    43  	case flagBad:
    44  		checksum = 0xDEADBEEF
    45  	case flagUnsupported:
    46  		flags = 0x0000FFFC
    47  		checksum = 0xFFFFFFFF - headerMagic - uint32(flags) + 1
    48  	}
    49  
    50  	return header{
    51  		mandatory: mandatory{
    52  			Magic:    headerMagic,
    53  			Flags:    flags,
    54  			Checksum: checksum,
    55  		},
    56  		optional: optional{
    57  			HeaderAddr:  1,
    58  			LoadAddr:    2,
    59  			LoadEndAddr: 3,
    60  			BSSEndAddr:  4,
    61  			EntryAddr:   5,
    62  
    63  			ModeType: 6,
    64  			Width:    7,
    65  			Height:   8,
    66  			Depth:    9,
    67  		},
    68  	}
    69  }
    70  
    71  func TestParseHeader(t *testing.T) {
    72  	mandatorySize := binary.Size(mandatory{})
    73  	optionalSize := binary.Size(optional{})
    74  	sizeofHeader := mandatorySize + optionalSize
    75  
    76  	for _, test := range []struct {
    77  		flags  testFlag
    78  		offset int
    79  		size   int
    80  		err    error
    81  	}{
    82  		{flags: flagGood, offset: 0, size: 8192, err: nil},
    83  		{flags: flagGood, offset: 2048, size: 8192, err: nil},
    84  		{flags: flagGood, offset: 8192 - sizeofHeader - 4, size: 8192, err: nil},
    85  		{flags: flagGood, offset: 8192 - sizeofHeader - 1, size: 8192, err: ErrHeaderNotFound},
    86  		{flags: flagGood, offset: 8192 - sizeofHeader, size: 8192, err: nil},
    87  		{flags: flagGood, offset: 8192 - 4, size: 8192, err: ErrHeaderNotFound},
    88  		{flags: flagGood, offset: 8192, size: 16384, err: ErrHeaderNotFound},
    89  		{flags: flagGood, offset: 0, size: 10, err: io.ErrUnexpectedEOF},
    90  		{flags: flagBad, offset: 0, size: 8192, err: ErrHeaderNotFound},
    91  		{flags: flagUnsupported, offset: 0, size: 8192, err: ErrFlagsNotSupported},
    92  		{flags: flagGood, offset: 8192 - mandatorySize, size: 8192, err: nil},
    93  	} {
    94  		t.Run(fmt.Sprintf("flags:%v,off:%v,sz:%v,err:%v", test.flags, test.offset, test.size, test.err), func(t *testing.T) {
    95  			want := createHeader(test.flags)
    96  			r, err := createFile(&want, test.offset, test.size)
    97  			if err != nil {
    98  				t.Fatalf("Cannot create test file: %v", err)
    99  			}
   100  			got, err := parseHeader(r)
   101  			if err != test.err {
   102  				t.Fatalf("parseHeader() got error: %v, want: %v", err, test.err)
   103  			}
   104  
   105  			if err != nil {
   106  				return
   107  			}
   108  			if test.size-test.offset > mandatorySize {
   109  				if !reflect.DeepEqual(*got, want) {
   110  					t.Errorf("parseHeader() got %+v, want %+v", *got, want)
   111  				}
   112  			} else {
   113  				if !reflect.DeepEqual(got.mandatory, want.mandatory) {
   114  					t.Errorf("parseHeader() got %+v, want %+v", got.mandatory, want.mandatory)
   115  				}
   116  			}
   117  
   118  		})
   119  	}
   120  }