github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/lib/iolib/read.go (about)

     1  package iolib
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  
     7  	rt "github.com/arnodel/golua/runtime"
     8  )
     9  
    10  func ioread(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    11  	next := c.Next()
    12  	readers, fmtErr := getFormatReaders(c.Etc())
    13  	if fmtErr != nil {
    14  		return nil, fmtErr
    15  	}
    16  	ioErr := read(t.Runtime, getIoData(t.Runtime).defaultInputFile(), readers, next)
    17  	if ioErr != nil && ioErr != io.EOF {
    18  		return t.ProcessIoError(c.Next(), ioErr)
    19  	}
    20  	return next, nil
    21  }
    22  
    23  func fileread(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    24  	if err := c.Check1Arg(); err != nil {
    25  		return nil, err
    26  	}
    27  	f, err := FileArg(c, 0)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  	next := c.Next()
    32  	readers, fmtErr := getFormatReaders(c.Etc())
    33  	if fmtErr != nil {
    34  		return nil, fmtErr
    35  	}
    36  	ioErr := read(t.Runtime, f, readers, next)
    37  	if ioErr != nil && ioErr != io.EOF {
    38  		return t.ProcessIoError(c.Next(), ioErr)
    39  	}
    40  	return next, nil
    41  }
    42  
    43  type formatReader func(*File) (rt.Value, error)
    44  
    45  var errInvalidFormat = errors.New("invalid format")
    46  var errFormatOutOfRange = errors.New("format out of range")
    47  
    48  func getFormatReader(fmt rt.Value) (reader formatReader, err error) {
    49  	if n, ok := rt.ToInt(fmt); ok {
    50  		if n < 0 {
    51  			return nil, errFormatOutOfRange
    52  		}
    53  		reader = func(f *File) (rt.Value, error) { return f.Read(int(n)) }
    54  	} else if s, ok := fmt.TryString(); ok && len(s) > 0 {
    55  		switch s {
    56  		case "n", "*n":
    57  			reader = (*File).ReadNumber
    58  		case "a", "*a", "all":
    59  			reader = (*File).ReadAll
    60  		case "l", "*l":
    61  			reader = lineReader(false)
    62  		case "L", "*L":
    63  			reader = lineReader(true)
    64  		default:
    65  			return nil, errInvalidFormat
    66  		}
    67  	} else {
    68  		return nil, errInvalidFormat
    69  	}
    70  	return
    71  }
    72  
    73  func getFormatReaders(fmts []rt.Value) ([]formatReader, error) {
    74  	readers := make([]formatReader, len(fmts))
    75  	for i, fmt := range fmts {
    76  		reader, err := getFormatReader(fmt)
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  		readers[i] = reader
    81  	}
    82  	return readers, nil
    83  }
    84  
    85  func read(r *rt.Runtime, f *File, readers []formatReader, next rt.Cont) error {
    86  	if f.IsClosed() {
    87  		return errFileAlreadyClosed
    88  	}
    89  	if len(readers) == 0 {
    90  		readers = []formatReader{lineReader(false)}
    91  	}
    92  	for i, reader := range readers {
    93  		val, readErr := reader(f)
    94  		if readErr == nil {
    95  			r.Push1(next, val)
    96  		} else if i == 0 || readErr != io.EOF {
    97  			return readErr
    98  		}
    99  	}
   100  	return nil
   101  }
   102  
   103  func lineReader(withEnd bool) formatReader {
   104  	return func(f *File) (rt.Value, error) {
   105  		return f.ReadLine(withEnd)
   106  	}
   107  }