github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/rules.go (about) 1 // Copyright 2021 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package larking 6 7 import ( 8 "bytes" 9 "context" 10 "encoding/base64" 11 "encoding/json" 12 "fmt" 13 "net/http" 14 "net/textproto" 15 "net/url" 16 "sort" 17 "strconv" 18 "strings" 19 20 "google.golang.org/genproto/googleapis/api/annotations" 21 _ "google.golang.org/genproto/googleapis/api/httpbody" 22 "google.golang.org/grpc/codes" 23 "google.golang.org/grpc/metadata" 24 "google.golang.org/grpc/status" 25 "google.golang.org/protobuf/encoding/protojson" 26 "google.golang.org/protobuf/proto" 27 "google.golang.org/protobuf/reflect/protoreflect" 28 _ "google.golang.org/protobuf/types/descriptorpb" 29 "google.golang.org/protobuf/types/known/durationpb" 30 "google.golang.org/protobuf/types/known/fieldmaskpb" 31 "google.golang.org/protobuf/types/known/timestamppb" 32 "google.golang.org/protobuf/types/known/wrapperspb" 33 ) 34 35 // getExtensionHTTP 36 func getExtensionHTTP(m proto.Message) *annotations.HttpRule { 37 return proto.GetExtension(m, annotations.E_Http).(*annotations.HttpRule) 38 } 39 40 type variable struct { 41 name string // path.to.field=segment/*/** 42 toks tokens // segment/*/** 43 next *path 44 } 45 46 func (v *variable) String() string { 47 return fmt.Sprintf("%#v", v) 48 } 49 50 type variables []*variable 51 52 func (p variables) Len() int { return len(p) } 53 func (p variables) Less(i, j int) bool { return p[i].name < p[j].name } 54 func (p variables) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 55 56 type path struct { 57 segments map[string]*path // maps constants to path routes 58 variables variables // sorted array of variables 59 methods map[string]*method // maps http methods to grpc methods 60 methodAll *method // maps kind '*' 61 } 62 63 func (p *path) String() string { 64 var s, sp, sv, sm []string 65 for k, pp := range p.segments { 66 sp = append(sp, "\""+k+"\":"+pp.String()) 67 } 68 if len(sp) > 0 { 69 sort.Strings(sp) 70 s = append(s, "segments{"+strings.Join(sp, ",")+"}") 71 } 72 73 for _, vv := range p.variables { 74 sv = append(sv, "\"{"+vv.name+"}\"->"+vv.next.String()) 75 } 76 if len(sv) > 0 { 77 sort.Strings(sv) 78 s = append(s, "variables["+strings.Join(sv, ",")+"]") 79 } 80 81 for k, mm := range p.methods { 82 sm = append(sm, "\""+k+"\":"+mm.String()) 83 } 84 if len(sm) > 0 { 85 sort.Strings(sm) 86 s = append(s, "methods{"+strings.Join(sm, ",")+"}") 87 } 88 return "path{" + strings.Join(s, ",") + "}" 89 } 90 91 func (p *path) findVariable(name string) (*variable, bool) { 92 for _, v := range p.variables { 93 if v.name == name { 94 return v, true 95 } 96 } 97 return nil, false 98 } 99 100 func (p *path) addVariable(toks tokens) *variable { 101 name := toks.String() 102 if v, ok := p.findVariable(name); ok { 103 return v 104 } 105 v := &variable{ 106 name: name, 107 toks: toks, 108 next: newPath(), 109 } 110 p.variables = append(p.variables, v) 111 sort.Sort(p.variables) 112 return v 113 } 114 115 func (p *path) addPath(parent, value token) *path { 116 val := parent.val + value.val 117 if next, ok := p.segments[val]; ok { 118 return next 119 } 120 next := newPath() 121 p.segments[val] = next 122 return next 123 } 124 125 func newPath() *path { 126 return &path{ 127 segments: make(map[string]*path), 128 methods: make(map[string]*method), 129 } 130 } 131 132 type method struct { 133 desc protoreflect.MethodDescriptor 134 body []protoreflect.FieldDescriptor // body 135 vars [][]protoreflect.FieldDescriptor // variables on path 136 hasBody bool // body="*" or body="field.name" or body="" for no body 137 resp []protoreflect.FieldDescriptor // body=[""|"*"] 138 name string // /{ServiceName}/{MethodName} 139 } 140 141 func (m *method) String() string { 142 return m.name 143 } 144 145 func fieldPath(fieldDescs protoreflect.FieldDescriptors, names ...string) []protoreflect.FieldDescriptor { 146 fds := make([]protoreflect.FieldDescriptor, len(names)) 147 for i, name := range names { 148 fd := fieldDescs.ByJSONName(name) 149 if fd == nil { 150 fd = fieldDescs.ByName(protoreflect.Name(name)) 151 } 152 if fd == nil { 153 return nil 154 } 155 156 fds[i] = fd 157 158 // advance 159 if i != len(fds)-1 { 160 msgDesc := fd.Message() 161 if msgDesc == nil { 162 return nil 163 } 164 fieldDescs = msgDesc.Fields() 165 } 166 } 167 return fds 168 } 169 170 func (p *path) alive() bool { 171 return len(p.methods) != 0 || 172 len(p.variables) != 0 || 173 len(p.segments) != 0 174 } 175 176 // clone deep clones the path tree. 177 func (p *path) clone() *path { 178 pc := newPath() 179 if p == nil { 180 return pc 181 } 182 183 for k, s := range p.segments { 184 pc.segments[k] = s.clone() 185 } 186 187 pc.variables = make(variables, len(p.variables)) 188 for i, v := range p.variables { 189 pc.variables[i] = &variable{ 190 name: v.name, // RO 191 toks: v.toks, // RO 192 next: v.next.clone(), 193 } 194 } 195 196 for k, m := range p.methods { 197 pc.methods[k] = m // RO 198 } 199 pc.methodAll = p.methodAll 200 201 return pc 202 } 203 204 // delRule deletes the HTTP rule to the path. 205 func (p *path) delRule(name string) bool { 206 for k, s := range p.segments { 207 if ok := s.delRule(name); ok { 208 if !s.alive() { 209 delete(p.segments, k) 210 } 211 return ok 212 } 213 } 214 215 for i, v := range p.variables { 216 if ok := v.next.delRule(name); ok { 217 if !v.next.alive() { 218 p.variables = append( 219 p.variables[:i], p.variables[i+1:]..., 220 ) 221 } 222 return ok 223 } 224 } 225 226 for k, m := range p.methods { 227 if m.name == name { 228 delete(p.methods, k) 229 return true 230 } 231 } 232 return false 233 } 234 235 // addRule adds the HTTP rule to the path. 236 func (p *path) addRule( 237 rule *annotations.HttpRule, 238 desc protoreflect.MethodDescriptor, 239 name string, 240 ) error { 241 var tmpl, verb string 242 switch v := rule.Pattern.(type) { 243 case *annotations.HttpRule_Get: 244 verb = http.MethodGet 245 tmpl = v.Get 246 case *annotations.HttpRule_Put: 247 verb = http.MethodPut 248 tmpl = v.Put 249 case *annotations.HttpRule_Post: 250 verb = http.MethodPost 251 tmpl = v.Post 252 case *annotations.HttpRule_Delete: 253 verb = http.MethodDelete 254 tmpl = v.Delete 255 case *annotations.HttpRule_Patch: 256 verb = http.MethodPatch 257 tmpl = v.Patch 258 case *annotations.HttpRule_Custom: 259 verb = strings.ToUpper(v.Custom.Kind) 260 tmpl = v.Custom.Path 261 default: 262 return fmt.Errorf("unsupported pattern %v", v) 263 } 264 265 msgDesc := desc.Input() 266 fieldDescs := msgDesc.Fields() 267 268 // Hold state for the lexer. 269 l := &lexer{input: tmpl} 270 if err := lexTemplate(l); err != nil { 271 return err 272 } 273 274 var ( 275 i = 0 276 cursor = p 277 varfds [][]protoreflect.FieldDescriptor 278 ) 279 280 next := func() token { 281 i++ 282 return l.toks[i] 283 } 284 invalid := func(tok token) { panic(fmt.Sprintf("invalid token: %v", tok)) } 285 286 // Segments 287 tok := l.toks[i] 288 for ; tok.typ == tokenSlash; tok = next() { 289 switch val := next(); val.typ { 290 case tokenStar, tokenStarStar: 291 // TODO: Variables that don't capture the path. 292 panic("todo") 293 294 // Literal 295 case tokenValue: 296 cursor = cursor.addPath(tok, val) 297 298 // Variable 299 case tokenVariableStart: 300 // FieldPath 301 tok := next() 302 keys := []string{tok.val} 303 304 nxt := next() 305 for nxt.typ == tokenDot { 306 keys = append(keys, next().val) 307 nxt = next() 308 } 309 310 var vars tokens 311 switch nxt.typ { 312 case tokenEqual: 313 for nxt := next(); nxt.typ != tokenVariableEnd; nxt = next() { 314 vars = append(vars, nxt) 315 } 316 317 case tokenVariableEnd: 318 // default 319 vars = append(vars, token{ 320 typ: tokenStar, 321 val: "*", 322 }) 323 324 default: 325 invalid(nxt) 326 } 327 328 fds := fieldPath(fieldDescs, keys...) 329 if fds == nil { 330 return fmt.Errorf("field not found %v", keys) 331 } 332 varfds = append(varfds, fds) 333 334 v := cursor.addVariable(vars) 335 cursor = v.next 336 337 default: 338 invalid(tok) 339 } 340 } 341 342 switch tok.typ { 343 case tokenVerb: 344 // Literal 345 val := next() 346 cursor = cursor.addPath(tok, val) 347 // eof 348 349 case tokenEOF: 350 // eof 351 352 default: 353 invalid(tok) 354 } 355 356 if y, ok := cursor.methods[verb]; ok || cursor.methodAll != nil { 357 if y.desc.FullName() != desc.FullName() { 358 return fmt.Errorf("duplicate rule %v", rule) 359 } 360 return nil // Method already registered. 361 } 362 363 m := &method{ 364 desc: desc, 365 vars: varfds, 366 name: name, 367 } 368 switch rule.Body { 369 case "*": 370 m.hasBody = true 371 case "": 372 m.hasBody = false 373 default: 374 m.body = fieldPath(fieldDescs, strings.Split(rule.Body, ".")...) 375 if m.body == nil { 376 return fmt.Errorf("body field error %v", rule.Body) 377 } 378 m.hasBody = true 379 } 380 381 switch rule.ResponseBody { 382 case "": 383 default: 384 m.resp = fieldPath(fieldDescs, strings.Split(rule.Body, ".")...) 385 if m.resp == nil { 386 return fmt.Errorf("response body field error %v", rule.ResponseBody) 387 } 388 } 389 390 // register method 391 if verb == "*" { 392 cursor.methodAll = m 393 } else { 394 cursor.methods[verb] = m 395 } 396 397 for _, addRule := range rule.AdditionalBindings { 398 if len(addRule.AdditionalBindings) != 0 { 399 return fmt.Errorf("nested rules") // TODO: errors... 400 } 401 402 if err := p.addRule(addRule, desc, name); err != nil { 403 return err 404 } 405 } 406 407 return nil 408 } 409 410 type param struct { 411 fds []protoreflect.FieldDescriptor 412 val protoreflect.Value 413 } 414 415 func parseParam(fds []protoreflect.FieldDescriptor, raw []byte) (param, error) { 416 if len(fds) == 0 { 417 return param{}, fmt.Errorf("zero field") 418 } 419 fd := fds[len(fds)-1] 420 421 switch kind := fd.Kind(); kind { 422 case protoreflect.BoolKind: 423 var b bool 424 if err := json.Unmarshal(raw, &b); err != nil { 425 return param{}, err 426 } 427 return param{fds, protoreflect.ValueOfBool(b)}, nil 428 429 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 430 var x int32 431 if err := json.Unmarshal(raw, &x); err != nil { 432 return param{}, err 433 } 434 return param{fds, protoreflect.ValueOfInt32(x)}, nil 435 436 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 437 var x int64 438 if err := json.Unmarshal(raw, &x); err != nil { 439 return param{}, err 440 } 441 return param{fds, protoreflect.ValueOfInt64(x)}, nil 442 443 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 444 var x uint32 445 if err := json.Unmarshal(raw, &x); err != nil { 446 return param{}, err 447 } 448 return param{fds, protoreflect.ValueOfUint32(x)}, nil 449 450 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 451 var x uint64 452 if err := json.Unmarshal(raw, &x); err != nil { 453 return param{}, err 454 } 455 return param{fds, protoreflect.ValueOfUint64(x)}, nil 456 457 case protoreflect.FloatKind: 458 var x float32 459 if err := json.Unmarshal(raw, &x); err != nil { 460 return param{}, err 461 } 462 return param{fds, protoreflect.ValueOfFloat32(x)}, nil 463 464 case protoreflect.DoubleKind: 465 var x float64 466 if err := json.Unmarshal(raw, &x); err != nil { 467 return param{}, err 468 } 469 return param{fds, protoreflect.ValueOfFloat64(x)}, nil 470 471 case protoreflect.StringKind: 472 return param{fds, protoreflect.ValueOfString(string(raw))}, nil 473 474 case protoreflect.BytesKind: 475 enc := base64.StdEncoding 476 if bytes.ContainsAny(raw, "-_") { 477 enc = base64.URLEncoding 478 } 479 if len(raw)%4 != 0 { 480 enc = enc.WithPadding(base64.NoPadding) 481 } 482 483 dst := make([]byte, enc.DecodedLen(len(raw))) 484 n, err := enc.Decode(dst, raw) 485 if err != nil { 486 return param{}, err 487 } 488 return param{fds, protoreflect.ValueOfBytes(dst[:n])}, nil 489 490 case protoreflect.EnumKind: 491 var x int32 492 if err := json.Unmarshal(raw, &x); err == nil { 493 return param{fds, protoreflect.ValueOfEnum(protoreflect.EnumNumber(x))}, nil 494 } 495 496 s := string(raw) 497 if isNullValue(fd) && s == "null" { 498 return param{fds, protoreflect.ValueOfEnum(0)}, nil 499 } 500 501 enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)) 502 if enumVal == nil { 503 return param{}, fmt.Errorf("unexpected enum %s", raw) 504 } 505 return param{fds, protoreflect.ValueOfEnum(enumVal.Number())}, nil 506 507 case protoreflect.MessageKind: 508 // Well known JSON scalars are decoded to message types. 509 md := fd.Message() 510 switch md.FullName() { 511 case "google.protobuf.Timestamp": 512 var msg timestamppb.Timestamp 513 if err := protojson.Unmarshal(raw, &msg); err != nil { 514 return param{}, err 515 } 516 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 517 case "google.protobuf.Duration": 518 var msg durationpb.Duration 519 if err := protojson.Unmarshal(raw, &msg); err != nil { 520 return param{}, err 521 } 522 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 523 case "google.protobuf.BoolValue": 524 var msg wrapperspb.BoolValue 525 if err := protojson.Unmarshal(raw, &msg); err != nil { 526 return param{}, err 527 } 528 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 529 case "google.protobuf.Int32Value": 530 var msg wrapperspb.Int32Value 531 if err := protojson.Unmarshal(raw, &msg); err != nil { 532 return param{}, err 533 } 534 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 535 case "google.protobuf.Int64Value": 536 var msg wrapperspb.Int64Value 537 if err := protojson.Unmarshal(raw, &msg); err != nil { 538 return param{}, err 539 } 540 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 541 case "google.protobuf.UInt32Value": 542 var msg wrapperspb.UInt32Value 543 if err := protojson.Unmarshal(raw, &msg); err != nil { 544 return param{}, err 545 } 546 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 547 case "google.protobuf.UInt64Value": 548 var msg wrapperspb.UInt64Value 549 if err := protojson.Unmarshal(raw, &msg); err != nil { 550 return param{}, err 551 } 552 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 553 case "google.protobuf.FloatValue": 554 var msg wrapperspb.FloatValue 555 if err := protojson.Unmarshal(raw, &msg); err != nil { 556 return param{}, err 557 } 558 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 559 case "google.protobuf.DoubleValue": 560 var msg wrapperspb.DoubleValue 561 if err := protojson.Unmarshal(raw, &msg); err != nil { 562 return param{}, err 563 } 564 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 565 case "google.protobuf.BytesValue": 566 if n := len(raw); n > 0 && (raw[0] != '"' || raw[n-1] != '"') { 567 raw = []byte(strconv.Quote(string(raw))) 568 } 569 var msg wrapperspb.BytesValue 570 if err := protojson.Unmarshal(raw, &msg); err != nil { 571 return param{}, err 572 } 573 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 574 case "google.protobuf.StringValue": 575 if n := len(raw); n > 0 && (raw[0] != '"' || raw[n-1] != '"') { 576 raw = []byte(strconv.Quote(string(raw))) 577 } 578 var msg wrapperspb.StringValue 579 if err := protojson.Unmarshal(raw, &msg); err != nil { 580 return param{}, err 581 } 582 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 583 584 case "google.protobuf.FieldMask": 585 var msg fieldmaskpb.FieldMask 586 if err := protojson.Unmarshal(raw, &msg); err != nil { 587 return param{}, err 588 } 589 return param{fds, protoreflect.ValueOfMessage(msg.ProtoReflect())}, nil 590 default: 591 return param{}, fmt.Errorf("unexpected message type %s", md.FullName()) 592 } 593 594 default: 595 return param{}, fmt.Errorf("unknown param type %s", kind) 596 597 } 598 } 599 600 func isNullValue(fd protoreflect.FieldDescriptor) bool { 601 ed := fd.Enum() 602 return ed != nil && ed.FullName() == "google.protobuf.NullValue" 603 } 604 605 type params []param 606 607 func (ps params) set(m proto.Message) error { 608 for _, p := range ps { 609 cur := m.ProtoReflect() 610 for i, fd := range p.fds { 611 if len(p.fds)-1 == i { 612 cur.Set(fd, p.val) 613 break 614 } 615 616 // TODO: more types? 617 cur = cur.Mutable(fd).Message() 618 // IsList() 619 // IsMap() 620 } 621 } 622 return nil 623 } 624 625 func (m *method) parseQueryParams(values url.Values) (params, error) { 626 msgDesc := m.desc.Input() 627 fieldDescs := msgDesc.Fields() 628 629 var ps params 630 for key, vs := range values { 631 fds := fieldPath(fieldDescs, strings.Split(key, ".")...) 632 if fds == nil { 633 continue 634 } 635 636 for _, v := range vs { 637 p, err := parseParam(fds, []byte(v)) 638 if err != nil { 639 return nil, err 640 } 641 ps = append(ps, p) 642 } 643 } 644 return ps, nil 645 } 646 647 // index returns the capture length. 648 func (v *variable) index(toks tokens) int { 649 n := len(toks) 650 651 var i int 652 for _, tok := range v.toks { 653 if i == n { 654 return -1 655 } 656 657 switch tok.typ { 658 case tokenSlash: 659 if toks[i].typ != tok.typ { 660 return -1 661 } 662 i += 1 663 664 case tokenStar: 665 set := newTokenSet(tokenSlash, tokenVerb) 666 if j := toks.indexAny(set); j != -1 { 667 i += j 668 } else { 669 i = n // EOL 670 } 671 672 case tokenStarStar: 673 if j := toks.index(tokenVerb); j != -1 { 674 i += j 675 } else { 676 i = n // EOL 677 } 678 679 case tokenValue: 680 // TODO: tokenPath != tokenValue 681 if toks[i].typ != tokenPath || tok.val != toks[i].val { 682 return -1 683 } 684 i += 1 685 686 default: 687 panic(":(") 688 } 689 } 690 return i 691 } 692 693 // Depth first search preferring path segments over variables. 694 // Variables split the search tree: 695 // /path/{variable/*}/to/{end/**} ?:VERB 696 func (p *path) search(toks tokens, verb string) (*method, params, error) { 697 if n := len(toks); n <= 1 { 698 if m, ok := p.methods[verb]; ok { 699 return m, nil, nil 700 } 701 if m := p.methodAll; m != nil { 702 return m, nil, nil 703 } 704 return nil, nil, status.Error(codes.NotFound, "not found") 705 } 706 707 tt, tv := toks[0], toks[1] 708 segment := tt.val + tv.val 709 710 //fmt.Println("------------------") 711 //fmt.Println("~", segment, "~") 712 //defer fmt.Println("search end") 713 714 if next, ok := p.segments[segment]; ok { 715 if m, ps, err := next.search(toks[2:], verb); err == nil { 716 return m, ps, err 717 } 718 } 719 720 for _, v := range p.variables { 721 l := v.index(toks[1:]) + 1 // bump off / 722 if l == 0 { 723 continue 724 } 725 726 m, ps, err := v.next.search(toks[l:], verb) 727 if err != nil { 728 continue 729 } 730 731 capture := []byte(toks[1:l].String()) 732 733 fds := m.vars[len(m.vars)-len(ps)-1] 734 if p, err := parseParam(fds, capture); err != nil { 735 return nil, nil, err 736 } else { 737 ps = append(ps, p) 738 } 739 return m, ps, err 740 } 741 return nil, nil, status.Error(codes.NotFound, "not found") 742 } 743 744 // match the route to a method. 745 func (p *path) match(route, verb string) (*method, params, error) { 746 l := &lexer{input: route} 747 if err := lexPath(l); err != nil { 748 return nil, nil, status.Errorf(codes.NotFound, "not found: %v", err) 749 } 750 return p.search(l.toks, verb) 751 } 752 753 const httpHeaderPrefix = "http-" 754 755 func newIncomingContext(ctx context.Context, header http.Header) (context.Context, metadata.MD) { 756 md := make(metadata.MD, len(header)) 757 for k, vs := range header { 758 md.Set(httpHeaderPrefix+k, vs...) 759 } 760 return metadata.NewIncomingContext(ctx, md), md 761 } 762 763 func setOutgoingHeader(header http.Header, mds ...metadata.MD) { 764 for _, md := range mds { 765 for k, vs := range md { 766 if !strings.HasPrefix(k, httpHeaderPrefix) { 767 continue 768 } 769 k = k[len(httpHeaderPrefix):] 770 if len(k) == 0 { 771 continue 772 } 773 header[textproto.CanonicalMIMEHeaderKey(k)] = vs 774 } 775 } 776 }