github.com/matislovas/ratago@v0.0.0-20240408115641-cc0857415a7a/xslt/match.go (about)

     1  package xslt
     2  
     3  import (
     4  	"container/list"
     5  	"strconv"
     6  	"strings"
     7  	"unicode/utf8"
     8  
     9  	"github.com/matislovas/gokogiri/xml"
    10  	"github.com/matislovas/gokogiri/xpath"
    11  )
    12  
    13  type StepOperation int
    14  
    15  const (
    16  	OP_END StepOperation = iota
    17  	OP_ROOT
    18  	OP_ELEM
    19  	OP_ATTR
    20  	OP_PARENT
    21  	OP_ANCESTOR
    22  	OP_ID
    23  	OP_KEY
    24  	OP_NS
    25  	OP_ALL
    26  	OP_PI
    27  	OP_COMMENT
    28  	OP_TEXT
    29  	OP_NODE
    30  	OP_PREDICATE
    31  	OP_OR
    32  	OP_ERROR
    33  )
    34  
    35  // An individual step in the pattern
    36  type MatchStep struct {
    37  	Op    StepOperation
    38  	Value string
    39  }
    40  
    41  // The compiled match pattern
    42  type CompiledMatch struct {
    43  	pattern  string
    44  	Steps    []*MatchStep
    45  	Template *Template
    46  }
    47  
    48  type stateFn func(*lexer) stateFn
    49  
    50  type lexer struct {
    51  	input string
    52  	start int
    53  	pos   int
    54  	width int //really?
    55  	steps chan *MatchStep
    56  }
    57  
    58  func (l *lexer) run() {
    59  	for state := lexNodeTest; state != nil; {
    60  		state = state(l)
    61  	}
    62  	close(l.steps)
    63  
    64  	// the | operator
    65  
    66  	// see a ::, either set axis or emit error
    67  	// see a :, emit op_NS?  or just modify next op?
    68  	// inside a () consume to close, check for validity of arguments
    69  }
    70  
    71  const eof = -1
    72  
    73  // emit passes an item back to the client.
    74  func (l *lexer) emit(t StepOperation) {
    75  	l.steps <- &MatchStep{t, l.input[l.start:l.pos]}
    76  	l.start = l.pos
    77  }
    78  
    79  func (l *lexer) next() (r rune) {
    80  	if l.pos >= len(l.input) {
    81  		l.width = 0
    82  		return eof
    83  	}
    84  	r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
    85  	l.pos += l.width
    86  	return r
    87  }
    88  
    89  // ignore skips over the pending input before this point.
    90  func (l *lexer) ignore() {
    91  	l.start = l.pos
    92  }
    93  
    94  // backup steps back one rune.
    95  // Can be called only once per call of next.
    96  func (l *lexer) backup() {
    97  	l.pos -= l.width
    98  }
    99  
   100  // peek returns but does not consume
   101  // the next rune in the input.
   102  func (l *lexer) peek() rune {
   103  	r := l.next()
   104  	l.backup()
   105  	return r
   106  }
   107  
   108  func lexNodeTest(l *lexer) stateFn {
   109  	attr := false
   110  	for {
   111  		r := l.next()
   112  		switch r {
   113  		case '/':
   114  			l.backup()
   115  			if l.pos > l.start {
   116  				if attr {
   117  					l.emit(OP_ATTR)
   118  				} else {
   119  					l.emit(OP_ELEM)
   120  				}
   121  			}
   122  			return lexParent
   123  		case '(':
   124  			l.backup()
   125  			if attr {
   126  				return lexAttrNodeTest
   127  			} else {
   128  				return lexFunctionCall
   129  			}
   130  		case '[':
   131  			l.backup()
   132  			if l.pos > l.start {
   133  				if attr {
   134  					l.emit(OP_ATTR)
   135  				} else {
   136  					l.emit(OP_ELEM)
   137  				}
   138  			}
   139  			return lexPredicate
   140  		case '@':
   141  			l.ignore()
   142  			attr = true
   143  		case '*':
   144  			if attr {
   145  				l.emit(OP_ATTR)
   146  			} else {
   147  				return lexAll
   148  			}
   149  		case ':':
   150  			if l.peek() == ':' {
   151  				//axis specifier
   152  				_ = l.next()
   153  				axisName := l.input[l.start:l.pos]
   154  				if axisName == "attribute::" {
   155  					attr = true
   156  				}
   157  				//TODO: only child and attribute axes allowed in pattern
   158  				l.ignore()
   159  			} else {
   160  				l.backup()
   161  				l.emit(OP_NS)
   162  				_ = l.next()
   163  				l.ignore()
   164  			}
   165  		case '|':
   166  			l.backup()
   167  			if l.pos > l.start {
   168  				if attr {
   169  					l.emit(OP_ATTR)
   170  				} else {
   171  					l.emit(OP_ELEM)
   172  				}
   173  			}
   174  			_ = l.next()
   175  			l.emit(OP_OR)
   176  			l.ignore()
   177  			return lexNodeTest
   178  		case ' ', '\t', '\r':
   179  			l.ignore()
   180  		default:
   181  		}
   182  		//switch?
   183  		if r == eof {
   184  			break
   185  		}
   186  	}
   187  	if l.pos > l.start {
   188  		if attr {
   189  			l.emit(OP_ATTR)
   190  		} else {
   191  			l.emit(OP_ELEM)
   192  		}
   193  	}
   194  	return nil
   195  }
   196  
   197  func lexFunctionCall(l *lexer) stateFn {
   198  	fnName := l.input[l.start:l.pos]
   199  	op := OP_ERROR
   200  	switch fnName {
   201  	case "comment":
   202  		op = OP_COMMENT
   203  	case "text":
   204  		op = OP_TEXT
   205  	case "node":
   206  		op = OP_NODE
   207  	case "id":
   208  		op = OP_ID
   209  	case "key":
   210  		op = OP_KEY
   211  	case "processing-instruction":
   212  		op = OP_PI
   213  	}
   214  	l.ignore()
   215  	depth := 0
   216  	for {
   217  		r := l.next()
   218  		if r == eof {
   219  			//TODO: parse error
   220  			break
   221  		}
   222  		if r == '(' {
   223  			depth = depth + 1
   224  		}
   225  		if r == ')' {
   226  			depth = depth - 1
   227  			if depth == 0 {
   228  				l.emit(op)
   229  			}
   230  		}
   231  	}
   232  	return lexNodeTest
   233  }
   234  
   235  func lexAttrNodeTest(l *lexer) stateFn {
   236  	fnName := l.input[l.start:l.pos]
   237  	op := OP_ERROR
   238  	switch fnName {
   239  	case "node":
   240  		op = OP_ATTR
   241  	}
   242  	l.ignore()
   243  	depth := 0
   244  	for {
   245  		r := l.next()
   246  		if r == eof {
   247  			//TODO: parse error
   248  			break
   249  		}
   250  		if r == '(' {
   251  			depth = depth + 1
   252  		}
   253  		if r == ')' {
   254  			depth = depth - 1
   255  			if depth == 0 {
   256  				l.steps <- &MatchStep{op, "*"}
   257  				l.start = l.pos
   258  			}
   259  		}
   260  	}
   261  	return lexNodeTest
   262  }
   263  
   264  func lexPredicate(l *lexer) stateFn {
   265  	depth := 0
   266  	for {
   267  		r := l.next()
   268  		if r == '[' {
   269  			depth = depth + 1
   270  		}
   271  		if r == ']' {
   272  			depth = depth - 1
   273  			if depth == 0 {
   274  				l.emit(OP_PREDICATE)
   275  				break
   276  			}
   277  		}
   278  		if r == eof {
   279  			//TODO: parse error
   280  			break
   281  		}
   282  	}
   283  	return lexNodeTest
   284  }
   285  
   286  func lexParent(l *lexer) stateFn {
   287  	_ = l.next()
   288  	if l.peek() == '/' {
   289  		_ = l.next()
   290  		//we can ignore it at the root!
   291  		if l.start == 0 {
   292  			l.ignore()
   293  		} else {
   294  			l.emit(OP_ANCESTOR)
   295  		}
   296  		return lexNodeTest
   297  	}
   298  	if l.start == 0 {
   299  		l.emit(OP_ROOT)
   300  		//return lexNodeTest
   301  	}
   302  	l.emit(OP_PARENT)
   303  	return lexNodeTest
   304  }
   305  
   306  func lexAll(l *lexer) stateFn {
   307  	l.emit(OP_ALL)
   308  	return lexNodeTest
   309  }
   310  
   311  func parseMatchPattern(s string) (steps []*MatchStep) {
   312  	//create a lexer
   313  	//run the state machine
   314  	// each state emits steps into the stream
   315  	// when it recognizes new state returns new state
   316  	// state returns nil when out of input
   317  	// break out of loop and close channel
   318  	//get the channel of steps
   319  
   320  	//range over the steps until we have them all
   321  	//reverse the array for fast matching?
   322  	//assign priority/mode
   323  
   324  	// for now shortcut the common ROOT
   325  	if s == "/" {
   326  		steps = []*MatchStep{{Op: OP_ROOT, Value: s}, {Op: OP_END}}
   327  		return
   328  	}
   329  
   330  	ls := list.New()
   331  	ls.PushFront(&MatchStep{Op: OP_END})
   332  
   333  	// parse the expression
   334  	l := &lexer{input: s, steps: make(chan *MatchStep)}
   335  	go l.run()
   336  
   337  	// prepend steps to avoid reversing later
   338  	for step := range l.steps {
   339  		//we don't want predicates at the front
   340  		if step.Op == OP_PREDICATE {
   341  			//TODO: fix lexer to trim outer braces
   342  			step.Value = step.Value[1 : len(step.Value)-1]
   343  			ls.InsertAfter(step, ls.Front())
   344  		} else {
   345  			ls.PushFront(step)
   346  		}
   347  	}
   348  
   349  	for i := ls.Front(); i != nil; i = i.Next() {
   350  		steps = append(steps, i.Value.(*MatchStep))
   351  	}
   352  	return
   353  }
   354  
   355  func CompileMatch(s string, t *Template) (matches []*CompiledMatch) {
   356  	if s == "" {
   357  		return
   358  	}
   359  	steps := parseMatchPattern(s)
   360  	start := 0
   361  	for i, step := range steps {
   362  		if step.Op == OP_OR {
   363  			matches = append(matches, &CompiledMatch{s, steps[start:i], t})
   364  			start = i + 1
   365  		}
   366  	}
   367  	matches = append(matches, &CompiledMatch{s, steps[start:], t})
   368  	return
   369  }
   370  
   371  // Returns true if the node matches the pattern
   372  func (m *CompiledMatch) EvalMatch(node xml.Node, mode string, context *ExecutionContext) bool {
   373  	cur := node
   374  	//false if wrong mode
   375  	// #all is an XSLT 2.0 feature
   376  	if m.Template != nil && mode != m.Template.Mode && m.Template.Mode != "#all" {
   377  		return false
   378  	}
   379  
   380  	for i, step := range m.Steps {
   381  		switch step.Op {
   382  		case OP_END:
   383  			return true
   384  		case OP_ROOT:
   385  			if cur.NodeType() != xml.XML_DOCUMENT_NODE {
   386  				return false
   387  			}
   388  		case OP_ELEM:
   389  			if cur.NodeType() != xml.XML_ELEMENT_NODE {
   390  				return false
   391  			}
   392  			if step.Value != cur.Name() && step.Value != "*" {
   393  				return false
   394  			}
   395  		case OP_NS:
   396  			uri := ""
   397  			// m.Template.Node
   398  			if m.Template != nil {
   399  				uri = context.LookupNamespace(step.Value, m.Template.Node)
   400  			} else {
   401  				uri = context.LookupNamespace(step.Value, nil)
   402  			}
   403  			if uri != cur.Namespace() {
   404  				return false
   405  			}
   406  		case OP_ATTR:
   407  			if cur.NodeType() != xml.XML_ATTRIBUTE_NODE {
   408  				return false
   409  			}
   410  			if step.Value != cur.Name() && step.Value != "*" {
   411  				return false
   412  			}
   413  		case OP_TEXT:
   414  			if cur.NodeType() != xml.XML_TEXT_NODE && cur.NodeType() != xml.XML_CDATA_SECTION_NODE {
   415  				return false
   416  			}
   417  		case OP_COMMENT:
   418  			if cur.NodeType() != xml.XML_COMMENT_NODE {
   419  				return false
   420  			}
   421  		case OP_ALL:
   422  			if cur.NodeType() != xml.XML_ELEMENT_NODE {
   423  				return false
   424  			}
   425  		case OP_PI:
   426  			if cur.NodeType() != xml.XML_PI_NODE {
   427  				return false
   428  			}
   429  		case OP_NODE:
   430  			switch cur.NodeType() {
   431  			case xml.XML_ELEMENT_NODE, xml.XML_CDATA_SECTION_NODE, xml.XML_TEXT_NODE, xml.XML_COMMENT_NODE, xml.XML_PI_NODE:
   432  				// matches any of these node types
   433  			default:
   434  				return false
   435  			}
   436  		case OP_PARENT:
   437  			cur = cur.Parent()
   438  			if cur == nil {
   439  				return false
   440  			}
   441  		case OP_ANCESTOR:
   442  			next := m.Steps[i+1]
   443  			if next.Op != OP_ELEM {
   444  				return false
   445  			}
   446  			for {
   447  				cur = cur.Parent()
   448  				if cur == nil {
   449  					return false
   450  				}
   451  				if next.Value == cur.Name() {
   452  					break
   453  				}
   454  			}
   455  		case OP_PREDICATE:
   456  			// see test REC/5.2-16
   457  			// see test REC/5.2-22
   458  			evalFull := true
   459  			if context != nil {
   460  
   461  				prev := m.Steps[i-1]
   462  				if prev.Op == OP_PREDICATE {
   463  					prev = m.Steps[i-2]
   464  				}
   465  				if prev.Op == OP_ELEM || prev.Op == OP_ALL {
   466  					parent := cur.Parent()
   467  					sibs := context.ChildrenOf(parent)
   468  					var clen, pos int
   469  					for _, n := range sibs {
   470  						if n.NodePtr() == cur.NodePtr() {
   471  							pos = clen + 1
   472  							clen = clen + 1
   473  						} else {
   474  							if n.NodeType() == xml.XML_ELEMENT_NODE {
   475  								if n.Name() == cur.Name() || prev.Op == OP_ALL {
   476  									clen = clen + 1
   477  								}
   478  							}
   479  						}
   480  					}
   481  					if step.Value == "last()" {
   482  						if pos != clen {
   483  							return false
   484  						}
   485  					}
   486  					//eval predicate should do special number handling
   487  					postest, err := strconv.Atoi(step.Value)
   488  					if err == nil {
   489  						if pos != postest {
   490  							return false
   491  						}
   492  					}
   493  					opos, olen := context.XPathContext.GetContextPosition()
   494  					context.XPathContext.SetContextPosition(pos, clen)
   495  					result := cur.EvalXPathAsBoolean(step.Value, context)
   496  					context.XPathContext.SetContextPosition(opos, olen)
   497  					if result == false {
   498  						return false
   499  					}
   500  					evalFull = false
   501  				}
   502  			}
   503  			if evalFull {
   504  				//if we made it this far, fall back to the more expensive option of evaluating
   505  				// the entire pattern globally
   506  				//TODO: cache results on first run for given document
   507  				xp := m.pattern
   508  				if m.pattern[0] != '/' {
   509  					xp = "//" + m.pattern
   510  				}
   511  				e := xpath.Compile(xp)
   512  				o, err := node.Search(e)
   513  				if err != nil {
   514  					//fmt.Println("ERROR",err)
   515  				}
   516  				for _, n := range o {
   517  					if cur.NodePtr() == n.NodePtr() {
   518  						return true
   519  					}
   520  				}
   521  				return false
   522  			}
   523  
   524  		case OP_ID:
   525  			//TODO: fix lexer to only put literal inside step value
   526  			val := strings.Trim(step.Value, "()\"'")
   527  			id := cur.MyDocument().NodeById(val)
   528  			if id == nil || node.NodePtr() != id.NodePtr() {
   529  				return false
   530  			}
   531  		case OP_KEY:
   532  			//  TODO: make this robust
   533  			if context != nil {
   534  				val := strings.Trim(step.Value, "()")
   535  				v := strings.Split(val, ",")
   536  				keyname := strings.Trim(v[0], "\"'")
   537  				keyval := strings.Trim(v[1], "\"'")
   538  				key, _ := context.Style.Keys[keyname]
   539  				if key != nil {
   540  					o, _ := key.nodes[keyval]
   541  					for _, n := range o {
   542  						if cur.NodePtr() == n.NodePtr() {
   543  							return true
   544  						}
   545  					}
   546  				}
   547  			}
   548  			return false
   549  		default:
   550  			return false
   551  		}
   552  	}
   553  	//in theory, OP_END means we never reach here
   554  	// in practice, we can generate match patterns
   555  	// that are missing OP_END due to how we handle OP_OR
   556  	return true
   557  }
   558  
   559  func (m *CompiledMatch) Hash() (hash string) {
   560  	base := m.Steps[0]
   561  	switch base.Op {
   562  	case OP_ATTR:
   563  		return base.Value
   564  	case OP_ELEM:
   565  		return base.Value
   566  	case OP_ALL:
   567  		return "*"
   568  	case OP_ROOT:
   569  		return "/"
   570  	}
   571  	return
   572  }
   573  
   574  func (m *CompiledMatch) IsElement() bool {
   575  	op := m.Steps[0].Op
   576  	if op == OP_ELEM || op == OP_ROOT || op == OP_ALL {
   577  		return true
   578  	}
   579  	return false
   580  }
   581  
   582  func (m *CompiledMatch) IsAttr() bool {
   583  	op := m.Steps[0].Op
   584  	return op == OP_ATTR
   585  }
   586  
   587  func (m *CompiledMatch) IsNode() bool {
   588  	op := m.Steps[0].Op
   589  	return op == OP_NODE
   590  }
   591  
   592  func (m *CompiledMatch) IsPI() bool {
   593  	op := m.Steps[0].Op
   594  	return op == OP_PI
   595  }
   596  
   597  func (m *CompiledMatch) IsIdKey() bool {
   598  	op := m.Steps[0].Op
   599  	return op == OP_ID || op == OP_KEY
   600  }
   601  
   602  func (m *CompiledMatch) IsText() bool {
   603  	op := m.Steps[0].Op
   604  	return op == OP_TEXT
   605  }
   606  
   607  func (m *CompiledMatch) IsComment() bool {
   608  	op := m.Steps[0].Op
   609  	return op == OP_COMMENT
   610  }
   611  
   612  func (m *CompiledMatch) endsAfter(n int) bool {
   613  	steps := len(m.Steps)
   614  	if n == steps {
   615  		return true
   616  	}
   617  	if n+1 == steps && m.Steps[n].Op == OP_END {
   618  		return true
   619  	}
   620  	return false
   621  }
   622  
   623  func (m *CompiledMatch) DefaultPriority() (priority float64) {
   624  	//TODO: calculate defaults according to spec
   625  	step := m.Steps[0]
   626  	// *
   627  	if step.Op == OP_ALL {
   628  		if m.endsAfter(1) {
   629  			return -0.5
   630  		}
   631  		// ns:*
   632  		if m.endsAfter(2) && m.Steps[1].Op == OP_NS {
   633  			return -0.25
   634  		}
   635  	}
   636  	// @*
   637  	if step.Op == OP_ATTR && step.Value == "*" {
   638  		if m.endsAfter(1) {
   639  			return -0.5
   640  		}
   641  		if m.endsAfter(2) && m.Steps[1].Op == OP_NS {
   642  			return -0.25
   643  		}
   644  	}
   645  	// text(), node(), comment()
   646  	if step.Op == OP_TEXT || step.Op == OP_NODE || step.Op == OP_COMMENT {
   647  		if m.endsAfter(1) {
   648  			return -0.5
   649  		}
   650  	}
   651  	// QName
   652  	if step.Op == OP_ELEM {
   653  		if m.endsAfter(1) {
   654  			return 0
   655  		}
   656  		if m.endsAfter(2) && m.Steps[1].Op == OP_NS {
   657  			return 0
   658  		}
   659  	}
   660  	// @QName
   661  	if step.Op == OP_ATTR && step.Value != "*" {
   662  		if m.endsAfter(1) {
   663  			return 0
   664  		}
   665  		if m.endsAfter(2) && m.Steps[1].Op == OP_NS {
   666  			return 0
   667  		}
   668  	}
   669  	return 0.5
   670  }