github.com/rajeev159/opa@v0.45.0/ast/internal/scanner/scanner.go (about)

     1  // Copyright 2020 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package scanner
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"unicode"
    12  	"unicode/utf8"
    13  
    14  	"github.com/open-policy-agent/opa/ast/internal/tokens"
    15  )
    16  
    17  const bom = 0xFEFF
    18  
    19  // Scanner is used to tokenize an input stream of
    20  // Rego source code.
    21  type Scanner struct {
    22  	offset   int
    23  	row      int
    24  	col      int
    25  	bs       []byte
    26  	curr     rune
    27  	width    int
    28  	errors   []Error
    29  	keywords map[string]tokens.Token
    30  }
    31  
    32  // Error represents a scanner error.
    33  type Error struct {
    34  	Pos     Position
    35  	Message string
    36  }
    37  
    38  // Position represents a point in the scanned source code.
    39  type Position struct {
    40  	Offset int // start offset in bytes
    41  	End    int // end offset in bytes
    42  	Row    int // line number computed in bytes
    43  	Col    int // column number computed in bytes
    44  }
    45  
    46  // New returns an initialized scanner that will scan
    47  // through the source code provided by the io.Reader.
    48  func New(r io.Reader) (*Scanner, error) {
    49  
    50  	bs, err := ioutil.ReadAll(r)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	s := &Scanner{
    56  		offset:   0,
    57  		row:      1,
    58  		col:      0,
    59  		bs:       bs,
    60  		curr:     -1,
    61  		width:    0,
    62  		keywords: tokens.Keywords(),
    63  	}
    64  
    65  	s.next()
    66  
    67  	if s.curr == bom {
    68  		s.next()
    69  	}
    70  
    71  	return s, nil
    72  }
    73  
    74  // Bytes returns the raw bytes for the full source
    75  // which the scanner has read in.
    76  func (s *Scanner) Bytes() []byte {
    77  	return s.bs
    78  }
    79  
    80  // String returns a human readable string of the current scanner state.
    81  func (s *Scanner) String() string {
    82  	return fmt.Sprintf("<curr: %q, offset: %d, len: %d>", s.curr, s.offset, len(s.bs))
    83  }
    84  
    85  // Keyword will return a token for the passed in
    86  // literal value. If the value is a Rego keyword
    87  // then the appropriate token is returned. Everything
    88  // else is an Ident.
    89  func (s *Scanner) Keyword(lit string) tokens.Token {
    90  	if tok, ok := s.keywords[lit]; ok {
    91  		return tok
    92  	}
    93  	return tokens.Ident
    94  }
    95  
    96  // AddKeyword adds a string -> token mapping to this Scanner instance.
    97  func (s *Scanner) AddKeyword(kw string, tok tokens.Token) {
    98  	s.keywords[kw] = tok
    99  
   100  	switch tok {
   101  	case tokens.Every: // importing 'every' means also importing 'in'
   102  		s.keywords["in"] = tokens.In
   103  	}
   104  }
   105  
   106  // WithKeywords returns a new copy of the Scanner struct `s`, with the set
   107  // of known keywords being that of `s` with `kws` added.
   108  func (s *Scanner) WithKeywords(kws map[string]tokens.Token) *Scanner {
   109  	cpy := *s
   110  	cpy.keywords = make(map[string]tokens.Token, len(s.keywords)+len(kws))
   111  	for kw, tok := range s.keywords {
   112  		cpy.AddKeyword(kw, tok)
   113  	}
   114  	for k, t := range kws {
   115  		cpy.AddKeyword(k, t)
   116  	}
   117  	return &cpy
   118  }
   119  
   120  // WithoutKeywords returns a new copy of the Scanner struct `s`, with the
   121  // set of known keywords being that of `s` with `kws` removed.
   122  // The previously known keywords are returned for a convenient reset.
   123  func (s *Scanner) WithoutKeywords(kws map[string]tokens.Token) (*Scanner, map[string]tokens.Token) {
   124  	cpy := *s
   125  	kw := s.keywords
   126  	cpy.keywords = make(map[string]tokens.Token, len(s.keywords)-len(kws))
   127  	for kw, tok := range s.keywords {
   128  		if _, ok := kws[kw]; !ok {
   129  			cpy.AddKeyword(kw, tok)
   130  		}
   131  	}
   132  	return &cpy, kw
   133  }
   134  
   135  // Scan will increment the scanners position in the source
   136  // code until the next token is found. The token, starting position
   137  // of the token, string literal, and any errors encountered are
   138  // returned. A token will always be returned, the caller must check
   139  // for any errors before using the other values.
   140  func (s *Scanner) Scan() (tokens.Token, Position, string, []Error) {
   141  
   142  	pos := Position{Offset: s.offset - s.width, Row: s.row, Col: s.col}
   143  	var tok tokens.Token
   144  	var lit string
   145  
   146  	if s.isWhitespace() {
   147  		lit = string(s.curr)
   148  		s.next()
   149  		tok = tokens.Whitespace
   150  	} else if isLetter(s.curr) {
   151  		lit = s.scanIdentifier()
   152  		tok = s.Keyword(lit)
   153  	} else if isDecimal(s.curr) {
   154  		lit = s.scanNumber()
   155  		tok = tokens.Number
   156  	} else {
   157  		ch := s.curr
   158  		s.next()
   159  		switch ch {
   160  		case -1:
   161  			tok = tokens.EOF
   162  		case '#':
   163  			lit = s.scanComment()
   164  			tok = tokens.Comment
   165  		case '"':
   166  			lit = s.scanString()
   167  			tok = tokens.String
   168  		case '`':
   169  			lit = s.scanRawString()
   170  			tok = tokens.String
   171  		case '[':
   172  			tok = tokens.LBrack
   173  		case ']':
   174  			tok = tokens.RBrack
   175  		case '{':
   176  			tok = tokens.LBrace
   177  		case '}':
   178  			tok = tokens.RBrace
   179  		case '(':
   180  			tok = tokens.LParen
   181  		case ')':
   182  			tok = tokens.RParen
   183  		case ',':
   184  			tok = tokens.Comma
   185  		case ':':
   186  			if s.curr == '=' {
   187  				s.next()
   188  				tok = tokens.Assign
   189  			} else {
   190  				tok = tokens.Colon
   191  			}
   192  		case '+':
   193  			tok = tokens.Add
   194  		case '-':
   195  			tok = tokens.Sub
   196  		case '*':
   197  			tok = tokens.Mul
   198  		case '/':
   199  			tok = tokens.Quo
   200  		case '%':
   201  			tok = tokens.Rem
   202  		case '&':
   203  			tok = tokens.And
   204  		case '|':
   205  			tok = tokens.Or
   206  		case '=':
   207  			if s.curr == '=' {
   208  				s.next()
   209  				tok = tokens.Equal
   210  			} else {
   211  				tok = tokens.Unify
   212  			}
   213  		case '>':
   214  			if s.curr == '=' {
   215  				s.next()
   216  				tok = tokens.Gte
   217  			} else {
   218  				tok = tokens.Gt
   219  			}
   220  		case '<':
   221  			if s.curr == '=' {
   222  				s.next()
   223  				tok = tokens.Lte
   224  			} else {
   225  				tok = tokens.Lt
   226  			}
   227  		case '!':
   228  			if s.curr == '=' {
   229  				s.next()
   230  				tok = tokens.Neq
   231  			} else {
   232  				s.error("illegal ! character")
   233  			}
   234  		case ';':
   235  			tok = tokens.Semicolon
   236  		case '.':
   237  			tok = tokens.Dot
   238  		}
   239  	}
   240  
   241  	pos.End = s.offset - s.width
   242  	errs := s.errors
   243  	s.errors = nil
   244  
   245  	return tok, pos, lit, errs
   246  }
   247  
   248  func (s *Scanner) scanIdentifier() string {
   249  	start := s.offset - 1
   250  	for isLetter(s.curr) || isDigit(s.curr) {
   251  		s.next()
   252  	}
   253  	return string(s.bs[start : s.offset-1])
   254  }
   255  
   256  func (s *Scanner) scanNumber() string {
   257  
   258  	start := s.offset - 1
   259  
   260  	if s.curr != '.' {
   261  		for isDecimal(s.curr) {
   262  			s.next()
   263  		}
   264  	}
   265  
   266  	if s.curr == '.' {
   267  		s.next()
   268  		var found bool
   269  		for isDecimal(s.curr) {
   270  			s.next()
   271  			found = true
   272  		}
   273  		if !found {
   274  			s.error("expected fraction")
   275  		}
   276  	}
   277  
   278  	if lower(s.curr) == 'e' {
   279  		s.next()
   280  		if s.curr == '+' || s.curr == '-' {
   281  			s.next()
   282  		}
   283  		var found bool
   284  		for isDecimal(s.curr) {
   285  			s.next()
   286  			found = true
   287  		}
   288  		if !found {
   289  			s.error("expected exponent")
   290  		}
   291  	}
   292  
   293  	// Scan any digits following the decimals to get the
   294  	// entire invalid number/identifier.
   295  	// Example: 0a2b should be a single invalid number "0a2b"
   296  	// rather than a number "0", followed by identifier "a2b".
   297  	if isLetter(s.curr) {
   298  		s.error("illegal number format")
   299  		for isLetter(s.curr) || isDigit(s.curr) {
   300  			s.next()
   301  		}
   302  	}
   303  
   304  	return string(s.bs[start : s.offset-1])
   305  }
   306  
   307  func (s *Scanner) scanString() string {
   308  	start := s.literalStart()
   309  	for {
   310  		ch := s.curr
   311  
   312  		if ch == '\n' || ch < 0 {
   313  			s.error("non-terminated string")
   314  			break
   315  		}
   316  
   317  		s.next()
   318  
   319  		if ch == '"' {
   320  			break
   321  		}
   322  
   323  		if ch == '\\' {
   324  			switch s.curr {
   325  			case '\\', '"', '/', 'b', 'f', 'n', 'r', 't':
   326  				s.next()
   327  			case 'u':
   328  				s.next()
   329  				s.next()
   330  				s.next()
   331  				s.next()
   332  			default:
   333  				s.error("illegal escape sequence")
   334  			}
   335  		}
   336  	}
   337  
   338  	return string(s.bs[start : s.offset-1])
   339  }
   340  
   341  func (s *Scanner) scanRawString() string {
   342  	start := s.literalStart()
   343  	for {
   344  		ch := s.curr
   345  		s.next()
   346  		if ch == '`' {
   347  			break
   348  		} else if ch < 0 {
   349  			s.error("non-terminated string")
   350  			break
   351  		}
   352  	}
   353  	return string(s.bs[start : s.offset-1])
   354  }
   355  
   356  func (s *Scanner) scanComment() string {
   357  	start := s.literalStart()
   358  	for s.curr != '\n' && s.curr != -1 {
   359  		s.next()
   360  	}
   361  	end := s.offset - 1
   362  	// Trim carriage returns that precede the newline
   363  	if s.offset > 1 && s.bs[s.offset-2] == '\r' {
   364  		end = end - 1
   365  	}
   366  	return string(s.bs[start:end])
   367  }
   368  
   369  func (s *Scanner) next() {
   370  
   371  	if s.offset >= len(s.bs) {
   372  		s.curr = -1
   373  		s.offset = len(s.bs) + 1
   374  		return
   375  	}
   376  
   377  	s.curr = rune(s.bs[s.offset])
   378  	s.width = 1
   379  
   380  	if s.curr == 0 {
   381  		s.error("illegal null character")
   382  	} else if s.curr >= utf8.RuneSelf {
   383  		s.curr, s.width = utf8.DecodeRune(s.bs[s.offset:])
   384  		if s.curr == utf8.RuneError && s.width == 1 {
   385  			s.error("illegal utf-8 character")
   386  		} else if s.curr == bom && s.offset > 0 {
   387  			s.error("illegal byte-order mark")
   388  		}
   389  	}
   390  
   391  	s.offset += s.width
   392  
   393  	if s.curr == '\n' {
   394  		s.row++
   395  		s.col = 0
   396  	} else {
   397  		s.col++
   398  	}
   399  }
   400  
   401  func (s *Scanner) literalStart() int {
   402  	// The current offset is at the first character past the literal delimiter (#, ", `, etc.)
   403  	// Need to subtract width of first character (plus one for the delimiter).
   404  	return s.offset - (s.width + 1)
   405  }
   406  
   407  // From the Go scanner (src/go/scanner/scanner.go)
   408  
   409  func isLetter(ch rune) bool {
   410  	return 'a' <= lower(ch) && lower(ch) <= 'z' || ch == '_'
   411  }
   412  
   413  func isDigit(ch rune) bool {
   414  	return isDecimal(ch) || ch >= utf8.RuneSelf && unicode.IsDigit(ch)
   415  }
   416  
   417  func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' }
   418  
   419  func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter
   420  
   421  func (s *Scanner) isWhitespace() bool {
   422  	return s.curr == ' ' || s.curr == '\t' || s.curr == '\n' || s.curr == '\r'
   423  }
   424  
   425  func (s *Scanner) error(reason string) {
   426  	s.errors = append(s.errors, Error{Pos: Position{
   427  		Offset: s.offset,
   428  		Row:    s.row,
   429  		Col:    s.col,
   430  	}, Message: reason})
   431  }