github.com/dfcfw/lua@v0.0.0-20230325031207-0cc7ffb7b8b9/parse/lexer.go (about)

     1  package parse
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/dfcfw/lua/ast"
    13  )
    14  
    15  const (
    16  	EOF         = -1
    17  	whitespace1 = 1<<'\t' | 1<<' '
    18  	whitespace2 = 1<<'\t' | 1<<'\n' | 1<<'\r' | 1<<' '
    19  )
    20  
    21  type Error struct {
    22  	Pos     ast.Position
    23  	Message string
    24  	Token   string
    25  }
    26  
    27  func (e *Error) Error() string {
    28  	pos := e.Pos
    29  	if pos.Line == EOF {
    30  		return fmt.Sprintf("%v at EOF:   %s\n", pos.Source, e.Message)
    31  	} else {
    32  		return fmt.Sprintf("%v line:%d(column:%d) near '%v':   %s\n", pos.Source, pos.Line, pos.Column, e.Token, e.Message)
    33  	}
    34  }
    35  
    36  func writeChar(buf *bytes.Buffer, c int) { buf.WriteByte(byte(c)) }
    37  
    38  func isDecimal(ch int) bool { return '0' <= ch && ch <= '9' }
    39  
    40  func isIdent(ch int, pos int) bool {
    41  	return ch == '_' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' || isDecimal(ch) && pos > 0
    42  }
    43  
    44  func isDigit(ch int) bool {
    45  	return '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
    46  }
    47  
    48  type Scanner struct {
    49  	Pos    ast.Position
    50  	reader *bufio.Reader
    51  }
    52  
    53  func NewScanner(reader io.Reader, source string) *Scanner {
    54  	return &Scanner{
    55  		Pos: ast.Position{
    56  			Source: source,
    57  			Line:   1,
    58  			Column: 0,
    59  		},
    60  		reader: bufio.NewReaderSize(reader, 4096),
    61  	}
    62  }
    63  
    64  func (sc *Scanner) Error(tok string, msg string) *Error { return &Error{sc.Pos, msg, tok} }
    65  
    66  func (sc *Scanner) TokenError(tok ast.Token, msg string) *Error { return &Error{tok.Pos, msg, tok.Str} }
    67  
    68  func (sc *Scanner) readNext() int {
    69  	ch, err := sc.reader.ReadByte()
    70  	if err == io.EOF {
    71  		return EOF
    72  	}
    73  	return int(ch)
    74  }
    75  
    76  func (sc *Scanner) Newline(ch int) {
    77  	if ch < 0 {
    78  		return
    79  	}
    80  	sc.Pos.Line += 1
    81  	sc.Pos.Column = 0
    82  	next := sc.Peek()
    83  	if ch == '\n' && next == '\r' || ch == '\r' && next == '\n' {
    84  		sc.reader.ReadByte()
    85  	}
    86  }
    87  
    88  func (sc *Scanner) Next() int {
    89  	ch := sc.readNext()
    90  	switch ch {
    91  	case '\n', '\r':
    92  		sc.Newline(ch)
    93  		ch = int('\n')
    94  	case EOF:
    95  		sc.Pos.Line = EOF
    96  		sc.Pos.Column = 0
    97  	default:
    98  		sc.Pos.Column++
    99  	}
   100  	return ch
   101  }
   102  
   103  func (sc *Scanner) Peek() int {
   104  	ch := sc.readNext()
   105  	if ch != EOF {
   106  		sc.reader.UnreadByte()
   107  	}
   108  	return ch
   109  }
   110  
   111  func (sc *Scanner) skipWhiteSpace(whitespace int64) int {
   112  	ch := sc.Next()
   113  	for ; whitespace&(1<<uint(ch)) != 0; ch = sc.Next() {
   114  	}
   115  	return ch
   116  }
   117  
   118  func (sc *Scanner) skipComments(ch int) error {
   119  	// multiline comment
   120  	if sc.Peek() == '[' {
   121  		ch = sc.Next()
   122  		if sc.Peek() == '[' || sc.Peek() == '=' {
   123  			var buf bytes.Buffer
   124  			if err := sc.scanMultilineString(sc.Next(), &buf); err != nil {
   125  				return sc.Error(buf.String(), "invalid multiline comment")
   126  			}
   127  			return nil
   128  		}
   129  	}
   130  	for {
   131  		if ch == '\n' || ch == '\r' || ch < 0 {
   132  			break
   133  		}
   134  		ch = sc.Next()
   135  	}
   136  	return nil
   137  }
   138  
   139  func (sc *Scanner) scanIdent(ch int, buf *bytes.Buffer) error {
   140  	writeChar(buf, ch)
   141  	for isIdent(sc.Peek(), 1) {
   142  		writeChar(buf, sc.Next())
   143  	}
   144  	return nil
   145  }
   146  
   147  func (sc *Scanner) scanDecimal(ch int, buf *bytes.Buffer) error {
   148  	writeChar(buf, ch)
   149  	for isDecimal(sc.Peek()) {
   150  		writeChar(buf, sc.Next())
   151  	}
   152  	return nil
   153  }
   154  
   155  func (sc *Scanner) scanNumber(ch int, buf *bytes.Buffer) error {
   156  	if ch == '0' { // octal
   157  		if sc.Peek() == 'x' || sc.Peek() == 'X' {
   158  			writeChar(buf, ch)
   159  			writeChar(buf, sc.Next())
   160  			hasvalue := false
   161  			for isDigit(sc.Peek()) {
   162  				writeChar(buf, sc.Next())
   163  				hasvalue = true
   164  			}
   165  			if !hasvalue {
   166  				return sc.Error(buf.String(), "illegal hexadecimal number")
   167  			}
   168  			return nil
   169  		} else if sc.Peek() != '.' && isDecimal(sc.Peek()) {
   170  			ch = sc.Next()
   171  		}
   172  	}
   173  	sc.scanDecimal(ch, buf)
   174  	if sc.Peek() == '.' {
   175  		sc.scanDecimal(sc.Next(), buf)
   176  	}
   177  	if ch = sc.Peek(); ch == 'e' || ch == 'E' {
   178  		writeChar(buf, sc.Next())
   179  		if ch = sc.Peek(); ch == '-' || ch == '+' {
   180  			writeChar(buf, sc.Next())
   181  		}
   182  		sc.scanDecimal(sc.Next(), buf)
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func (sc *Scanner) scanString(quote int, buf *bytes.Buffer) error {
   189  	ch := sc.Next()
   190  	for ch != quote {
   191  		if ch == '\n' || ch == '\r' || ch < 0 {
   192  			return sc.Error(buf.String(), "unterminated string")
   193  		}
   194  		if ch == '\\' {
   195  			if err := sc.scanEscape(ch, buf); err != nil {
   196  				return err
   197  			}
   198  		} else {
   199  			writeChar(buf, ch)
   200  		}
   201  		ch = sc.Next()
   202  	}
   203  	return nil
   204  }
   205  
   206  func (sc *Scanner) scanEscape(ch int, buf *bytes.Buffer) error {
   207  	ch = sc.Next()
   208  	switch ch {
   209  	case 'a':
   210  		buf.WriteByte('\a')
   211  	case 'b':
   212  		buf.WriteByte('\b')
   213  	case 'f':
   214  		buf.WriteByte('\f')
   215  	case 'n':
   216  		buf.WriteByte('\n')
   217  	case 'r':
   218  		buf.WriteByte('\r')
   219  	case 't':
   220  		buf.WriteByte('\t')
   221  	case 'v':
   222  		buf.WriteByte('\v')
   223  	case '\\':
   224  		buf.WriteByte('\\')
   225  	case '"':
   226  		buf.WriteByte('"')
   227  	case '\'':
   228  		buf.WriteByte('\'')
   229  	case '\n':
   230  		buf.WriteByte('\n')
   231  	case '\r':
   232  		buf.WriteByte('\n')
   233  		sc.Newline('\r')
   234  	default:
   235  		if '0' <= ch && ch <= '9' {
   236  			bytes := []byte{byte(ch)}
   237  			for i := 0; i < 2 && isDecimal(sc.Peek()); i++ {
   238  				bytes = append(bytes, byte(sc.Next()))
   239  			}
   240  			val, _ := strconv.ParseInt(string(bytes), 10, 32)
   241  			writeChar(buf, int(val))
   242  		} else {
   243  			writeChar(buf, ch)
   244  		}
   245  	}
   246  	return nil
   247  }
   248  
   249  func (sc *Scanner) countSep(ch int) (int, int) {
   250  	count := 0
   251  	for ; ch == '='; count = count + 1 {
   252  		ch = sc.Next()
   253  	}
   254  	return count, ch
   255  }
   256  
   257  func (sc *Scanner) scanMultilineString(ch int, buf *bytes.Buffer) error {
   258  	var count1, count2 int
   259  	count1, ch = sc.countSep(ch)
   260  	if ch != '[' {
   261  		return sc.Error(string(rune(ch)), "invalid multiline string")
   262  	}
   263  	ch = sc.Next()
   264  	if ch == '\n' || ch == '\r' {
   265  		ch = sc.Next()
   266  	}
   267  	for {
   268  		if ch < 0 {
   269  			return sc.Error(buf.String(), "unterminated multiline string")
   270  		} else if ch == ']' {
   271  			count2, ch = sc.countSep(sc.Next())
   272  			if count1 == count2 && ch == ']' {
   273  				goto finally
   274  			}
   275  			buf.WriteByte(']')
   276  			buf.WriteString(strings.Repeat("=", count2))
   277  			continue
   278  		}
   279  		writeChar(buf, ch)
   280  		ch = sc.Next()
   281  	}
   282  
   283  finally:
   284  	return nil
   285  }
   286  
   287  var reservedWords = map[string]int{
   288  	"and": TAnd, "break": TBreak, "do": TDo, "else": TElse, "elseif": TElseIf,
   289  	"end": TEnd, "false": TFalse, "for": TFor, "function": TFunction,
   290  	"if": TIf, "in": TIn, "local": TLocal, "nil": TNil, "not": TNot, "or": TOr,
   291  	"return": TReturn, "repeat": TRepeat, "then": TThen, "true": TTrue,
   292  	"until": TUntil, "while": TWhile, "goto": TGoto,
   293  }
   294  
   295  func (sc *Scanner) Scan(lexer *Lexer) (ast.Token, error) {
   296  redo:
   297  	var err error
   298  	tok := ast.Token{}
   299  	newline := false
   300  
   301  	ch := sc.skipWhiteSpace(whitespace1)
   302  	if ch == '\n' || ch == '\r' {
   303  		newline = true
   304  		ch = sc.skipWhiteSpace(whitespace2)
   305  	}
   306  
   307  	if ch == '(' && lexer.PrevTokenType == ')' {
   308  		lexer.PNewLine = newline
   309  	} else {
   310  		lexer.PNewLine = false
   311  	}
   312  
   313  	var _buf bytes.Buffer
   314  	buf := &_buf
   315  	tok.Pos = sc.Pos
   316  
   317  	switch {
   318  	case isIdent(ch, 0):
   319  		tok.Type = TIdent
   320  		err = sc.scanIdent(ch, buf)
   321  		tok.Str = buf.String()
   322  		if err != nil {
   323  			goto finally
   324  		}
   325  		if typ, ok := reservedWords[tok.Str]; ok {
   326  			tok.Type = typ
   327  		}
   328  	case isDecimal(ch):
   329  		tok.Type = TNumber
   330  		err = sc.scanNumber(ch, buf)
   331  		tok.Str = buf.String()
   332  	default:
   333  		switch ch {
   334  		case EOF:
   335  			tok.Type = EOF
   336  		case '-':
   337  			if sc.Peek() == '-' {
   338  				err = sc.skipComments(sc.Next())
   339  				if err != nil {
   340  					goto finally
   341  				}
   342  				goto redo
   343  			} else {
   344  				tok.Type = ch
   345  				tok.Str = string(rune(ch))
   346  			}
   347  		case '"', '\'':
   348  			tok.Type = TString
   349  			err = sc.scanString(ch, buf)
   350  			tok.Str = buf.String()
   351  		case '[':
   352  			if c := sc.Peek(); c == '[' || c == '=' {
   353  				tok.Type = TString
   354  				err = sc.scanMultilineString(sc.Next(), buf)
   355  				tok.Str = buf.String()
   356  			} else {
   357  				tok.Type = ch
   358  				tok.Str = string(rune(ch))
   359  			}
   360  		case '=':
   361  			if sc.Peek() == '=' {
   362  				tok.Type = TEqeq
   363  				tok.Str = "=="
   364  				sc.Next()
   365  			} else {
   366  				tok.Type = ch
   367  				tok.Str = string(rune(ch))
   368  			}
   369  		case '~':
   370  			if sc.Peek() == '=' {
   371  				tok.Type = TNeq
   372  				tok.Str = "~="
   373  				sc.Next()
   374  			} else {
   375  				err = sc.Error("~", "Invalid '~' token")
   376  			}
   377  		case '<':
   378  			if sc.Peek() == '=' {
   379  				tok.Type = TLte
   380  				tok.Str = "<="
   381  				sc.Next()
   382  			} else {
   383  				tok.Type = ch
   384  				tok.Str = string(rune(ch))
   385  			}
   386  		case '>':
   387  			if sc.Peek() == '=' {
   388  				tok.Type = TGte
   389  				tok.Str = ">="
   390  				sc.Next()
   391  			} else {
   392  				tok.Type = ch
   393  				tok.Str = string(rune(ch))
   394  			}
   395  		case '.':
   396  			ch2 := sc.Peek()
   397  			switch {
   398  			case isDecimal(ch2):
   399  				tok.Type = TNumber
   400  				err = sc.scanNumber(ch, buf)
   401  				tok.Str = buf.String()
   402  			case ch2 == '.':
   403  				writeChar(buf, ch)
   404  				writeChar(buf, sc.Next())
   405  				if sc.Peek() == '.' {
   406  					writeChar(buf, sc.Next())
   407  					tok.Type = T3Comma
   408  				} else {
   409  					tok.Type = T2Comma
   410  				}
   411  			default:
   412  				tok.Type = '.'
   413  			}
   414  			tok.Str = buf.String()
   415  		case ':':
   416  			if sc.Peek() == ':' {
   417  				tok.Type = T2Colon
   418  				tok.Str = "::"
   419  				sc.Next()
   420  			} else {
   421  				tok.Type = ch
   422  				tok.Str = string(rune(ch))
   423  			}
   424  		case '+', '*', '/', '%', '^', '#', '(', ')', '{', '}', ']', ';', ',':
   425  			tok.Type = ch
   426  			tok.Str = string(rune(ch))
   427  		default:
   428  			writeChar(buf, ch)
   429  			err = sc.Error(buf.String(), "Invalid token")
   430  			goto finally
   431  		}
   432  	}
   433  
   434  finally:
   435  	tok.Name = TokenName(int(tok.Type))
   436  	return tok, err
   437  }
   438  
   439  // yacc interface {{{
   440  
   441  type Lexer struct {
   442  	scanner       *Scanner
   443  	Stmts         []ast.Stmt
   444  	PNewLine      bool
   445  	Token         ast.Token
   446  	PrevTokenType int
   447  }
   448  
   449  func (lx *Lexer) Lex(lval *yySymType) int {
   450  	lx.PrevTokenType = lx.Token.Type
   451  	tok, err := lx.scanner.Scan(lx)
   452  	if err != nil {
   453  		panic(err)
   454  	}
   455  	if tok.Type < 0 {
   456  		return 0
   457  	}
   458  	lval.token = tok
   459  	lx.Token = tok
   460  	return int(tok.Type)
   461  }
   462  
   463  func (lx *Lexer) Error(message string) {
   464  	panic(lx.scanner.Error(lx.Token.Str, message))
   465  }
   466  
   467  func (lx *Lexer) TokenError(tok ast.Token, message string) {
   468  	panic(lx.scanner.TokenError(tok, message))
   469  }
   470  
   471  func Parse(reader io.Reader, name string) (chunk []ast.Stmt, err error) {
   472  	lexer := &Lexer{NewScanner(reader, name), nil, false, ast.Token{Str: ""}, TNil}
   473  	chunk = nil
   474  	defer func() {
   475  		if e := recover(); e != nil {
   476  			err, _ = e.(error)
   477  		}
   478  	}()
   479  	yyParse(lexer)
   480  	chunk = lexer.Stmts
   481  	return
   482  }
   483  
   484  // }}}
   485  
   486  // Dump {{{
   487  
   488  func isInlineDumpNode(rv reflect.Value) bool {
   489  	switch rv.Kind() {
   490  	case reflect.Struct, reflect.Slice, reflect.Interface, reflect.Ptr:
   491  		return false
   492  	default:
   493  		return true
   494  	}
   495  }
   496  
   497  func dump(node interface{}, level int, s string) string {
   498  	rt := reflect.TypeOf(node)
   499  	if fmt.Sprint(rt) == "<nil>" {
   500  		return strings.Repeat(s, level) + "<nil>"
   501  	}
   502  
   503  	rv := reflect.ValueOf(node)
   504  	buf := []string{}
   505  	switch rt.Kind() {
   506  	case reflect.Slice:
   507  		if rv.Len() == 0 {
   508  			return strings.Repeat(s, level) + "<empty>"
   509  		}
   510  		for i := 0; i < rv.Len(); i++ {
   511  			buf = append(buf, dump(rv.Index(i).Interface(), level, s))
   512  		}
   513  	case reflect.Ptr:
   514  		vt := rv.Elem()
   515  		tt := rt.Elem()
   516  		indicies := []int{}
   517  		for i := 0; i < tt.NumField(); i++ {
   518  			if strings.Index(tt.Field(i).Name, "Base") > -1 {
   519  				continue
   520  			}
   521  			indicies = append(indicies, i)
   522  		}
   523  		switch {
   524  		case len(indicies) == 0:
   525  			return strings.Repeat(s, level) + "<empty>"
   526  		case len(indicies) == 1 && isInlineDumpNode(vt.Field(indicies[0])):
   527  			for _, i := range indicies {
   528  				buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name()+": "+dump(vt.Field(i).Interface(), 0, s))
   529  			}
   530  		default:
   531  			buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name())
   532  			for _, i := range indicies {
   533  				if isInlineDumpNode(vt.Field(i)) {
   534  					inf := dump(vt.Field(i).Interface(), 0, s)
   535  					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": "+inf)
   536  				} else {
   537  					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": ")
   538  					buf = append(buf, dump(vt.Field(i).Interface(), level+2, s))
   539  				}
   540  			}
   541  		}
   542  	default:
   543  		buf = append(buf, strings.Repeat(s, level)+fmt.Sprint(node))
   544  	}
   545  	return strings.Join(buf, "\n")
   546  }
   547  
   548  func Dump(chunk []ast.Stmt) string {
   549  	return dump(chunk, 0, "   ")
   550  }
   551  
   552  // }}