github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/hypervisor/tftpbootd/impl.go (about)

     1  package tftpbootd
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"os"
    11  	"strings"
    12  	"time"
    13  
    14  	imageclient "github.com/Cloud-Foundations/Dominator/imageserver/client"
    15  	"github.com/Cloud-Foundations/Dominator/lib/filesystem"
    16  	"github.com/Cloud-Foundations/Dominator/lib/format"
    17  	"github.com/Cloud-Foundations/Dominator/lib/log"
    18  	"github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger"
    19  	objectclient "github.com/Cloud-Foundations/Dominator/lib/objectserver/client"
    20  	"github.com/Cloud-Foundations/Dominator/lib/srpc"
    21  	"github.com/pin/tftp"
    22  )
    23  
    24  const tftpbootPrefix = "/tftpboot"
    25  
    26  func cleanPath(filename string) string {
    27  	if strings.HasPrefix(filename, tftpbootPrefix) {
    28  		return filename[len(tftpbootPrefix):]
    29  	} else if filename[0] != '/' {
    30  		return "/" + filename
    31  	} else {
    32  		return filename
    33  	}
    34  }
    35  
    36  func readHandler(rf io.ReaderFrom, reader io.Reader,
    37  	logger log.DebugLogger) error {
    38  	startTime := time.Now()
    39  	nRead, err := rf.ReadFrom(reader)
    40  	if err != nil {
    41  		io.Copy(ioutil.Discard, reader)
    42  		return err
    43  	}
    44  	timeTaken := time.Since(startTime)
    45  	speed := uint64(float64(nRead) / timeTaken.Seconds())
    46  	logger.Printf("%d bytes sent in %s (%s/s)\n",
    47  		nRead, format.Duration(timeTaken), format.FormatBytes(speed))
    48  	return nil
    49  }
    50  
    51  func newServer(imageServerAddress, imageStreamName string,
    52  	logger log.DebugLogger) (*TftpbootServer, error) {
    53  	s := &TftpbootServer{
    54  		cachedFileSystems:  make(map[string]*cachedFileSystem),
    55  		filesForIPs:        make(map[string]map[string][]byte),
    56  		imageServerAddress: imageServerAddress,
    57  		imageStreamName:    imageStreamName,
    58  		logger:             logger,
    59  		closeClientTimer:   time.NewTimer(time.Minute),
    60  	}
    61  	s.tftpdServer = tftp.NewServer(s.readHandler, nil)
    62  	go func() {
    63  		if err := s.tftpdServer.ListenAndServe(":69"); err != nil {
    64  			s.logger.Println(err)
    65  		}
    66  	}()
    67  	go s.imageServerClientCloser()
    68  	return s, nil
    69  }
    70  
    71  func (s *TftpbootServer) closeImageServerClient() {
    72  	s.lock.Lock()
    73  	defer s.lock.Unlock()
    74  	if s.imageServerClientInUse {
    75  		return
    76  	}
    77  	if s.imageServerClient != nil {
    78  		s.imageServerClient.Close()
    79  		s.imageServerClient = nil
    80  		s.logger.Debugf(0, "closed connection to: %s\n", s.imageServerAddress)
    81  	}
    82  }
    83  
    84  func (s *TftpbootServer) getFileSystem(imageStreamName string,
    85  	client *srpc.Client) (*filesystem.FileSystem, error) {
    86  	if fs, err := s.getCachedFileSystem(imageStreamName); err != nil {
    87  		return nil, err
    88  	} else if fs != nil {
    89  		return fs, nil
    90  	}
    91  	imageName, err := imageclient.FindLatestImage(client, imageStreamName,
    92  		false)
    93  	if err != nil {
    94  		return nil, fmt.Errorf("error finding latest image in stream: %s: %s",
    95  			imageStreamName, err)
    96  	}
    97  	if imageName == "" {
    98  		return nil, fmt.Errorf("no images in stream: %s", imageStreamName)
    99  	}
   100  	image, err := imageclient.GetImage(client, imageName)
   101  	if err != nil {
   102  		return nil, fmt.Errorf("error getting image: %s: %s", imageName, err)
   103  	}
   104  	if err := image.FileSystem.RebuildInodePointers(); err != nil {
   105  		return nil, err
   106  	}
   107  	entry := cachedFileSystem{
   108  		deleteTimer: time.NewTimer(time.Minute),
   109  		fileSystem:  image.FileSystem,
   110  	}
   111  	s.lock.Lock()
   112  	s.cachedFileSystems[imageStreamName] = &entry
   113  	s.lock.Unlock()
   114  	go func() {
   115  		<-entry.deleteTimer.C
   116  		s.lock.Lock()
   117  		delete(s.cachedFileSystems, imageStreamName)
   118  		s.lock.Unlock()
   119  		s.logger.Debugf(0, "removed from cache: %s\n", imageStreamName)
   120  	}()
   121  	return image.FileSystem, nil
   122  }
   123  
   124  func (s *TftpbootServer) getCachedFileSystem(imageStreamName string) (
   125  	*filesystem.FileSystem, error) {
   126  	if imageStreamName == "" {
   127  		return nil, errors.New("no image stream defined")
   128  	}
   129  	s.lock.Lock()
   130  	defer s.lock.Unlock()
   131  	if entry, ok := s.cachedFileSystems[imageStreamName]; ok {
   132  		entry.deleteTimer.Reset(time.Minute)
   133  		return entry.fileSystem, nil
   134  	}
   135  	return nil, nil
   136  }
   137  
   138  func (s *TftpbootServer) getImageServerClient() *srpc.Client {
   139  	s.lock.Lock()
   140  	s.imageServerClientInUse = true
   141  	s.lock.Unlock()
   142  	s.imageServerClientLock.Lock()
   143  	if s.imageServerClient != nil {
   144  		return s.imageServerClient
   145  	}
   146  	for ; ; time.Sleep(time.Second * 15) {
   147  		client, err := srpc.DialHTTP("tcp", s.imageServerAddress, 0)
   148  		if err != nil {
   149  			s.logger.Println(err)
   150  			continue
   151  		}
   152  		s.logger.Debugf(0, "Connected to: %s\n", s.imageServerAddress)
   153  		s.imageServerClient = client
   154  		return s.imageServerClient
   155  	}
   156  }
   157  
   158  func (s *TftpbootServer) imageServerClientCloser() {
   159  	for range s.closeClientTimer.C {
   160  		s.closeImageServerClient()
   161  	}
   162  }
   163  
   164  func (s *TftpbootServer) readHandler(filename string, rf io.ReaderFrom) error {
   165  	filename = cleanPath(filename)
   166  	rAddr := rf.(tftp.OutgoingTransfer).RemoteAddr().IP.String()
   167  	logger := prefixlogger.New("tftpd("+rAddr+":"+filename+"): ", s.logger)
   168  	logger.Debugln(1, "received request")
   169  	if err := s.readHandlerInternal(filename, rf, rAddr, logger); err != nil {
   170  		logger.Println(err)
   171  		return err
   172  	}
   173  	return nil
   174  }
   175  
   176  func (s *TftpbootServer) readHandlerInternal(filename string, rf io.ReaderFrom,
   177  	remoteAddr string, logger log.DebugLogger) error {
   178  	s.lock.Lock()
   179  	if files, ok := s.filesForIPs[remoteAddr]; ok {
   180  		if data, ok := files[filename]; ok {
   181  			s.lock.Unlock()
   182  			rf.(tftp.OutgoingTransfer).SetSize(int64(len(data)))
   183  			return readHandler(rf, bytes.NewReader(data), logger)
   184  		}
   185  	}
   186  	imageStreamName := s.imageStreamName
   187  	s.lock.Unlock()
   188  	client := s.getImageServerClient()
   189  	defer s.releaseImageServerClient()
   190  	fs, err := s.getFileSystem(imageStreamName, client)
   191  	if err != nil {
   192  		return err
   193  	}
   194  	defer s.getCachedFileSystem(imageStreamName) // Reset expiration timer.
   195  	filenameToInodeTable := fs.FilenameToInodeTable()
   196  	if inum, ok := filenameToInodeTable[filename]; !ok {
   197  		return os.ErrNotExist
   198  	} else if gInode, ok := fs.InodeTable[inum]; !ok {
   199  		return fmt.Errorf("inode: %d does not exist", inum)
   200  	} else if inode, ok := gInode.(*filesystem.RegularInode); !ok {
   201  		return fmt.Errorf("inode is not a regular file: %d", inum)
   202  	} else {
   203  		objSrv := objectclient.AttachObjectClient(client)
   204  		defer objSrv.Close()
   205  		if size, reader, err := objSrv.GetObject(inode.Hash); err != nil {
   206  			return err
   207  		} else {
   208  			defer reader.Close()
   209  			rf.(tftp.OutgoingTransfer).SetSize(int64(size))
   210  			return readHandler(rf, reader, logger)
   211  		}
   212  	}
   213  }
   214  
   215  func (s *TftpbootServer) registerFiles(ipAddr net.IP, files map[string][]byte) {
   216  	address := ipAddr.String()
   217  	cleanedFiles := make(map[string][]byte, len(files))
   218  	for filename, data := range files {
   219  		cleanedFiles[cleanPath(filename)] = data
   220  	}
   221  	s.lock.Lock()
   222  	defer s.lock.Unlock()
   223  	if len(files) < 1 {
   224  		delete(s.filesForIPs, address)
   225  	} else {
   226  		s.filesForIPs[address] = cleanedFiles
   227  	}
   228  }
   229  
   230  func (s *TftpbootServer) releaseImageServerClient() {
   231  	s.closeClientTimer.Reset(time.Minute)
   232  	s.lock.Lock()
   233  	s.imageServerClientInUse = false
   234  	s.lock.Unlock()
   235  	s.imageServerClientLock.Unlock()
   236  }
   237  
   238  func (s *TftpbootServer) setImageStreamName(name string) {
   239  	s.lock.Lock()
   240  	defer s.lock.Unlock()
   241  	s.imageStreamName = name
   242  }
   243  
   244  func (s *TftpbootServer) unregisterFiles(ipAddr net.IP) {
   245  	address := ipAddr.String()
   246  	s.lock.Lock()
   247  	defer s.lock.Unlock()
   248  	delete(s.filesForIPs, address)
   249  }