github.com/benhoyt/goawk@v1.8.1/interp/io.go (about)

     1  // Input/output handling for GoAWK interpreter
     2  
     3  package interp
     4  
     5  import (
     6  	"bufio"
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"os/exec"
    13  	"strconv"
    14  	"strings"
    15  	"unicode/utf8"
    16  
    17  	. "github.com/benhoyt/goawk/internal/ast"
    18  	. "github.com/benhoyt/goawk/lexer"
    19  )
    20  
    21  // Print a line of output followed by a newline
    22  func (p *interp) printLine(writer io.Writer, line string) error {
    23  	err := writeOutput(writer, line)
    24  	if err != nil {
    25  		return err
    26  	}
    27  	return writeOutput(writer, p.outputRecordSep)
    28  }
    29  
    30  // Implement a buffered version of WriteCloser so output is buffered
    31  // when redirecting to a file (eg: print >"out")
    32  type bufferedWriteCloser struct {
    33  	*bufio.Writer
    34  	io.Closer
    35  }
    36  
    37  func newBufferedWriteClose(w io.WriteCloser) *bufferedWriteCloser {
    38  	writer := bufio.NewWriterSize(w, outputBufSize)
    39  	return &bufferedWriteCloser{writer, w}
    40  }
    41  
    42  func (wc *bufferedWriteCloser) Close() error {
    43  	err := wc.Writer.Flush()
    44  	if err != nil {
    45  		return err
    46  	}
    47  	return wc.Closer.Close()
    48  }
    49  
    50  // Determine the output stream for given redirect token and
    51  // destination (file or pipe name)
    52  func (p *interp) getOutputStream(redirect Token, dest Expr) (io.Writer, error) {
    53  	if redirect == ILLEGAL {
    54  		// Token "ILLEGAL" means send to standard output
    55  		return p.output, nil
    56  	}
    57  
    58  	destValue, err := p.eval(dest)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	name := p.toString(destValue)
    63  	if _, ok := p.inputStreams[name]; ok {
    64  		return nil, newError("can't write to reader stream")
    65  	}
    66  	if w, ok := p.outputStreams[name]; ok {
    67  		return w, nil
    68  	}
    69  
    70  	switch redirect {
    71  	case GREATER, APPEND:
    72  		// Write or append to file
    73  		if p.noFileWrites {
    74  			return nil, newError("can't write to file due to NoFileWrites")
    75  		}
    76  		flags := os.O_CREATE | os.O_WRONLY
    77  		if redirect == GREATER {
    78  			flags |= os.O_TRUNC
    79  		} else {
    80  			flags |= os.O_APPEND
    81  		}
    82  		w, err := os.OpenFile(name, flags, 0644)
    83  		if err != nil {
    84  			return nil, newError("output redirection error: %s", err)
    85  		}
    86  		buffered := newBufferedWriteClose(w)
    87  		p.outputStreams[name] = buffered
    88  		return buffered, nil
    89  
    90  	case PIPE:
    91  		// Pipe to command
    92  		if p.noExec {
    93  			return nil, newError("can't write to pipe due to NoExec")
    94  		}
    95  		cmd := exec.Command("sh", "-c", name)
    96  		w, err := cmd.StdinPipe()
    97  		if err != nil {
    98  			return nil, newError("error connecting to stdin pipe: %v", err)
    99  		}
   100  		cmd.Stdout = p.output
   101  		cmd.Stderr = p.errorOutput
   102  		err = cmd.Start()
   103  		if err != nil {
   104  			fmt.Fprintln(p.errorOutput, err)
   105  			return ioutil.Discard, nil
   106  		}
   107  		p.commands[name] = cmd
   108  		buffered := newBufferedWriteClose(w)
   109  		p.outputStreams[name] = buffered
   110  		return buffered, nil
   111  
   112  	default:
   113  		// Should never happen
   114  		panic(fmt.Sprintf("unexpected redirect type %s", redirect))
   115  	}
   116  }
   117  
   118  // Get input Scanner to use for "getline" based on file name
   119  func (p *interp) getInputScannerFile(name string) (*bufio.Scanner, error) {
   120  	if _, ok := p.outputStreams[name]; ok {
   121  		return nil, newError("can't read from writer stream")
   122  	}
   123  	if _, ok := p.inputStreams[name]; ok {
   124  		return p.scanners[name], nil
   125  	}
   126  	if p.noFileReads {
   127  		return nil, newError("can't read from file due to NoFileReads")
   128  	}
   129  	r, err := os.Open(name)
   130  	if err != nil {
   131  		return nil, err // *os.PathError is handled by caller (getline returns -1)
   132  	}
   133  	scanner := p.newScanner(r)
   134  	p.scanners[name] = scanner
   135  	p.inputStreams[name] = r
   136  	return scanner, nil
   137  }
   138  
   139  // Get input Scanner to use for "getline" based on pipe name
   140  func (p *interp) getInputScannerPipe(name string) (*bufio.Scanner, error) {
   141  	if _, ok := p.outputStreams[name]; ok {
   142  		return nil, newError("can't read from writer stream")
   143  	}
   144  	if _, ok := p.inputStreams[name]; ok {
   145  		return p.scanners[name], nil
   146  	}
   147  	if p.noExec {
   148  		return nil, newError("can't read from pipe due to NoExec")
   149  	}
   150  	cmd := exec.Command("sh", "-c", name)
   151  	cmd.Stdin = p.stdin
   152  	cmd.Stderr = p.errorOutput
   153  	r, err := cmd.StdoutPipe()
   154  	if err != nil {
   155  		return nil, newError("error connecting to stdout pipe: %v", err)
   156  	}
   157  	err = cmd.Start()
   158  	if err != nil {
   159  		fmt.Fprintln(p.errorOutput, err)
   160  		return bufio.NewScanner(strings.NewReader("")), nil
   161  	}
   162  	scanner := p.newScanner(r)
   163  	p.commands[name] = cmd
   164  	p.inputStreams[name] = r
   165  	p.scanners[name] = scanner
   166  	return scanner, nil
   167  }
   168  
   169  // Create a new buffered Scanner for reading input records
   170  func (p *interp) newScanner(input io.Reader) *bufio.Scanner {
   171  	scanner := bufio.NewScanner(input)
   172  	switch p.recordSep {
   173  	case "\n":
   174  		// Scanner default is to split on newlines
   175  	case "":
   176  		// Empty string for RS means split on \n\n (blank lines)
   177  		scanner.Split(scanLinesBlank)
   178  	default:
   179  		splitter := byteSplitter{p.recordSep[0]}
   180  		scanner.Split(splitter.scan)
   181  	}
   182  	buffer := make([]byte, inputBufSize)
   183  	scanner.Buffer(buffer, maxRecordLength)
   184  	return scanner
   185  }
   186  
   187  // Copied from bufio/scan.go in the stdlib: I guess it's a bit more
   188  // efficient than bytes.TrimSuffix(data, []byte("\r"))
   189  func dropCR(data []byte) []byte {
   190  	if len(data) > 0 && data[len(data)-1] == '\r' {
   191  		return data[:len(data)-1]
   192  	}
   193  	return data
   194  }
   195  
   196  func dropLF(data []byte) []byte {
   197  	if len(data) > 0 && data[len(data)-1] == '\n' {
   198  		return data[:len(data)-1]
   199  	}
   200  	return data
   201  }
   202  
   203  func scanLinesBlank(data []byte, atEOF bool) (advance int, token []byte, err error) {
   204  	if atEOF && len(data) == 0 {
   205  		return 0, nil, nil
   206  	}
   207  
   208  	// Skip newlines at beginning of data
   209  	i := 0
   210  	for i < len(data) && (data[i] == '\n' || data[i] == '\r') {
   211  		i++
   212  	}
   213  	if i >= len(data) {
   214  		// At end of data after newlines, skip entire data block
   215  		return i, nil, nil
   216  	}
   217  	start := i
   218  
   219  	// Try to find two consecutive newlines (or \n\r\n for Windows)
   220  	for ; i < len(data); i++ {
   221  		if data[i] != '\n' {
   222  			continue
   223  		}
   224  		end := i
   225  		if i+1 < len(data) && data[i+1] == '\n' {
   226  			i += 2
   227  			for i < len(data) && (data[i] == '\n' || data[i] == '\r') {
   228  				i++ // Skip newlines at end of record
   229  			}
   230  			return i, dropCR(data[start:end]), nil
   231  		}
   232  		if i+2 < len(data) && data[i+1] == '\r' && data[i+2] == '\n' {
   233  			i += 3
   234  			for i < len(data) && (data[i] == '\n' || data[i] == '\r') {
   235  				i++ // Skip newlines at end of record
   236  			}
   237  			return i, dropCR(data[start:end]), nil
   238  		}
   239  	}
   240  
   241  	// If we're at EOF, we have one final record; return it
   242  	if atEOF {
   243  		return len(data), dropCR(dropLF(data[start:])), nil
   244  	}
   245  
   246  	// Request more data
   247  	return 0, nil, nil
   248  }
   249  
   250  // Splitter function that splits records on the given separator byte
   251  type byteSplitter struct {
   252  	sep byte
   253  }
   254  
   255  func (s byteSplitter) scan(data []byte, atEOF bool) (advance int, token []byte, err error) {
   256  	if atEOF && len(data) == 0 {
   257  		return 0, nil, nil
   258  	}
   259  	if i := bytes.IndexByte(data, s.sep); i >= 0 {
   260  		// We have a full sep-terminated record
   261  		return i + 1, data[0:i], nil
   262  	}
   263  	// If at EOF, we have a final, non-terminated record; return it
   264  	if atEOF {
   265  		return len(data), data, nil
   266  	}
   267  	// Request more data
   268  	return 0, nil, nil
   269  }
   270  
   271  // Setup for a new input file with given name (empty string if stdin)
   272  func (p *interp) setFile(filename string) {
   273  	p.filename = filename
   274  	p.fileLineNum = 0
   275  }
   276  
   277  // Setup for a new input line (but don't parse it into fields till we
   278  // need to)
   279  func (p *interp) setLine(line string) {
   280  	p.line = line
   281  	p.haveFields = false
   282  }
   283  
   284  // Ensure that the current line is parsed into fields, splitting it
   285  // into fields if it hasn't been already
   286  func (p *interp) ensureFields() {
   287  	if p.haveFields {
   288  		return
   289  	}
   290  	p.haveFields = true
   291  
   292  	if p.fieldSep == " " {
   293  		// FS space (default) means split fields on any whitespace
   294  		p.fields = strings.Fields(p.line)
   295  	} else if p.line == "" {
   296  		p.fields = nil
   297  	} else if utf8.RuneCountInString(p.fieldSep) <= 1 {
   298  		// 1-char FS is handled as plain split (not regex)
   299  		p.fields = strings.Split(p.line, p.fieldSep)
   300  	} else {
   301  		// Split on FS as a regex
   302  		p.fields = p.fieldSepRegex.Split(p.line, -1)
   303  	}
   304  
   305  	// Special case for when RS=="" and FS is single character,
   306  	// split on newline in addition to FS. See more here:
   307  	// https://www.gnu.org/software/gawk/manual/html_node/Multiple-Line.html
   308  	if p.recordSep == "" && utf8.RuneCountInString(p.fieldSep) == 1 {
   309  		fields := make([]string, 0, len(p.fields))
   310  		for _, field := range p.fields {
   311  			lines := strings.Split(field, "\n")
   312  			for _, line := range lines {
   313  				trimmed := strings.TrimSuffix(line, "\r")
   314  				fields = append(fields, trimmed)
   315  			}
   316  		}
   317  		p.fields = fields
   318  	}
   319  
   320  	p.numFields = len(p.fields)
   321  }
   322  
   323  // Fetch next line (record) of input from current input file, opening
   324  // next input file if done with previous one
   325  func (p *interp) nextLine() (string, error) {
   326  	for {
   327  		if p.scanner == nil {
   328  			if prevInput, ok := p.input.(io.Closer); ok && p.input != p.stdin {
   329  				// Previous input is file, close it
   330  				_ = prevInput.Close()
   331  			}
   332  			if p.filenameIndex >= p.argc && !p.hadFiles {
   333  				// Moved past number of ARGV args and haven't seen
   334  				// any files yet, use stdin
   335  				p.input = p.stdin
   336  				p.setFile("")
   337  				p.hadFiles = true
   338  			} else {
   339  				if p.filenameIndex >= p.argc {
   340  					// Done with ARGV args, all done with input
   341  					return "", io.EOF
   342  				}
   343  				// Fetch next filename from ARGV. Can't use
   344  				// getArrayValue() here as it would set the value if
   345  				// not present
   346  				index := strconv.Itoa(p.filenameIndex)
   347  				argvIndex := p.program.Arrays["ARGV"]
   348  				argvArray := p.arrays[p.getArrayIndex(ScopeGlobal, argvIndex)]
   349  				filename := p.toString(argvArray[index])
   350  				p.filenameIndex++
   351  
   352  				// Is it actually a var=value assignment?
   353  				matches := varRegex.FindStringSubmatch(filename)
   354  				if len(matches) >= 3 {
   355  					// Yep, set variable to value and keep going
   356  					err := p.setVarByName(matches[1], matches[2])
   357  					if err != nil {
   358  						return "", err
   359  					}
   360  					continue
   361  				} else if filename == "" {
   362  					// ARGV arg is empty string, skip
   363  					p.input = nil
   364  					continue
   365  				} else if filename == "-" {
   366  					// ARGV arg is "-" meaning stdin
   367  					p.input = p.stdin
   368  					p.setFile("")
   369  				} else {
   370  					// A regular file name, open it
   371  					if p.noFileReads {
   372  						return "", newError("can't read from file due to NoFileReads")
   373  					}
   374  					input, err := os.Open(filename)
   375  					if err != nil {
   376  						return "", err
   377  					}
   378  					p.input = input
   379  					p.setFile(filename)
   380  					p.hadFiles = true
   381  				}
   382  			}
   383  			p.scanner = p.newScanner(p.input)
   384  		}
   385  		if p.scanner.Scan() {
   386  			// We scanned some input, break and return it
   387  			break
   388  		}
   389  		if err := p.scanner.Err(); err != nil {
   390  			return "", fmt.Errorf("error reading from input: %s", err)
   391  		}
   392  		// Signal loop to move onto next file
   393  		p.scanner = nil
   394  	}
   395  
   396  	// Got a line (record) of input, return it
   397  	p.lineNum++
   398  	p.fileLineNum++
   399  	return p.scanner.Text(), nil
   400  }
   401  
   402  // Write output string to given writer, producing correct line endings
   403  // on Windows (CR LF)
   404  func writeOutput(w io.Writer, s string) error {
   405  	if crlfNewline {
   406  		// First normalize to \n, then convert all newlines to \r\n
   407  		// (on Windows). NOTE: creating two new strings is almost
   408  		// certainly slow; would be better to create a custom Writer.
   409  		s = strings.Replace(s, "\r\n", "\n", -1)
   410  		s = strings.Replace(s, "\n", "\r\n", -1)
   411  	}
   412  	_, err := io.WriteString(w, s)
   413  	return err
   414  }
   415  
   416  // Close all streams, commands, etc (after program execution)
   417  func (p *interp) closeAll() {
   418  	if prevInput, ok := p.input.(io.Closer); ok {
   419  		_ = prevInput.Close()
   420  	}
   421  	for _, r := range p.inputStreams {
   422  		_ = r.Close()
   423  	}
   424  	for _, w := range p.outputStreams {
   425  		_ = w.Close()
   426  	}
   427  	for _, cmd := range p.commands {
   428  		_ = cmd.Wait()
   429  	}
   430  	if f, ok := p.output.(flusher); ok {
   431  		_ = f.Flush()
   432  	}
   433  	if f, ok := p.errorOutput.(flusher); ok {
   434  		_ = f.Flush()
   435  	}
   436  }
   437  
   438  // Flush all output streams as well as standard output. Report whether all
   439  // streams were flushed successfully (logging error(s) if not).
   440  func (p *interp) flushAll() bool {
   441  	allGood := true
   442  	for name, writer := range p.outputStreams {
   443  		allGood = allGood && p.flushWriter(name, writer)
   444  	}
   445  	if _, ok := p.output.(flusher); ok {
   446  		// User-provided output may or may not be flushable
   447  		allGood = allGood && p.flushWriter("stdout", p.output)
   448  	}
   449  	return allGood
   450  }
   451  
   452  // Flush a single, named output stream, and report whether it was flushed
   453  // successfully (logging an error if not).
   454  func (p *interp) flushStream(name string) bool {
   455  	writer := p.outputStreams[name]
   456  	if writer == nil {
   457  		fmt.Fprintf(p.errorOutput, "error flushing %q: not an output file or pipe\n", name)
   458  		return false
   459  	}
   460  	return p.flushWriter(name, writer)
   461  }
   462  
   463  type flusher interface {
   464  	Flush() error
   465  }
   466  
   467  // Flush given output writer, and report whether it was flushed successfully
   468  // (logging an error if not).
   469  func (p *interp) flushWriter(name string, writer io.Writer) bool {
   470  	flusher := writer.(flusher)
   471  	err := flusher.Flush()
   472  	if err != nil {
   473  		fmt.Fprintf(p.errorOutput, "error flushing %q: %v\n", name, err)
   474  		return false
   475  	}
   476  	return true
   477  }