github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/starlark.go (about)

     1  package larking
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"sort"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/emcfarlane/larking/starlib/encoding/starlarkproto"
    12  	"github.com/emcfarlane/larking/starlib/starext"
    13  	"github.com/emcfarlane/larking/starlib/starlarkthread"
    14  	"github.com/go-logr/logr"
    15  	"go.starlark.net/starlark"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/metadata"
    19  	"google.golang.org/grpc/status"
    20  	"google.golang.org/protobuf/proto"
    21  	"google.golang.org/protobuf/reflect/protoreflect"
    22  	"google.golang.org/protobuf/types/dynamicpb"
    23  )
    24  
    25  func (m *Mux) String() string        { return "mux" }
    26  func (m *Mux) Type() string          { return "mux" }
    27  func (m *Mux) Freeze()               {} // immutable
    28  func (m *Mux) Truth() starlark.Bool  { return starlark.True }
    29  func (m *Mux) Hash() (uint32, error) { return 0, nil }
    30  
    31  type muxAttr func(m *Mux) starlark.Value
    32  
    33  var muxMethods = map[string]muxAttr{
    34  	"service": func(m *Mux) starlark.Value {
    35  		return starext.MakeMethod(m, "service", m.openStarlarkService)
    36  	},
    37  	"register_service": func(m *Mux) starlark.Value {
    38  		return starext.MakeMethod(m, "register", m.registerStarlarkService)
    39  	},
    40  }
    41  
    42  func (m *Mux) Attr(name string) (starlark.Value, error) {
    43  	if a := muxMethods[name]; a != nil {
    44  		return a(m), nil
    45  	}
    46  	return nil, nil
    47  }
    48  func (v *Mux) AttrNames() []string {
    49  	names := make([]string, 0, len(muxMethods))
    50  	for name := range muxMethods {
    51  		names = append(names, name)
    52  	}
    53  	sort.Strings(names)
    54  	return names
    55  }
    56  
    57  type StarlarkService struct {
    58  	mux  *Mux
    59  	name string
    60  }
    61  
    62  func (m *Mux) openStarlarkService(_ *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
    63  	var name string
    64  	if err := starlark.UnpackPositionalArgs(fnname, args, nil, 1, &name); err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	pfx := "/" + name
    69  	if state := m.loadState(); state != nil {
    70  		for method := range state.handlers {
    71  			if strings.HasPrefix(method, pfx) {
    72  				return &StarlarkService{
    73  					mux:  m,
    74  					name: name,
    75  				}, nil
    76  			}
    77  
    78  		}
    79  	}
    80  	return nil, status.Errorf(codes.NotFound, "unknown service: %s", name)
    81  }
    82  
    83  func starlarkUnimplemented(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
    84  	return nil, status.Errorf(codes.Unimplemented, "method %s not implemented", fnname)
    85  }
    86  
    87  func createStarlarkHandler(
    88  	parent *starlark.Thread,
    89  	fn starlark.Callable,
    90  	sd protoreflect.ServiceDescriptor,
    91  	md protoreflect.MethodDescriptor,
    92  ) *handler {
    93  
    94  	argsDesc := md.Input()
    95  	replyDesc := md.Output()
    96  
    97  	method := fmt.Sprintf("/%s/%s", sd.FullName(), md.Name())
    98  
    99  	isClientStream := md.IsStreamingClient()
   100  	isServerStream := md.IsStreamingServer()
   101  	if isClientStream || isServerStream {
   102  		//sd := &grpc.StreamDesc{
   103  		//	ServerStreams: md.IsStreamingServer(),
   104  		//	ClientStreams: md.IsStreamingClient(),
   105  		//}
   106  		info := &grpc.StreamServerInfo{
   107  			FullMethod:     method,
   108  			IsClientStream: isClientStream,
   109  			IsServerStream: isServerStream,
   110  		}
   111  
   112  		// TODO: check not mutated.
   113  		//globals := starlib.NewGlobals()
   114  
   115  		fn := func(_ interface{}, stream grpc.ServerStream) (err error) {
   116  			ctx := stream.Context()
   117  
   118  			args := dynamicpb.NewMessage(argsDesc)
   119  			//reply := dynamicpb.NewMessage(replyDesc)
   120  
   121  			if err := stream.RecvMsg(args); err != nil {
   122  				return err
   123  			}
   124  
   125  			if md, ok := metadata.FromIncomingContext(ctx); ok {
   126  				ctx = metadata.NewOutgoingContext(ctx, md)
   127  			}
   128  
   129  			// build thread
   130  			l := logr.FromContextOrDiscard(ctx)
   131  			thread := &starlark.Thread{
   132  				Name: parent.Name,
   133  				Print: func(_ *starlark.Thread, msg string) {
   134  					l.Info(msg, "thread", parent.Name)
   135  				},
   136  				Load: parent.Load,
   137  			}
   138  			starlarkthread.SetContext(thread, ctx)
   139  			close := starlarkthread.WithResourceStore(thread)
   140  			defer func() {
   141  				if cerr := close(); err == nil {
   142  					err = cerr
   143  				}
   144  			}()
   145  
   146  			// TODO: streams.
   147  			return fmt.Errorf("unimplemented")
   148  		}
   149  
   150  		h := func(opts *muxOptions, stream grpc.ServerStream) error {
   151  			return opts.stream(nil, stream, info, fn)
   152  		}
   153  
   154  		return &handler{
   155  			method:     method,
   156  			descriptor: md,
   157  			handler:    h,
   158  		}
   159  	} else {
   160  		info := &grpc.UnaryServerInfo{
   161  			Server:     nil,
   162  			FullMethod: method,
   163  		}
   164  		fn := func(ctx context.Context, args interface{}) (reply interface{}, err error) {
   165  
   166  			if md, ok := metadata.FromIncomingContext(ctx); ok {
   167  				ctx = metadata.NewOutgoingContext(ctx, md)
   168  			}
   169  
   170  			l := logr.FromContextOrDiscard(ctx)
   171  			thread := &starlark.Thread{
   172  				Name: parent.Name,
   173  				Print: func(_ *starlark.Thread, msg string) {
   174  					l.Info(msg, "thread", parent.Name)
   175  				},
   176  				Load: parent.Load,
   177  			}
   178  			starlarkthread.SetContext(thread, ctx)
   179  			close := starlarkthread.WithResourceStore(thread)
   180  			defer func() {
   181  				if cerr := close(); err == nil {
   182  					err = cerr
   183  				}
   184  			}()
   185  
   186  			msg, ok := args.(proto.Message)
   187  			if !ok {
   188  				return nil, fmt.Errorf("expected proto message")
   189  			}
   190  
   191  			reqpb, err := starlarkproto.NewMessage(msg.ProtoReflect(), nil, nil)
   192  			if err != nil {
   193  				return nil, err
   194  			}
   195  
   196  			v, err := starlark.Call(thread, fn, starlark.Tuple{reqpb}, nil)
   197  			if err != nil {
   198  				return nil, err
   199  			}
   200  
   201  			rsppb, ok := v.(*starlarkproto.Message)
   202  			if !ok {
   203  				return nil, fmt.Errorf("expected \"proto.message\" received %q", v.Type())
   204  			}
   205  			rspMsg := rsppb.ProtoReflect()
   206  			// Compare FullName for multiple descriptor implementations.
   207  			if got, want := rspMsg.Descriptor().FullName(), replyDesc.FullName(); got != want {
   208  				return nil, fmt.Errorf("invalid response type %s, want %s", got, want)
   209  			}
   210  			return rspMsg.Interface(), nil
   211  		}
   212  		h := func(opts *muxOptions, stream grpc.ServerStream) error {
   213  			ctx := stream.Context()
   214  			args := dynamicpb.NewMessage(argsDesc)
   215  
   216  			if err := stream.RecvMsg(args); err != nil {
   217  				return err
   218  			}
   219  
   220  			reply, err := opts.unary(ctx, args, info, fn)
   221  			if err != nil {
   222  				return err
   223  			}
   224  			return stream.SendMsg(reply)
   225  		}
   226  
   227  		return &handler{
   228  			method:     method,
   229  			descriptor: md,
   230  			handler:    h,
   231  		}
   232  	}
   233  }
   234  
   235  func (m *Mux) registerStarlarkService(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   236  
   237  	var name string
   238  	if err := starlark.UnpackPositionalArgs(fnname, args, nil, 1, &name); err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	resolver := starlarkproto.GetProtodescResolver(thread)
   243  	desc, err := resolver.FindDescriptorByName(protoreflect.FullName(name))
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	sd, ok := desc.(protoreflect.ServiceDescriptor)
   249  	if !ok {
   250  		return nil, status.Errorf(codes.InvalidArgument, "%q must be a service descriptor", name)
   251  	}
   252  
   253  	mds := sd.Methods()
   254  
   255  	// for each key, assign the service function.
   256  
   257  	mms := make(map[string]starlark.Callable)
   258  
   259  	for _, kwarg := range kwargs {
   260  		k := string(kwarg[0].(starlark.String))
   261  		v := kwarg[1]
   262  
   263  		// Check
   264  		c, ok := v.(starlark.Callable)
   265  		if !ok {
   266  			return nil, status.Errorf(codes.InvalidArgument, "%s must be callable", k)
   267  		}
   268  		mms[k] = c
   269  	}
   270  
   271  	// Load the state for writing.
   272  	m.mu.Lock()
   273  	defer m.mu.Unlock()
   274  	s := m.loadState().clone()
   275  
   276  	for i, n := 0, mds.Len(); i < n; i++ {
   277  		md := mds.Get(i)
   278  		methodName := string(md.Name())
   279  
   280  		c, ok := mms[methodName]
   281  		if !ok {
   282  			c = starext.MakeMethod(m, methodName, starlarkUnimplemented)
   283  		}
   284  
   285  		opts := md.Options()
   286  
   287  		rule := getExtensionHTTP(opts)
   288  		if rule == nil {
   289  			continue
   290  		}
   291  		hd := createStarlarkHandler(thread, c, sd, md)
   292  		if err := s.appendHandler(rule, md, hd); err != nil {
   293  			return nil, err
   294  		}
   295  	}
   296  
   297  	m.storeState(s)
   298  	return starlark.None, nil
   299  }
   300  
   301  func (s *StarlarkService) String() string        { return s.name }
   302  func (s *StarlarkService) Type() string          { return "grpc.service" }
   303  func (s *StarlarkService) Freeze()               {} // immutable
   304  func (s *StarlarkService) Truth() starlark.Bool  { return starlark.True }
   305  func (s *StarlarkService) Hash() (uint32, error) { return 0, nil }
   306  
   307  // HasAttrs with each one being callable.
   308  func (s *StarlarkService) Attr(name string) (starlark.Value, error) {
   309  	m := "/" + s.name + "/" + name
   310  	hd, err := s.mux.loadState().pickMethodHandler(m)
   311  	if err != nil {
   312  		return nil, nil // swallow error, reports missing attr.
   313  	}
   314  
   315  	if hd.descriptor.IsStreamingClient() || hd.descriptor.IsStreamingServer() {
   316  		ss := &StarlarkStream{
   317  			mux: s.mux,
   318  			hd:  hd,
   319  		}
   320  
   321  		return ss, nil
   322  	}
   323  
   324  	return &StarlarkUnary{
   325  		mux: s.mux,
   326  		hd:  hd,
   327  	}, nil
   328  }
   329  func (s *StarlarkService) AttrNames() []string {
   330  	var attrs []string
   331  
   332  	pfx := "/" + s.name + "/"
   333  	for method := range s.mux.loadState().handlers {
   334  		if strings.HasPrefix(method, pfx) {
   335  			attrs = append(attrs, strings.TrimPrefix(method, pfx))
   336  		}
   337  	}
   338  	sort.Strings(attrs)
   339  	return attrs
   340  }
   341  
   342  type starlarkStream struct {
   343  	ctx        context.Context
   344  	method     string
   345  	sentHeader bool
   346  	header     metadata.MD
   347  	trailer    metadata.MD
   348  	ins        chan func(proto.Message) error
   349  	outs       chan func(proto.Message) error
   350  }
   351  
   352  func (s *starlarkStream) SetHeader(md metadata.MD) error {
   353  	if !s.sentHeader {
   354  		s.header = metadata.Join(s.header, md)
   355  	}
   356  	return nil
   357  
   358  }
   359  func (s *starlarkStream) SendHeader(md metadata.MD) error {
   360  	if s.sentHeader {
   361  		return nil // already sent?
   362  	}
   363  	// TODO: write header?
   364  	s.sentHeader = true
   365  	return nil
   366  }
   367  
   368  func (s *starlarkStream) SetTrailer(md metadata.MD) {
   369  	s.sentHeader = true
   370  	s.trailer = metadata.Join(s.trailer, md)
   371  }
   372  
   373  func (s *starlarkStream) Context() context.Context {
   374  	ctx, _ := newIncomingContext(s.ctx, nil) // TODO: remove me?
   375  	sts := &serverTransportStream{s, s.method}
   376  	return grpc.NewContextWithServerTransportStream(ctx, sts)
   377  }
   378  
   379  func (s *starlarkStream) SendMsg(m interface{}) error {
   380  	reply := m.(proto.Message)
   381  	select {
   382  	case fn := <-s.outs:
   383  		return fn(reply)
   384  	case <-s.ctx.Done():
   385  		return s.ctx.Err()
   386  	}
   387  }
   388  
   389  func (s *starlarkStream) RecvMsg(m interface{}) error {
   390  	args := m.(proto.Message)
   391  	//msg := args.ProtoReflect()
   392  
   393  	select {
   394  	case fn := <-s.ins:
   395  		return fn(args)
   396  	case <-s.ctx.Done():
   397  		return s.ctx.Err()
   398  	}
   399  
   400  }
   401  
   402  type StarlarkUnary struct {
   403  	mux *Mux
   404  	hd  *handler
   405  }
   406  
   407  func (s *StarlarkUnary) String() string        { return s.hd.method }
   408  func (s *StarlarkUnary) Type() string          { return "grpc.unary_method" }
   409  func (s *StarlarkUnary) Freeze()               {} // immutable
   410  func (s *StarlarkUnary) Truth() starlark.Bool  { return starlark.True }
   411  func (s *StarlarkUnary) Hash() (uint32, error) { return 0, nil }
   412  func (s *StarlarkUnary) Name() string          { return "" }
   413  func (s *StarlarkUnary) CallInternal(thread *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   414  	ctx := starlarkthread.GetContext(thread)
   415  	opts := &s.mux.opts
   416  
   417  	// Buffer channels one message for unary.
   418  	stream := &starlarkStream{
   419  		ctx:    ctx,
   420  		method: s.hd.method,
   421  		ins:    make(chan func(proto.Message) error, 1),
   422  		outs:   make(chan func(proto.Message) error, 1),
   423  	}
   424  
   425  	stream.ins <- func(msg proto.Message) error {
   426  		arg := msg.ProtoReflect()
   427  
   428  		// Capture starlark arguments.
   429  		_, err := starlarkproto.NewMessage(arg, args, kwargs)
   430  		return err
   431  	}
   432  
   433  	var rsp *starlarkproto.Message
   434  	stream.outs <- func(msg proto.Message) error {
   435  		arg := msg.ProtoReflect()
   436  
   437  		val, err := starlarkproto.NewMessage(arg, nil, nil)
   438  		rsp = val
   439  		return err
   440  	}
   441  
   442  	if err := s.hd.handler(opts, stream); err != nil {
   443  		return nil, err
   444  	}
   445  	return rsp, nil
   446  }
   447  
   448  type StarlarkStream struct {
   449  	mux *Mux
   450  	hd  *handler
   451  
   452  	once   sync.Once
   453  	cancel func()
   454  	stream *starlarkStream
   455  
   456  	onceErr sync.Once
   457  	err     error
   458  }
   459  
   460  func (s *StarlarkStream) setErr(err error) {
   461  	s.onceErr.Do(func() { s.err = err })
   462  }
   463  func (s *StarlarkStream) getErr() error {
   464  	s.setErr(nil) // blow away onceErr
   465  	return s.err
   466  }
   467  
   468  // init lazy initializes the streaming handler.
   469  func (s *StarlarkStream) init(thread *starlark.Thread) error {
   470  	ctx := starlarkthread.GetContext(thread)
   471  	opts := &s.mux.opts
   472  
   473  	s.once.Do(func() {
   474  		if err := starlarkthread.AddResource(thread, s); err != nil {
   475  			s.setErr(err)
   476  			return
   477  		}
   478  
   479  		ctx, cancel := context.WithCancel(ctx)
   480  		s.cancel = cancel
   481  		s.stream = &starlarkStream{
   482  			ctx:    ctx,
   483  			method: s.hd.method,
   484  			ins:    make(chan func(proto.Message) error),
   485  			outs:   make(chan func(proto.Message) error),
   486  		}
   487  
   488  		// Start the handler
   489  		go func() {
   490  			s.onceErr.Do(func() {
   491  				s.err = s.hd.handler(opts, s.stream)
   492  			})
   493  			cancel()
   494  		}()
   495  	})
   496  	if s.stream == nil || s.stream.ctx.Err() != nil {
   497  		return io.EOF // cancelled before starting or cancelled
   498  	}
   499  	return nil
   500  }
   501  
   502  func (s *StarlarkStream) String() string        { return s.hd.method }
   503  func (s *StarlarkStream) Type() string          { return "grpc.stream_method" }
   504  func (s *StarlarkStream) Freeze()               {} // immutable???
   505  func (s *StarlarkStream) Truth() starlark.Bool  { return starlark.True }
   506  func (s *StarlarkStream) Hash() (uint32, error) { return 0, nil }
   507  func (s *StarlarkStream) Name() string          { return "" }
   508  
   509  func (s *StarlarkStream) Close() error {
   510  	s.once.Do(func() {}) // blow the once away
   511  	if s.cancel == nil {
   512  		return nil // never started
   513  	}
   514  	s.cancel()
   515  	return s.getErr()
   516  }
   517  
   518  func (s *StarlarkStream) Attr(name string) (starlark.Value, error) {
   519  	if a := starlarkStreamAttrs[name]; a != nil {
   520  		return a(s), nil
   521  	}
   522  	return nil, nil
   523  }
   524  func (v *StarlarkStream) AttrNames() []string {
   525  	names := make([]string, 0, len(starlarkStreamAttrs))
   526  	for name := range starlarkStreamAttrs {
   527  		names = append(names, name)
   528  	}
   529  	sort.Strings(names)
   530  	return names
   531  }
   532  
   533  type starlarkStreamAttr func(*StarlarkStream) starlark.Value
   534  
   535  var starlarkStreamAttrs = map[string]starlarkStreamAttr{
   536  	"recv": func(s *StarlarkStream) starlark.Value {
   537  		return starext.MakeMethod(s, "recv", s.recv)
   538  	},
   539  	"send": func(s *StarlarkStream) starlark.Value {
   540  		return starext.MakeMethod(s, "send", s.send)
   541  	},
   542  }
   543  
   544  type starlarkResponse struct {
   545  	val starlark.Value
   546  	err error
   547  }
   548  
   549  func promiseResponse(
   550  	ctx context.Context, args starlark.Tuple, kwargs []starlark.Tuple,
   551  ) (func(proto.Message) error, <-chan starlarkResponse) {
   552  	ch := make(chan starlarkResponse)
   553  
   554  	return func(msg proto.Message) error {
   555  		arg := msg.ProtoReflect()
   556  
   557  		val, err := starlarkproto.NewMessage(arg, args, kwargs)
   558  		select {
   559  		case ch <- starlarkResponse{val: val, err: err}:
   560  			return err
   561  		case <-ctx.Done():
   562  			return ctx.Err()
   563  		}
   564  	}, ch
   565  }
   566  
   567  func (s *StarlarkStream) recv(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   568  	ctx := starlarkthread.GetContext(thread)
   569  	if err := s.init(thread); err != nil {
   570  		return nil, err
   571  	}
   572  
   573  	if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 0); err != nil {
   574  		return nil, err
   575  	}
   576  
   577  	fn, ch := promiseResponse(ctx, nil, nil)
   578  
   579  	select {
   580  	case <-ctx.Done():
   581  		return nil, ctx.Err()
   582  	case <-s.stream.ctx.Done():
   583  		return nil, s.getErr()
   584  	case s.stream.outs <- fn:
   585  		rsp := <-ch
   586  		return rsp.val, rsp.err
   587  	}
   588  }
   589  func (s *StarlarkStream) send(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   590  	ctx := starlarkthread.GetContext(thread)
   591  	if err := s.init(thread); err != nil {
   592  		return nil, err
   593  	}
   594  
   595  	fn, ch := promiseResponse(ctx, args, kwargs)
   596  
   597  	select {
   598  	case <-ctx.Done():
   599  		return nil, ctx.Err()
   600  	case <-s.stream.ctx.Done():
   601  		return nil, s.getErr()
   602  	case s.stream.ins <- fn:
   603  		rsp := <-ch
   604  		return starlark.None, rsp.err
   605  	}
   606  }