github.com/matislovas/ratago@v0.0.0-20240408115641-cc0857415a7a/xslt/match.go (about) 1 package xslt 2 3 import ( 4 "container/list" 5 "strconv" 6 "strings" 7 "unicode/utf8" 8 9 "github.com/matislovas/gokogiri/xml" 10 "github.com/matislovas/gokogiri/xpath" 11 ) 12 13 type StepOperation int 14 15 const ( 16 OP_END StepOperation = iota 17 OP_ROOT 18 OP_ELEM 19 OP_ATTR 20 OP_PARENT 21 OP_ANCESTOR 22 OP_ID 23 OP_KEY 24 OP_NS 25 OP_ALL 26 OP_PI 27 OP_COMMENT 28 OP_TEXT 29 OP_NODE 30 OP_PREDICATE 31 OP_OR 32 OP_ERROR 33 ) 34 35 // An individual step in the pattern 36 type MatchStep struct { 37 Op StepOperation 38 Value string 39 } 40 41 // The compiled match pattern 42 type CompiledMatch struct { 43 pattern string 44 Steps []*MatchStep 45 Template *Template 46 } 47 48 type stateFn func(*lexer) stateFn 49 50 type lexer struct { 51 input string 52 start int 53 pos int 54 width int //really? 55 steps chan *MatchStep 56 } 57 58 func (l *lexer) run() { 59 for state := lexNodeTest; state != nil; { 60 state = state(l) 61 } 62 close(l.steps) 63 64 // the | operator 65 66 // see a ::, either set axis or emit error 67 // see a :, emit op_NS? or just modify next op? 68 // inside a () consume to close, check for validity of arguments 69 } 70 71 const eof = -1 72 73 // emit passes an item back to the client. 74 func (l *lexer) emit(t StepOperation) { 75 l.steps <- &MatchStep{t, l.input[l.start:l.pos]} 76 l.start = l.pos 77 } 78 79 func (l *lexer) next() (r rune) { 80 if l.pos >= len(l.input) { 81 l.width = 0 82 return eof 83 } 84 r, l.width = utf8.DecodeRuneInString(l.input[l.pos:]) 85 l.pos += l.width 86 return r 87 } 88 89 // ignore skips over the pending input before this point. 90 func (l *lexer) ignore() { 91 l.start = l.pos 92 } 93 94 // backup steps back one rune. 95 // Can be called only once per call of next. 96 func (l *lexer) backup() { 97 l.pos -= l.width 98 } 99 100 // peek returns but does not consume 101 // the next rune in the input. 102 func (l *lexer) peek() rune { 103 r := l.next() 104 l.backup() 105 return r 106 } 107 108 func lexNodeTest(l *lexer) stateFn { 109 attr := false 110 for { 111 r := l.next() 112 switch r { 113 case '/': 114 l.backup() 115 if l.pos > l.start { 116 if attr { 117 l.emit(OP_ATTR) 118 } else { 119 l.emit(OP_ELEM) 120 } 121 } 122 return lexParent 123 case '(': 124 l.backup() 125 if attr { 126 return lexAttrNodeTest 127 } else { 128 return lexFunctionCall 129 } 130 case '[': 131 l.backup() 132 if l.pos > l.start { 133 if attr { 134 l.emit(OP_ATTR) 135 } else { 136 l.emit(OP_ELEM) 137 } 138 } 139 return lexPredicate 140 case '@': 141 l.ignore() 142 attr = true 143 case '*': 144 if attr { 145 l.emit(OP_ATTR) 146 } else { 147 return lexAll 148 } 149 case ':': 150 if l.peek() == ':' { 151 //axis specifier 152 _ = l.next() 153 axisName := l.input[l.start:l.pos] 154 if axisName == "attribute::" { 155 attr = true 156 } 157 //TODO: only child and attribute axes allowed in pattern 158 l.ignore() 159 } else { 160 l.backup() 161 l.emit(OP_NS) 162 _ = l.next() 163 l.ignore() 164 } 165 case '|': 166 l.backup() 167 if l.pos > l.start { 168 if attr { 169 l.emit(OP_ATTR) 170 } else { 171 l.emit(OP_ELEM) 172 } 173 } 174 _ = l.next() 175 l.emit(OP_OR) 176 l.ignore() 177 return lexNodeTest 178 case ' ', '\t', '\r': 179 l.ignore() 180 default: 181 } 182 //switch? 183 if r == eof { 184 break 185 } 186 } 187 if l.pos > l.start { 188 if attr { 189 l.emit(OP_ATTR) 190 } else { 191 l.emit(OP_ELEM) 192 } 193 } 194 return nil 195 } 196 197 func lexFunctionCall(l *lexer) stateFn { 198 fnName := l.input[l.start:l.pos] 199 op := OP_ERROR 200 switch fnName { 201 case "comment": 202 op = OP_COMMENT 203 case "text": 204 op = OP_TEXT 205 case "node": 206 op = OP_NODE 207 case "id": 208 op = OP_ID 209 case "key": 210 op = OP_KEY 211 case "processing-instruction": 212 op = OP_PI 213 } 214 l.ignore() 215 depth := 0 216 for { 217 r := l.next() 218 if r == eof { 219 //TODO: parse error 220 break 221 } 222 if r == '(' { 223 depth = depth + 1 224 } 225 if r == ')' { 226 depth = depth - 1 227 if depth == 0 { 228 l.emit(op) 229 } 230 } 231 } 232 return lexNodeTest 233 } 234 235 func lexAttrNodeTest(l *lexer) stateFn { 236 fnName := l.input[l.start:l.pos] 237 op := OP_ERROR 238 switch fnName { 239 case "node": 240 op = OP_ATTR 241 } 242 l.ignore() 243 depth := 0 244 for { 245 r := l.next() 246 if r == eof { 247 //TODO: parse error 248 break 249 } 250 if r == '(' { 251 depth = depth + 1 252 } 253 if r == ')' { 254 depth = depth - 1 255 if depth == 0 { 256 l.steps <- &MatchStep{op, "*"} 257 l.start = l.pos 258 } 259 } 260 } 261 return lexNodeTest 262 } 263 264 func lexPredicate(l *lexer) stateFn { 265 depth := 0 266 for { 267 r := l.next() 268 if r == '[' { 269 depth = depth + 1 270 } 271 if r == ']' { 272 depth = depth - 1 273 if depth == 0 { 274 l.emit(OP_PREDICATE) 275 break 276 } 277 } 278 if r == eof { 279 //TODO: parse error 280 break 281 } 282 } 283 return lexNodeTest 284 } 285 286 func lexParent(l *lexer) stateFn { 287 _ = l.next() 288 if l.peek() == '/' { 289 _ = l.next() 290 //we can ignore it at the root! 291 if l.start == 0 { 292 l.ignore() 293 } else { 294 l.emit(OP_ANCESTOR) 295 } 296 return lexNodeTest 297 } 298 if l.start == 0 { 299 l.emit(OP_ROOT) 300 //return lexNodeTest 301 } 302 l.emit(OP_PARENT) 303 return lexNodeTest 304 } 305 306 func lexAll(l *lexer) stateFn { 307 l.emit(OP_ALL) 308 return lexNodeTest 309 } 310 311 func parseMatchPattern(s string) (steps []*MatchStep) { 312 //create a lexer 313 //run the state machine 314 // each state emits steps into the stream 315 // when it recognizes new state returns new state 316 // state returns nil when out of input 317 // break out of loop and close channel 318 //get the channel of steps 319 320 //range over the steps until we have them all 321 //reverse the array for fast matching? 322 //assign priority/mode 323 324 // for now shortcut the common ROOT 325 if s == "/" { 326 steps = []*MatchStep{{Op: OP_ROOT, Value: s}, {Op: OP_END}} 327 return 328 } 329 330 ls := list.New() 331 ls.PushFront(&MatchStep{Op: OP_END}) 332 333 // parse the expression 334 l := &lexer{input: s, steps: make(chan *MatchStep)} 335 go l.run() 336 337 // prepend steps to avoid reversing later 338 for step := range l.steps { 339 //we don't want predicates at the front 340 if step.Op == OP_PREDICATE { 341 //TODO: fix lexer to trim outer braces 342 step.Value = step.Value[1 : len(step.Value)-1] 343 ls.InsertAfter(step, ls.Front()) 344 } else { 345 ls.PushFront(step) 346 } 347 } 348 349 for i := ls.Front(); i != nil; i = i.Next() { 350 steps = append(steps, i.Value.(*MatchStep)) 351 } 352 return 353 } 354 355 func CompileMatch(s string, t *Template) (matches []*CompiledMatch) { 356 if s == "" { 357 return 358 } 359 steps := parseMatchPattern(s) 360 start := 0 361 for i, step := range steps { 362 if step.Op == OP_OR { 363 matches = append(matches, &CompiledMatch{s, steps[start:i], t}) 364 start = i + 1 365 } 366 } 367 matches = append(matches, &CompiledMatch{s, steps[start:], t}) 368 return 369 } 370 371 // Returns true if the node matches the pattern 372 func (m *CompiledMatch) EvalMatch(node xml.Node, mode string, context *ExecutionContext) bool { 373 cur := node 374 //false if wrong mode 375 // #all is an XSLT 2.0 feature 376 if m.Template != nil && mode != m.Template.Mode && m.Template.Mode != "#all" { 377 return false 378 } 379 380 for i, step := range m.Steps { 381 switch step.Op { 382 case OP_END: 383 return true 384 case OP_ROOT: 385 if cur.NodeType() != xml.XML_DOCUMENT_NODE { 386 return false 387 } 388 case OP_ELEM: 389 if cur.NodeType() != xml.XML_ELEMENT_NODE { 390 return false 391 } 392 if step.Value != cur.Name() && step.Value != "*" { 393 return false 394 } 395 case OP_NS: 396 uri := "" 397 // m.Template.Node 398 if m.Template != nil { 399 uri = context.LookupNamespace(step.Value, m.Template.Node) 400 } else { 401 uri = context.LookupNamespace(step.Value, nil) 402 } 403 if uri != cur.Namespace() { 404 return false 405 } 406 case OP_ATTR: 407 if cur.NodeType() != xml.XML_ATTRIBUTE_NODE { 408 return false 409 } 410 if step.Value != cur.Name() && step.Value != "*" { 411 return false 412 } 413 case OP_TEXT: 414 if cur.NodeType() != xml.XML_TEXT_NODE && cur.NodeType() != xml.XML_CDATA_SECTION_NODE { 415 return false 416 } 417 case OP_COMMENT: 418 if cur.NodeType() != xml.XML_COMMENT_NODE { 419 return false 420 } 421 case OP_ALL: 422 if cur.NodeType() != xml.XML_ELEMENT_NODE { 423 return false 424 } 425 case OP_PI: 426 if cur.NodeType() != xml.XML_PI_NODE { 427 return false 428 } 429 case OP_NODE: 430 switch cur.NodeType() { 431 case xml.XML_ELEMENT_NODE, xml.XML_CDATA_SECTION_NODE, xml.XML_TEXT_NODE, xml.XML_COMMENT_NODE, xml.XML_PI_NODE: 432 // matches any of these node types 433 default: 434 return false 435 } 436 case OP_PARENT: 437 cur = cur.Parent() 438 if cur == nil { 439 return false 440 } 441 case OP_ANCESTOR: 442 next := m.Steps[i+1] 443 if next.Op != OP_ELEM { 444 return false 445 } 446 for { 447 cur = cur.Parent() 448 if cur == nil { 449 return false 450 } 451 if next.Value == cur.Name() { 452 break 453 } 454 } 455 case OP_PREDICATE: 456 // see test REC/5.2-16 457 // see test REC/5.2-22 458 evalFull := true 459 if context != nil { 460 461 prev := m.Steps[i-1] 462 if prev.Op == OP_PREDICATE { 463 prev = m.Steps[i-2] 464 } 465 if prev.Op == OP_ELEM || prev.Op == OP_ALL { 466 parent := cur.Parent() 467 sibs := context.ChildrenOf(parent) 468 var clen, pos int 469 for _, n := range sibs { 470 if n.NodePtr() == cur.NodePtr() { 471 pos = clen + 1 472 clen = clen + 1 473 } else { 474 if n.NodeType() == xml.XML_ELEMENT_NODE { 475 if n.Name() == cur.Name() || prev.Op == OP_ALL { 476 clen = clen + 1 477 } 478 } 479 } 480 } 481 if step.Value == "last()" { 482 if pos != clen { 483 return false 484 } 485 } 486 //eval predicate should do special number handling 487 postest, err := strconv.Atoi(step.Value) 488 if err == nil { 489 if pos != postest { 490 return false 491 } 492 } 493 opos, olen := context.XPathContext.GetContextPosition() 494 context.XPathContext.SetContextPosition(pos, clen) 495 result := cur.EvalXPathAsBoolean(step.Value, context) 496 context.XPathContext.SetContextPosition(opos, olen) 497 if result == false { 498 return false 499 } 500 evalFull = false 501 } 502 } 503 if evalFull { 504 //if we made it this far, fall back to the more expensive option of evaluating 505 // the entire pattern globally 506 //TODO: cache results on first run for given document 507 xp := m.pattern 508 if m.pattern[0] != '/' { 509 xp = "//" + m.pattern 510 } 511 e := xpath.Compile(xp) 512 o, err := node.Search(e) 513 if err != nil { 514 //fmt.Println("ERROR",err) 515 } 516 for _, n := range o { 517 if cur.NodePtr() == n.NodePtr() { 518 return true 519 } 520 } 521 return false 522 } 523 524 case OP_ID: 525 //TODO: fix lexer to only put literal inside step value 526 val := strings.Trim(step.Value, "()\"'") 527 id := cur.MyDocument().NodeById(val) 528 if id == nil || node.NodePtr() != id.NodePtr() { 529 return false 530 } 531 case OP_KEY: 532 // TODO: make this robust 533 if context != nil { 534 val := strings.Trim(step.Value, "()") 535 v := strings.Split(val, ",") 536 keyname := strings.Trim(v[0], "\"'") 537 keyval := strings.Trim(v[1], "\"'") 538 key, _ := context.Style.Keys[keyname] 539 if key != nil { 540 o, _ := key.nodes[keyval] 541 for _, n := range o { 542 if cur.NodePtr() == n.NodePtr() { 543 return true 544 } 545 } 546 } 547 } 548 return false 549 default: 550 return false 551 } 552 } 553 //in theory, OP_END means we never reach here 554 // in practice, we can generate match patterns 555 // that are missing OP_END due to how we handle OP_OR 556 return true 557 } 558 559 func (m *CompiledMatch) Hash() (hash string) { 560 base := m.Steps[0] 561 switch base.Op { 562 case OP_ATTR: 563 return base.Value 564 case OP_ELEM: 565 return base.Value 566 case OP_ALL: 567 return "*" 568 case OP_ROOT: 569 return "/" 570 } 571 return 572 } 573 574 func (m *CompiledMatch) IsElement() bool { 575 op := m.Steps[0].Op 576 if op == OP_ELEM || op == OP_ROOT || op == OP_ALL { 577 return true 578 } 579 return false 580 } 581 582 func (m *CompiledMatch) IsAttr() bool { 583 op := m.Steps[0].Op 584 return op == OP_ATTR 585 } 586 587 func (m *CompiledMatch) IsNode() bool { 588 op := m.Steps[0].Op 589 return op == OP_NODE 590 } 591 592 func (m *CompiledMatch) IsPI() bool { 593 op := m.Steps[0].Op 594 return op == OP_PI 595 } 596 597 func (m *CompiledMatch) IsIdKey() bool { 598 op := m.Steps[0].Op 599 return op == OP_ID || op == OP_KEY 600 } 601 602 func (m *CompiledMatch) IsText() bool { 603 op := m.Steps[0].Op 604 return op == OP_TEXT 605 } 606 607 func (m *CompiledMatch) IsComment() bool { 608 op := m.Steps[0].Op 609 return op == OP_COMMENT 610 } 611 612 func (m *CompiledMatch) endsAfter(n int) bool { 613 steps := len(m.Steps) 614 if n == steps { 615 return true 616 } 617 if n+1 == steps && m.Steps[n].Op == OP_END { 618 return true 619 } 620 return false 621 } 622 623 func (m *CompiledMatch) DefaultPriority() (priority float64) { 624 //TODO: calculate defaults according to spec 625 step := m.Steps[0] 626 // * 627 if step.Op == OP_ALL { 628 if m.endsAfter(1) { 629 return -0.5 630 } 631 // ns:* 632 if m.endsAfter(2) && m.Steps[1].Op == OP_NS { 633 return -0.25 634 } 635 } 636 // @* 637 if step.Op == OP_ATTR && step.Value == "*" { 638 if m.endsAfter(1) { 639 return -0.5 640 } 641 if m.endsAfter(2) && m.Steps[1].Op == OP_NS { 642 return -0.25 643 } 644 } 645 // text(), node(), comment() 646 if step.Op == OP_TEXT || step.Op == OP_NODE || step.Op == OP_COMMENT { 647 if m.endsAfter(1) { 648 return -0.5 649 } 650 } 651 // QName 652 if step.Op == OP_ELEM { 653 if m.endsAfter(1) { 654 return 0 655 } 656 if m.endsAfter(2) && m.Steps[1].Op == OP_NS { 657 return 0 658 } 659 } 660 // @QName 661 if step.Op == OP_ATTR && step.Value != "*" { 662 if m.endsAfter(1) { 663 return 0 664 } 665 if m.endsAfter(2) && m.Steps[1].Op == OP_NS { 666 return 0 667 } 668 } 669 return 0.5 670 }