github.com/jslyzt/glua@v0.0.0-20210819023911-4030c8e0234a/parse/lexer.go (about)

     1  package parse
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/jslyzt/cast"
    13  	"github.com/jslyzt/glua/ast"
    14  )
    15  
    16  // 常量定义
    17  const (
    18  	EOF         = -1
    19  	whitespace1 = 1<<'\t' | 1<<' '
    20  	whitespace2 = 1<<'\t' | 1<<'\n' | 1<<'\r' | 1<<' '
    21  )
    22  
    23  // Error 错误
    24  type Error struct {
    25  	Pos     ast.Position
    26  	Message string
    27  	Token   string
    28  }
    29  
    30  // Error 错误
    31  func (e *Error) Error() string {
    32  	pos := e.Pos
    33  	if pos.Line == EOF {
    34  		return fmt.Sprintf("%v at EOF:   %s\n", pos.Source, e.Message)
    35  	}
    36  	return fmt.Sprintf("%v line:%d(column:%d) near '%v':   %s\n", pos.Source, pos.Line, pos.Column, e.Token, e.Message)
    37  }
    38  
    39  func writeChar(buf *bytes.Buffer, c int) { buf.WriteByte(byte(c)) }
    40  
    41  func isDecimal(ch int) bool { return '0' <= ch && ch <= '9' }
    42  
    43  func isIdent(ch int, pos int) bool {
    44  	return ch == '_' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' || isDecimal(ch) && pos > 0
    45  }
    46  
    47  func isDigit(ch int) bool {
    48  	return '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
    49  }
    50  
    51  // Scanner 扫描
    52  type Scanner struct {
    53  	Pos    ast.Position
    54  	reader *bufio.Reader
    55  }
    56  
    57  // NewScanner 扫描器
    58  func NewScanner(reader io.Reader, source string) *Scanner {
    59  	return &Scanner{
    60  		Pos: ast.Position{
    61  			Source: source,
    62  			Line:   1,
    63  			Column: 0,
    64  		},
    65  		reader: bufio.NewReaderSize(reader, 4096),
    66  	}
    67  }
    68  
    69  // Error 错误
    70  func (sc *Scanner) Error(tok string, msg string) *Error {
    71  	return &Error{sc.Pos, msg, tok}
    72  }
    73  
    74  // TokenError token错误
    75  func (sc *Scanner) TokenError(tok ast.Token, msg string) *Error {
    76  	return &Error{tok.Pos, msg, tok.Str}
    77  }
    78  
    79  func (sc *Scanner) readNext() int {
    80  	ch, err := sc.reader.ReadByte()
    81  	if err == io.EOF {
    82  		return EOF
    83  	}
    84  	return int(ch)
    85  }
    86  
    87  // Newline new line
    88  func (sc *Scanner) Newline(ch int) {
    89  	if ch < 0 {
    90  		return
    91  	}
    92  	sc.Pos.Line++
    93  	sc.Pos.Column = 0
    94  	next := sc.Peek()
    95  	if ch == '\n' && next == '\r' || ch == '\r' && next == '\n' {
    96  		sc.reader.ReadByte()
    97  	}
    98  }
    99  
   100  // Next 下一步
   101  func (sc *Scanner) Next() int {
   102  	ch := sc.readNext()
   103  	switch ch {
   104  	case '\n', '\r':
   105  		sc.Newline(ch)
   106  		ch = int('\n')
   107  	case EOF:
   108  		sc.Pos.Line = EOF
   109  		sc.Pos.Column = 0
   110  	default:
   111  		sc.Pos.Column++
   112  	}
   113  	return ch
   114  }
   115  
   116  // Peek 尝试读取
   117  func (sc *Scanner) Peek() int {
   118  	ch := sc.readNext()
   119  	if ch != EOF {
   120  		sc.reader.UnreadByte()
   121  	}
   122  	return ch
   123  }
   124  
   125  func (sc *Scanner) skipWhiteSpace(whitespace int64) int {
   126  	ch := sc.Next()
   127  	for ; whitespace&(1<<uint(ch)) != 0; ch = sc.Next() {
   128  	}
   129  	return ch
   130  }
   131  
   132  func (sc *Scanner) skipComments(ch int) error {
   133  	// multiline comment
   134  	if sc.Peek() == '[' {
   135  		ch = sc.Next()
   136  		if sc.Peek() == '[' || sc.Peek() == '=' {
   137  			var buf bytes.Buffer
   138  			if err := sc.scanMultilineString(sc.Next(), &buf); err != nil {
   139  				return sc.Error(buf.String(), "invalid multiline comment")
   140  			}
   141  			return nil
   142  		}
   143  	}
   144  	for {
   145  		if ch == '\n' || ch == '\r' || ch < 0 {
   146  			break
   147  		}
   148  		ch = sc.Next()
   149  	}
   150  	return nil
   151  }
   152  
   153  func (sc *Scanner) scanIdent(ch int, buf *bytes.Buffer) error {
   154  	writeChar(buf, ch)
   155  	for isIdent(sc.Peek(), 1) {
   156  		writeChar(buf, sc.Next())
   157  	}
   158  	return nil
   159  }
   160  
   161  func (sc *Scanner) scanDecimal(ch int, buf *bytes.Buffer) error {
   162  	writeChar(buf, ch)
   163  	for isDecimal(sc.Peek()) {
   164  		writeChar(buf, sc.Next())
   165  	}
   166  	return nil
   167  }
   168  
   169  func (sc *Scanner) scanNumber(ch int, buf *bytes.Buffer) error {
   170  	if ch == '0' { // octal
   171  		if sc.Peek() == 'x' || sc.Peek() == 'X' {
   172  			writeChar(buf, ch)
   173  			writeChar(buf, sc.Next())
   174  			hasvalue := false
   175  			for isDigit(sc.Peek()) {
   176  				writeChar(buf, sc.Next())
   177  				hasvalue = true
   178  			}
   179  			if !hasvalue {
   180  				return sc.Error(buf.String(), "illegal hexadecimal number")
   181  			}
   182  			return nil
   183  		} else if sc.Peek() != '.' && isDecimal(sc.Peek()) {
   184  			ch = sc.Next()
   185  		}
   186  	}
   187  	sc.scanDecimal(ch, buf)
   188  	if sc.Peek() == '.' {
   189  		sc.scanDecimal(sc.Next(), buf)
   190  	}
   191  	if ch = sc.Peek(); ch == 'e' || ch == 'E' {
   192  		writeChar(buf, sc.Next())
   193  		if ch = sc.Peek(); ch == '-' || ch == '+' {
   194  			writeChar(buf, sc.Next())
   195  		}
   196  		sc.scanDecimal(sc.Next(), buf)
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  func (sc *Scanner) scanString(quote int, buf *bytes.Buffer) error {
   203  	ch := sc.Next()
   204  	for ch != quote {
   205  		if ch == '\n' || ch == '\r' || ch < 0 {
   206  			return sc.Error(buf.String(), "unterminated string")
   207  		}
   208  		if ch == '\\' {
   209  			if err := sc.scanEscape(buf); err != nil {
   210  				return err
   211  			}
   212  		} else {
   213  			writeChar(buf, ch)
   214  		}
   215  		ch = sc.Next()
   216  	}
   217  	return nil
   218  }
   219  
   220  func (sc *Scanner) scanEscape(buf *bytes.Buffer) error {
   221  	ch := sc.Next()
   222  	switch ch {
   223  	case 'a':
   224  		buf.WriteByte('\a')
   225  	case 'b':
   226  		buf.WriteByte('\b')
   227  	case 'f':
   228  		buf.WriteByte('\f')
   229  	case 'n':
   230  		buf.WriteByte('\n')
   231  	case 'r':
   232  		buf.WriteByte('\r')
   233  	case 't':
   234  		buf.WriteByte('\t')
   235  	case 'v':
   236  		buf.WriteByte('\v')
   237  	case '\\':
   238  		buf.WriteByte('\\')
   239  	case '"':
   240  		buf.WriteByte('"')
   241  	case '\'':
   242  		buf.WriteByte('\'')
   243  	case '\n':
   244  		buf.WriteByte('\n')
   245  	case '\r':
   246  		buf.WriteByte('\n')
   247  		sc.Newline('\r')
   248  	default:
   249  		if '0' <= ch && ch <= '9' {
   250  			bytes := []byte{byte(ch)}
   251  			for i := 0; i < 2 && isDecimal(sc.Peek()); i++ {
   252  				bytes = append(bytes, byte(sc.Next()))
   253  			}
   254  			val, _ := strconv.ParseInt(string(bytes), 10, 32)
   255  			writeChar(buf, int(val))
   256  		} else {
   257  			writeChar(buf, ch)
   258  		}
   259  	}
   260  	return nil
   261  }
   262  
   263  func (sc *Scanner) countSep(ch int) (int, int) {
   264  	count := 0
   265  	for ; ch == '='; count = count + 1 {
   266  		ch = sc.Next()
   267  	}
   268  	return count, ch
   269  }
   270  
   271  func (sc *Scanner) scanMultilineString(ch int, buf *bytes.Buffer) error {
   272  	var count1, count2 int
   273  	count1, ch = sc.countSep(ch)
   274  	if ch != '[' {
   275  		return sc.Error(cast.ToString(ch), "invalid multiline string")
   276  	}
   277  	ch = sc.Next()
   278  	if ch == '\n' || ch == '\r' {
   279  		ch = sc.Next()
   280  	}
   281  	for {
   282  		if ch < 0 {
   283  			return sc.Error(buf.String(), "unterminated multiline string")
   284  		} else if ch == ']' {
   285  			count2, ch = sc.countSep(sc.Next())
   286  			if count1 == count2 && ch == ']' {
   287  				goto finally
   288  			}
   289  			buf.WriteByte(']')
   290  			buf.WriteString(strings.Repeat("=", count2))
   291  			continue
   292  		}
   293  		writeChar(buf, ch)
   294  		ch = sc.Next()
   295  	}
   296  
   297  finally:
   298  	return nil
   299  }
   300  
   301  var reservedWords = map[string]int{
   302  	"and": TAnd, "break": TBreak, "do": TDo, "else": TElse, "elseif": TElseIf,
   303  	"end": TEnd, "false": TFalse, "for": TFor, "function": TFunction,
   304  	"if": TIf, "in": TIn, "local": TLocal, "nil": TNil, "not": TNot, "or": TOr,
   305  	"return": TReturn, "repeat": TRepeat, "then": TThen, "true": TTrue,
   306  	"until": TUntil, "while": TWhile,
   307  }
   308  
   309  // Scan 扫描
   310  func (sc *Scanner) Scan(lexer *Lexer) (ast.Token, error) {
   311  redo:
   312  	var err error
   313  	tok := ast.Token{}
   314  	newline := false
   315  
   316  	ch := sc.skipWhiteSpace(whitespace1)
   317  	if ch == '\n' || ch == '\r' {
   318  		newline = true
   319  		ch = sc.skipWhiteSpace(whitespace2)
   320  	}
   321  
   322  	if ch == '(' && lexer.PrevTokenType == ')' {
   323  		lexer.PNewLine = newline
   324  	} else {
   325  		lexer.PNewLine = false
   326  	}
   327  
   328  	var _buf bytes.Buffer
   329  	buf := &_buf
   330  	tok.Pos = sc.Pos
   331  
   332  	switch {
   333  	case isIdent(ch, 0):
   334  		tok.Type = TIdent
   335  		err = sc.scanIdent(ch, buf)
   336  		tok.Str = buf.String()
   337  		if err != nil {
   338  			goto finally
   339  		}
   340  		if typ, ok := reservedWords[tok.Str]; ok {
   341  			tok.Type = typ
   342  		}
   343  	case isDecimal(ch):
   344  		tok.Type = TNumber
   345  		err = sc.scanNumber(ch, buf)
   346  		tok.Str = buf.String()
   347  	default:
   348  		switch ch {
   349  		case EOF:
   350  			tok.Type = EOF
   351  		case '-':
   352  			if sc.Peek() == '-' {
   353  				err = sc.skipComments(sc.Next())
   354  				if err != nil {
   355  					goto finally
   356  				}
   357  				goto redo
   358  			} else {
   359  				tok.Type = ch
   360  				tok.Str = cast.ToString(ch)
   361  			}
   362  		case '"', '\'':
   363  			tok.Type = TString
   364  			err = sc.scanString(ch, buf)
   365  			tok.Str = buf.String()
   366  		case '[':
   367  			if c := sc.Peek(); c == '[' || c == '=' {
   368  				tok.Type = TString
   369  				err = sc.scanMultilineString(sc.Next(), buf)
   370  				tok.Str = buf.String()
   371  			} else {
   372  				tok.Type = ch
   373  				tok.Str = cast.ToString(ch)
   374  			}
   375  		case '=':
   376  			if sc.Peek() == '=' {
   377  				tok.Type = TEqeq
   378  				tok.Str = "=="
   379  				sc.Next()
   380  			} else {
   381  				tok.Type = ch
   382  				tok.Str = cast.ToString(ch)
   383  			}
   384  		case '~':
   385  			if sc.Peek() == '=' {
   386  				tok.Type = TNeq
   387  				tok.Str = "~="
   388  				sc.Next()
   389  			} else {
   390  				err = sc.Error("~", "Invalid '~' token")
   391  			}
   392  		case '<':
   393  			if sc.Peek() == '=' {
   394  				tok.Type = TLte
   395  				tok.Str = "<="
   396  				sc.Next()
   397  			} else {
   398  				tok.Type = ch
   399  				tok.Str = cast.ToString(ch)
   400  			}
   401  		case '>':
   402  			if sc.Peek() == '=' {
   403  				tok.Type = TGte
   404  				tok.Str = ">="
   405  				sc.Next()
   406  			} else {
   407  				tok.Type = ch
   408  				tok.Str = cast.ToString(ch)
   409  			}
   410  		case '.':
   411  			ch2 := sc.Peek()
   412  			switch {
   413  			case isDecimal(ch2):
   414  				tok.Type = TNumber
   415  				err = sc.scanNumber(ch, buf)
   416  				tok.Str = buf.String()
   417  			case ch2 == '.':
   418  				writeChar(buf, ch)
   419  				writeChar(buf, sc.Next())
   420  				if sc.Peek() == '.' {
   421  					writeChar(buf, sc.Next())
   422  					tok.Type = T3Comma
   423  				} else {
   424  					tok.Type = T2Comma
   425  				}
   426  			default:
   427  				tok.Type = '.'
   428  			}
   429  			tok.Str = buf.String()
   430  		case '+', '*', '/', '%', '^', '#', '(', ')', '{', '}', ']', ';', ':', ',':
   431  			tok.Type = ch
   432  			tok.Str = cast.ToString(ch)
   433  		default:
   434  			writeChar(buf, ch)
   435  			err = sc.Error(buf.String(), "Invalid token")
   436  			goto finally
   437  		}
   438  	}
   439  
   440  finally:
   441  	tok.Name = TokenName(int(tok.Type))
   442  	return tok, err
   443  }
   444  
   445  // yacc interface {{{
   446  
   447  // Lexer 分析
   448  type Lexer struct {
   449  	scanner       *Scanner
   450  	Stmts         []ast.Stmt
   451  	PNewLine      bool
   452  	Token         ast.Token
   453  	PrevTokenType int
   454  }
   455  
   456  // Lex 分析
   457  func (lx *Lexer) Lex(lval *yySymType) int {
   458  	lx.PrevTokenType = lx.Token.Type
   459  	tok, err := lx.scanner.Scan(lx)
   460  	if err != nil {
   461  		panic(err)
   462  	}
   463  	if tok.Type < 0 {
   464  		return 0
   465  	}
   466  	lval.token = tok
   467  	lx.Token = tok
   468  	return int(tok.Type)
   469  }
   470  
   471  // Error 错误
   472  func (lx *Lexer) Error(message string) {
   473  	panic(lx.scanner.Error(lx.Token.Str, message))
   474  }
   475  
   476  // TokenError token 错误
   477  func (lx *Lexer) TokenError(tok ast.Token, message string) {
   478  	panic(lx.scanner.TokenError(tok, message))
   479  }
   480  
   481  // Parse 解析
   482  func Parse(reader io.Reader, name string) (chunk []ast.Stmt, err error) {
   483  	lexer := &Lexer{NewScanner(reader, name), nil, false, ast.Token{Str: ""}, TNil}
   484  	chunk = nil
   485  	defer func() {
   486  		if e := recover(); e != nil {
   487  			err, _ = e.(error)
   488  		}
   489  	}()
   490  	yyParse(lexer)
   491  	chunk = lexer.Stmts
   492  	return
   493  }
   494  
   495  // Dump {{{
   496  
   497  func isInlineDumpNode(rv reflect.Value) bool {
   498  	switch rv.Kind() {
   499  	case reflect.Struct, reflect.Slice, reflect.Interface, reflect.Ptr:
   500  		return false
   501  	default:
   502  		return true
   503  	}
   504  }
   505  
   506  func dump(node interface{}, level int, s string) string {
   507  	rt := reflect.TypeOf(node)
   508  	if fmt.Sprint(rt) == "<nil>" {
   509  		return strings.Repeat(s, level) + "<nil>"
   510  	}
   511  
   512  	rv := reflect.ValueOf(node)
   513  	buf := []string{}
   514  	switch rt.Kind() {
   515  	case reflect.Slice:
   516  		if rv.Len() == 0 {
   517  			return strings.Repeat(s, level) + "<empty>"
   518  		}
   519  		for i := 0; i < rv.Len(); i++ {
   520  			buf = append(buf, dump(rv.Index(i).Interface(), level, s))
   521  		}
   522  	case reflect.Ptr:
   523  		vt := rv.Elem()
   524  		tt := rt.Elem()
   525  		indicies := []int{}
   526  		for i := 0; i < tt.NumField(); i++ {
   527  			if strings.Contains(tt.Field(i).Name, "Base") {
   528  				continue
   529  			}
   530  			indicies = append(indicies, i)
   531  		}
   532  		switch {
   533  		case len(indicies) == 0:
   534  			return strings.Repeat(s, level) + "<empty>"
   535  		case len(indicies) == 1 && isInlineDumpNode(vt.Field(indicies[0])):
   536  			for _, i := range indicies {
   537  				buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name()+": "+dump(vt.Field(i).Interface(), 0, s))
   538  			}
   539  		default:
   540  			buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name())
   541  			for _, i := range indicies {
   542  				if isInlineDumpNode(vt.Field(i)) {
   543  					inf := dump(vt.Field(i).Interface(), 0, s)
   544  					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": "+inf)
   545  				} else {
   546  					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": ")
   547  					buf = append(buf, dump(vt.Field(i).Interface(), level+2, s))
   548  				}
   549  			}
   550  		}
   551  	default:
   552  		buf = append(buf, strings.Repeat(s, level)+fmt.Sprint(node))
   553  	}
   554  	return strings.Join(buf, "\n")
   555  }
   556  
   557  // Dump 存储
   558  func Dump(chunk []ast.Stmt) string {
   559  	return dump(chunk, 0, "   ")
   560  }
   561  
   562  // }}