github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/boot/initrd.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package boot
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"strings"
    12  
    13  	"github.com/mvdan/u-root-coreutils/pkg/cpio"
    14  	"github.com/mvdan/u-root-coreutils/pkg/uio"
    15  )
    16  
    17  // CatInitrds concatenates initrds on first ReadAt call from a list of
    18  // io.ReaderAts, pads them to a 512 byte boundary.
    19  func CatInitrds(initrds ...io.ReaderAt) io.ReaderAt {
    20  	var names []string
    21  	for _, initrd := range initrds {
    22  		names = append(names, stringer(initrd))
    23  	}
    24  
    25  	return uio.NewLazyOpenerAt(strings.Join(names, ","), func() (io.ReaderAt, error) {
    26  		buf := new(bytes.Buffer)
    27  		for i, ireader := range initrds {
    28  			size, err := buf.ReadFrom(uio.Reader(ireader))
    29  			if err != nil {
    30  				return nil, err
    31  			}
    32  			// Don't pad the ending or an already aligned file.
    33  			if i != len(initrds)-1 && size%512 != 0 {
    34  				padding := make([]byte, 512-(size%512))
    35  				buf.Write(padding)
    36  			}
    37  		}
    38  		// Buffer doesn't implement ReadAt, so wrap in NewReader
    39  		return bytes.NewReader(buf.Bytes()), nil
    40  	})
    41  }
    42  
    43  // CreateInitrd creates an initrd with the collection of files passed in.
    44  func CreateInitrd(files ...string) (io.ReaderAt, error) {
    45  	b := &bytes.Buffer{}
    46  	archiver, err := cpio.Format("newc")
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	w := archiver.Writer(b)
    51  	cr := cpio.NewRecorder()
    52  	// to deconflict names, we may want to prepend the names with
    53  	// kexec_extra/ or something.
    54  	for _, n := range files {
    55  		rec, err := cr.GetRecord(n)
    56  		if err != nil {
    57  			return nil, fmt.Errorf("Getting record of %q failed: %v", n, err)
    58  		}
    59  		if err := w.WriteRecord(rec); err != nil {
    60  			return nil, fmt.Errorf("Writing record %q failed: %v", n, err)
    61  		}
    62  	}
    63  	if err := cpio.WriteTrailer(w); err != nil {
    64  		return nil, fmt.Errorf("Error writing trailer record: %v", err)
    65  	}
    66  	return bytes.NewReader(b.Bytes()), nil
    67  }