github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/starlark.go (about) 1 package larking 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "sort" 8 "strings" 9 "sync" 10 11 "github.com/emcfarlane/larking/starlib/encoding/starlarkproto" 12 "github.com/emcfarlane/larking/starlib/starext" 13 "github.com/emcfarlane/larking/starlib/starlarkthread" 14 "github.com/go-logr/logr" 15 "go.starlark.net/starlark" 16 "google.golang.org/grpc" 17 "google.golang.org/grpc/codes" 18 "google.golang.org/grpc/metadata" 19 "google.golang.org/grpc/status" 20 "google.golang.org/protobuf/proto" 21 "google.golang.org/protobuf/reflect/protoreflect" 22 "google.golang.org/protobuf/types/dynamicpb" 23 ) 24 25 func (m *Mux) String() string { return "mux" } 26 func (m *Mux) Type() string { return "mux" } 27 func (m *Mux) Freeze() {} // immutable 28 func (m *Mux) Truth() starlark.Bool { return starlark.True } 29 func (m *Mux) Hash() (uint32, error) { return 0, nil } 30 31 type muxAttr func(m *Mux) starlark.Value 32 33 var muxMethods = map[string]muxAttr{ 34 "service": func(m *Mux) starlark.Value { 35 return starext.MakeMethod(m, "service", m.openStarlarkService) 36 }, 37 "register_service": func(m *Mux) starlark.Value { 38 return starext.MakeMethod(m, "register", m.registerStarlarkService) 39 }, 40 } 41 42 func (m *Mux) Attr(name string) (starlark.Value, error) { 43 if a := muxMethods[name]; a != nil { 44 return a(m), nil 45 } 46 return nil, nil 47 } 48 func (v *Mux) AttrNames() []string { 49 names := make([]string, 0, len(muxMethods)) 50 for name := range muxMethods { 51 names = append(names, name) 52 } 53 sort.Strings(names) 54 return names 55 } 56 57 type StarlarkService struct { 58 mux *Mux 59 name string 60 } 61 62 func (m *Mux) openStarlarkService(_ *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 63 var name string 64 if err := starlark.UnpackPositionalArgs(fnname, args, nil, 1, &name); err != nil { 65 return nil, err 66 } 67 68 pfx := "/" + name 69 if state := m.loadState(); state != nil { 70 for method := range state.handlers { 71 if strings.HasPrefix(method, pfx) { 72 return &StarlarkService{ 73 mux: m, 74 name: name, 75 }, nil 76 } 77 78 } 79 } 80 return nil, status.Errorf(codes.NotFound, "unknown service: %s", name) 81 } 82 83 func starlarkUnimplemented(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 84 return nil, status.Errorf(codes.Unimplemented, "method %s not implemented", fnname) 85 } 86 87 func createStarlarkHandler( 88 parent *starlark.Thread, 89 fn starlark.Callable, 90 sd protoreflect.ServiceDescriptor, 91 md protoreflect.MethodDescriptor, 92 ) *handler { 93 94 argsDesc := md.Input() 95 replyDesc := md.Output() 96 97 method := fmt.Sprintf("/%s/%s", sd.FullName(), md.Name()) 98 99 isClientStream := md.IsStreamingClient() 100 isServerStream := md.IsStreamingServer() 101 if isClientStream || isServerStream { 102 //sd := &grpc.StreamDesc{ 103 // ServerStreams: md.IsStreamingServer(), 104 // ClientStreams: md.IsStreamingClient(), 105 //} 106 info := &grpc.StreamServerInfo{ 107 FullMethod: method, 108 IsClientStream: isClientStream, 109 IsServerStream: isServerStream, 110 } 111 112 // TODO: check not mutated. 113 //globals := starlib.NewGlobals() 114 115 fn := func(_ interface{}, stream grpc.ServerStream) (err error) { 116 ctx := stream.Context() 117 118 args := dynamicpb.NewMessage(argsDesc) 119 //reply := dynamicpb.NewMessage(replyDesc) 120 121 if err := stream.RecvMsg(args); err != nil { 122 return err 123 } 124 125 if md, ok := metadata.FromIncomingContext(ctx); ok { 126 ctx = metadata.NewOutgoingContext(ctx, md) 127 } 128 129 // build thread 130 l := logr.FromContextOrDiscard(ctx) 131 thread := &starlark.Thread{ 132 Name: parent.Name, 133 Print: func(_ *starlark.Thread, msg string) { 134 l.Info(msg, "thread", parent.Name) 135 }, 136 Load: parent.Load, 137 } 138 starlarkthread.SetContext(thread, ctx) 139 close := starlarkthread.WithResourceStore(thread) 140 defer func() { 141 if cerr := close(); err == nil { 142 err = cerr 143 } 144 }() 145 146 // TODO: streams. 147 return fmt.Errorf("unimplemented") 148 } 149 150 h := func(opts *muxOptions, stream grpc.ServerStream) error { 151 return opts.stream(nil, stream, info, fn) 152 } 153 154 return &handler{ 155 method: method, 156 descriptor: md, 157 handler: h, 158 } 159 } else { 160 info := &grpc.UnaryServerInfo{ 161 Server: nil, 162 FullMethod: method, 163 } 164 fn := func(ctx context.Context, args interface{}) (reply interface{}, err error) { 165 166 if md, ok := metadata.FromIncomingContext(ctx); ok { 167 ctx = metadata.NewOutgoingContext(ctx, md) 168 } 169 170 l := logr.FromContextOrDiscard(ctx) 171 thread := &starlark.Thread{ 172 Name: parent.Name, 173 Print: func(_ *starlark.Thread, msg string) { 174 l.Info(msg, "thread", parent.Name) 175 }, 176 Load: parent.Load, 177 } 178 starlarkthread.SetContext(thread, ctx) 179 close := starlarkthread.WithResourceStore(thread) 180 defer func() { 181 if cerr := close(); err == nil { 182 err = cerr 183 } 184 }() 185 186 msg, ok := args.(proto.Message) 187 if !ok { 188 return nil, fmt.Errorf("expected proto message") 189 } 190 191 reqpb, err := starlarkproto.NewMessage(msg.ProtoReflect(), nil, nil) 192 if err != nil { 193 return nil, err 194 } 195 196 v, err := starlark.Call(thread, fn, starlark.Tuple{reqpb}, nil) 197 if err != nil { 198 return nil, err 199 } 200 201 rsppb, ok := v.(*starlarkproto.Message) 202 if !ok { 203 return nil, fmt.Errorf("expected \"proto.message\" received %q", v.Type()) 204 } 205 rspMsg := rsppb.ProtoReflect() 206 // Compare FullName for multiple descriptor implementations. 207 if got, want := rspMsg.Descriptor().FullName(), replyDesc.FullName(); got != want { 208 return nil, fmt.Errorf("invalid response type %s, want %s", got, want) 209 } 210 return rspMsg.Interface(), nil 211 } 212 h := func(opts *muxOptions, stream grpc.ServerStream) error { 213 ctx := stream.Context() 214 args := dynamicpb.NewMessage(argsDesc) 215 216 if err := stream.RecvMsg(args); err != nil { 217 return err 218 } 219 220 reply, err := opts.unary(ctx, args, info, fn) 221 if err != nil { 222 return err 223 } 224 return stream.SendMsg(reply) 225 } 226 227 return &handler{ 228 method: method, 229 descriptor: md, 230 handler: h, 231 } 232 } 233 } 234 235 func (m *Mux) registerStarlarkService(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 236 237 var name string 238 if err := starlark.UnpackPositionalArgs(fnname, args, nil, 1, &name); err != nil { 239 return nil, err 240 } 241 242 resolver := starlarkproto.GetProtodescResolver(thread) 243 desc, err := resolver.FindDescriptorByName(protoreflect.FullName(name)) 244 if err != nil { 245 return nil, err 246 } 247 248 sd, ok := desc.(protoreflect.ServiceDescriptor) 249 if !ok { 250 return nil, status.Errorf(codes.InvalidArgument, "%q must be a service descriptor", name) 251 } 252 253 mds := sd.Methods() 254 255 // for each key, assign the service function. 256 257 mms := make(map[string]starlark.Callable) 258 259 for _, kwarg := range kwargs { 260 k := string(kwarg[0].(starlark.String)) 261 v := kwarg[1] 262 263 // Check 264 c, ok := v.(starlark.Callable) 265 if !ok { 266 return nil, status.Errorf(codes.InvalidArgument, "%s must be callable", k) 267 } 268 mms[k] = c 269 } 270 271 // Load the state for writing. 272 m.mu.Lock() 273 defer m.mu.Unlock() 274 s := m.loadState().clone() 275 276 for i, n := 0, mds.Len(); i < n; i++ { 277 md := mds.Get(i) 278 methodName := string(md.Name()) 279 280 c, ok := mms[methodName] 281 if !ok { 282 c = starext.MakeMethod(m, methodName, starlarkUnimplemented) 283 } 284 285 opts := md.Options() 286 287 rule := getExtensionHTTP(opts) 288 if rule == nil { 289 continue 290 } 291 hd := createStarlarkHandler(thread, c, sd, md) 292 if err := s.appendHandler(rule, md, hd); err != nil { 293 return nil, err 294 } 295 } 296 297 m.storeState(s) 298 return starlark.None, nil 299 } 300 301 func (s *StarlarkService) String() string { return s.name } 302 func (s *StarlarkService) Type() string { return "grpc.service" } 303 func (s *StarlarkService) Freeze() {} // immutable 304 func (s *StarlarkService) Truth() starlark.Bool { return starlark.True } 305 func (s *StarlarkService) Hash() (uint32, error) { return 0, nil } 306 307 // HasAttrs with each one being callable. 308 func (s *StarlarkService) Attr(name string) (starlark.Value, error) { 309 m := "/" + s.name + "/" + name 310 hd, err := s.mux.loadState().pickMethodHandler(m) 311 if err != nil { 312 return nil, nil // swallow error, reports missing attr. 313 } 314 315 if hd.descriptor.IsStreamingClient() || hd.descriptor.IsStreamingServer() { 316 ss := &StarlarkStream{ 317 mux: s.mux, 318 hd: hd, 319 } 320 321 return ss, nil 322 } 323 324 return &StarlarkUnary{ 325 mux: s.mux, 326 hd: hd, 327 }, nil 328 } 329 func (s *StarlarkService) AttrNames() []string { 330 var attrs []string 331 332 pfx := "/" + s.name + "/" 333 for method := range s.mux.loadState().handlers { 334 if strings.HasPrefix(method, pfx) { 335 attrs = append(attrs, strings.TrimPrefix(method, pfx)) 336 } 337 } 338 sort.Strings(attrs) 339 return attrs 340 } 341 342 type starlarkStream struct { 343 ctx context.Context 344 method string 345 sentHeader bool 346 header metadata.MD 347 trailer metadata.MD 348 ins chan func(proto.Message) error 349 outs chan func(proto.Message) error 350 } 351 352 func (s *starlarkStream) SetHeader(md metadata.MD) error { 353 if !s.sentHeader { 354 s.header = metadata.Join(s.header, md) 355 } 356 return nil 357 358 } 359 func (s *starlarkStream) SendHeader(md metadata.MD) error { 360 if s.sentHeader { 361 return nil // already sent? 362 } 363 // TODO: write header? 364 s.sentHeader = true 365 return nil 366 } 367 368 func (s *starlarkStream) SetTrailer(md metadata.MD) { 369 s.sentHeader = true 370 s.trailer = metadata.Join(s.trailer, md) 371 } 372 373 func (s *starlarkStream) Context() context.Context { 374 ctx, _ := newIncomingContext(s.ctx, nil) // TODO: remove me? 375 sts := &serverTransportStream{s, s.method} 376 return grpc.NewContextWithServerTransportStream(ctx, sts) 377 } 378 379 func (s *starlarkStream) SendMsg(m interface{}) error { 380 reply := m.(proto.Message) 381 select { 382 case fn := <-s.outs: 383 return fn(reply) 384 case <-s.ctx.Done(): 385 return s.ctx.Err() 386 } 387 } 388 389 func (s *starlarkStream) RecvMsg(m interface{}) error { 390 args := m.(proto.Message) 391 //msg := args.ProtoReflect() 392 393 select { 394 case fn := <-s.ins: 395 return fn(args) 396 case <-s.ctx.Done(): 397 return s.ctx.Err() 398 } 399 400 } 401 402 type StarlarkUnary struct { 403 mux *Mux 404 hd *handler 405 } 406 407 func (s *StarlarkUnary) String() string { return s.hd.method } 408 func (s *StarlarkUnary) Type() string { return "grpc.unary_method" } 409 func (s *StarlarkUnary) Freeze() {} // immutable 410 func (s *StarlarkUnary) Truth() starlark.Bool { return starlark.True } 411 func (s *StarlarkUnary) Hash() (uint32, error) { return 0, nil } 412 func (s *StarlarkUnary) Name() string { return "" } 413 func (s *StarlarkUnary) CallInternal(thread *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 414 ctx := starlarkthread.GetContext(thread) 415 opts := &s.mux.opts 416 417 // Buffer channels one message for unary. 418 stream := &starlarkStream{ 419 ctx: ctx, 420 method: s.hd.method, 421 ins: make(chan func(proto.Message) error, 1), 422 outs: make(chan func(proto.Message) error, 1), 423 } 424 425 stream.ins <- func(msg proto.Message) error { 426 arg := msg.ProtoReflect() 427 428 // Capture starlark arguments. 429 _, err := starlarkproto.NewMessage(arg, args, kwargs) 430 return err 431 } 432 433 var rsp *starlarkproto.Message 434 stream.outs <- func(msg proto.Message) error { 435 arg := msg.ProtoReflect() 436 437 val, err := starlarkproto.NewMessage(arg, nil, nil) 438 rsp = val 439 return err 440 } 441 442 if err := s.hd.handler(opts, stream); err != nil { 443 return nil, err 444 } 445 return rsp, nil 446 } 447 448 type StarlarkStream struct { 449 mux *Mux 450 hd *handler 451 452 once sync.Once 453 cancel func() 454 stream *starlarkStream 455 456 onceErr sync.Once 457 err error 458 } 459 460 func (s *StarlarkStream) setErr(err error) { 461 s.onceErr.Do(func() { s.err = err }) 462 } 463 func (s *StarlarkStream) getErr() error { 464 s.setErr(nil) // blow away onceErr 465 return s.err 466 } 467 468 // init lazy initializes the streaming handler. 469 func (s *StarlarkStream) init(thread *starlark.Thread) error { 470 ctx := starlarkthread.GetContext(thread) 471 opts := &s.mux.opts 472 473 s.once.Do(func() { 474 if err := starlarkthread.AddResource(thread, s); err != nil { 475 s.setErr(err) 476 return 477 } 478 479 ctx, cancel := context.WithCancel(ctx) 480 s.cancel = cancel 481 s.stream = &starlarkStream{ 482 ctx: ctx, 483 method: s.hd.method, 484 ins: make(chan func(proto.Message) error), 485 outs: make(chan func(proto.Message) error), 486 } 487 488 // Start the handler 489 go func() { 490 s.onceErr.Do(func() { 491 s.err = s.hd.handler(opts, s.stream) 492 }) 493 cancel() 494 }() 495 }) 496 if s.stream == nil || s.stream.ctx.Err() != nil { 497 return io.EOF // cancelled before starting or cancelled 498 } 499 return nil 500 } 501 502 func (s *StarlarkStream) String() string { return s.hd.method } 503 func (s *StarlarkStream) Type() string { return "grpc.stream_method" } 504 func (s *StarlarkStream) Freeze() {} // immutable??? 505 func (s *StarlarkStream) Truth() starlark.Bool { return starlark.True } 506 func (s *StarlarkStream) Hash() (uint32, error) { return 0, nil } 507 func (s *StarlarkStream) Name() string { return "" } 508 509 func (s *StarlarkStream) Close() error { 510 s.once.Do(func() {}) // blow the once away 511 if s.cancel == nil { 512 return nil // never started 513 } 514 s.cancel() 515 return s.getErr() 516 } 517 518 func (s *StarlarkStream) Attr(name string) (starlark.Value, error) { 519 if a := starlarkStreamAttrs[name]; a != nil { 520 return a(s), nil 521 } 522 return nil, nil 523 } 524 func (v *StarlarkStream) AttrNames() []string { 525 names := make([]string, 0, len(starlarkStreamAttrs)) 526 for name := range starlarkStreamAttrs { 527 names = append(names, name) 528 } 529 sort.Strings(names) 530 return names 531 } 532 533 type starlarkStreamAttr func(*StarlarkStream) starlark.Value 534 535 var starlarkStreamAttrs = map[string]starlarkStreamAttr{ 536 "recv": func(s *StarlarkStream) starlark.Value { 537 return starext.MakeMethod(s, "recv", s.recv) 538 }, 539 "send": func(s *StarlarkStream) starlark.Value { 540 return starext.MakeMethod(s, "send", s.send) 541 }, 542 } 543 544 type starlarkResponse struct { 545 val starlark.Value 546 err error 547 } 548 549 func promiseResponse( 550 ctx context.Context, args starlark.Tuple, kwargs []starlark.Tuple, 551 ) (func(proto.Message) error, <-chan starlarkResponse) { 552 ch := make(chan starlarkResponse) 553 554 return func(msg proto.Message) error { 555 arg := msg.ProtoReflect() 556 557 val, err := starlarkproto.NewMessage(arg, args, kwargs) 558 select { 559 case ch <- starlarkResponse{val: val, err: err}: 560 return err 561 case <-ctx.Done(): 562 return ctx.Err() 563 } 564 }, ch 565 } 566 567 func (s *StarlarkStream) recv(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 568 ctx := starlarkthread.GetContext(thread) 569 if err := s.init(thread); err != nil { 570 return nil, err 571 } 572 573 if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 0); err != nil { 574 return nil, err 575 } 576 577 fn, ch := promiseResponse(ctx, nil, nil) 578 579 select { 580 case <-ctx.Done(): 581 return nil, ctx.Err() 582 case <-s.stream.ctx.Done(): 583 return nil, s.getErr() 584 case s.stream.outs <- fn: 585 rsp := <-ch 586 return rsp.val, rsp.err 587 } 588 } 589 func (s *StarlarkStream) send(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 590 ctx := starlarkthread.GetContext(thread) 591 if err := s.init(thread); err != nil { 592 return nil, err 593 } 594 595 fn, ch := promiseResponse(ctx, args, kwargs) 596 597 select { 598 case <-ctx.Done(): 599 return nil, ctx.Err() 600 case <-s.stream.ctx.Done(): 601 return nil, s.getErr() 602 case s.stream.ins <- fn: 603 rsp := <-ch 604 return starlark.None, rsp.err 605 } 606 }