github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/worker/server.go (about)

     1  // Copyright 2022 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package worker
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"regexp"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/go-logr/logr"
    17  	"go.starlark.net/starlark"
    18  	"go.starlark.net/syntax"
    19  	"google.golang.org/grpc/codes"
    20  	"google.golang.org/grpc/metadata"
    21  	"google.golang.org/grpc/status"
    22  
    23  	"github.com/emcfarlane/larking/apipb/controlpb"
    24  	"github.com/emcfarlane/larking/apipb/workerpb"
    25  	"github.com/emcfarlane/larking/starlib"
    26  	"github.com/emcfarlane/larking/starlib/starlarkthread"
    27  	"github.com/emcfarlane/starlarkassert"
    28  )
    29  
    30  type loadFunc func(*starlark.Thread, string) (starlark.StringDict, error)
    31  
    32  type Server struct {
    33  	workerpb.UnimplementedWorkerServer
    34  	load    loadFunc
    35  	control controlpb.ControlClient
    36  	name    string
    37  }
    38  
    39  func NewServer(
    40  	load func(thread *starlark.Thread, module string) (starlark.StringDict, error),
    41  	control controlpb.ControlClient,
    42  	name string,
    43  ) *Server {
    44  	return &Server{
    45  		load:    load,
    46  		control: control,
    47  		name:    name,
    48  	}
    49  }
    50  
    51  func (s *Server) Load(thread *starlark.Thread, module string) (starlark.StringDict, error) {
    52  	if s.load == nil {
    53  		return nil, status.Error(
    54  			codes.Unavailable,
    55  			"module loading not avaliable",
    56  		)
    57  	}
    58  	return s.load(thread, module)
    59  }
    60  
    61  func (s *Server) authorize(ctx context.Context, op *controlpb.Operation) error {
    62  	req := &controlpb.CheckRequest{
    63  		Name:      s.name,
    64  		Operation: op,
    65  	}
    66  
    67  	rsp, err := s.control.Check(ctx, req)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	if s := rsp.Status; s != nil {
    72  		st := status.FromProto(s)
    73  		return st.Err()
    74  	}
    75  	return nil
    76  }
    77  
    78  var (
    79  	errMissingCredentials = status.Error(codes.Unauthenticated, "missing credentials")
    80  	errInvalidCredentials = status.Error(codes.Unauthenticated, "invalid credentials")
    81  )
    82  
    83  func extractCredentials(ctx context.Context) (*controlpb.Credentials, error) {
    84  	md, ok := metadata.FromIncomingContext(ctx)
    85  	if !ok {
    86  		return nil, status.Error(codes.InvalidArgument, "invalid metadata")
    87  	}
    88  
    89  	for _, hdrKey := range []string{"http-authorization", "authorization"} {
    90  		keys := md.Get(hdrKey)
    91  		if len(keys) == 0 {
    92  			continue
    93  		}
    94  		vals := strings.Split(keys[0], " ")
    95  		if len(vals) == 1 && len(vals[0]) == 0 {
    96  			continue
    97  		}
    98  		if len(vals) != 2 {
    99  			return nil, errMissingCredentials
   100  		}
   101  		val := vals[1]
   102  
   103  		switch strings.ToLower(vals[0]) {
   104  		case "bearer":
   105  			return &controlpb.Credentials{
   106  				Type: &controlpb.Credentials_Bearer{
   107  					Bearer: &controlpb.Credentials_BearerToken{
   108  						AccessToken: val,
   109  					},
   110  				},
   111  			}, nil
   112  
   113  		case "basic":
   114  			c, err := base64.StdEncoding.DecodeString(val)
   115  			if err != nil {
   116  				return nil, err
   117  			}
   118  			cs := string(c)
   119  			s := strings.IndexByte(cs, ':')
   120  			if s < 0 {
   121  				return nil, errMissingCredentials
   122  			}
   123  
   124  			return &controlpb.Credentials{
   125  				Type: &controlpb.Credentials_Basic{
   126  					Basic: &controlpb.Credentials_BasicAuth{
   127  						Username: cs[:s],
   128  						Password: cs[s+1:],
   129  					},
   130  				},
   131  			}, nil
   132  
   133  		default:
   134  			return nil, errInvalidCredentials
   135  		}
   136  	}
   137  	return &controlpb.Credentials{
   138  		Type: &controlpb.Credentials_Insecure{
   139  			Insecure: true,
   140  		},
   141  	}, nil
   142  }
   143  
   144  func soleExpr(f *syntax.File) syntax.Expr {
   145  	if len(f.Stmts) == 1 {
   146  		if stmt, ok := f.Stmts[0].(*syntax.ExprStmt); ok {
   147  			return stmt.X
   148  		}
   149  	}
   150  	return nil
   151  }
   152  
   153  // Create ServerStream...
   154  func (s *Server) RunOnThread(stream workerpb.Worker_RunOnThreadServer) (err error) {
   155  	ctx := stream.Context()
   156  	l := logr.FromContextOrDiscard(ctx)
   157  
   158  	cmd, err := stream.Recv()
   159  	if err != nil {
   160  		return err
   161  	}
   162  	l.Info("running on thread", "thread", cmd.Name)
   163  
   164  	creds, err := extractCredentials(ctx)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	op := &controlpb.Operation{
   170  		Name:        cmd.Name,
   171  		Credentials: creds,
   172  	}
   173  
   174  	if err := s.authorize(ctx, op); err != nil {
   175  		l.Error(err, "failed to authorize request", "name", cmd.Name)
   176  		return err
   177  	}
   178  
   179  	name := strings.TrimPrefix(cmd.Name, "thread/")
   180  
   181  	var buf bytes.Buffer
   182  	thread := &starlark.Thread{
   183  		Name: name,
   184  		Print: func(_ *starlark.Thread, msg string) {
   185  			buf.WriteString(msg) //nolint
   186  		},
   187  		Load: s.load,
   188  	}
   189  
   190  	starlarkthread.SetContext(thread, ctx)
   191  	cleanup := starlarkthread.WithResourceStore(thread)
   192  	defer func() {
   193  		if cerr := cleanup(); err == nil {
   194  			err = cerr
   195  		}
   196  	}()
   197  
   198  	globals := starlib.NewGlobals()
   199  	if name != "" {
   200  		if s.load == nil {
   201  			return status.Error(
   202  				codes.Unavailable,
   203  				"module loading not avaliable",
   204  			)
   205  		}
   206  		predeclared, err := s.load(thread, name)
   207  		if err != nil {
   208  			return err
   209  		}
   210  		for key, val := range predeclared {
   211  			globals[key] = val // copy thread values to globals
   212  		}
   213  		thread.Name = name
   214  	}
   215  
   216  	run := func(input string) error {
   217  		buf.Reset()
   218  		f, err := syntax.Parse(thread.Name, input, 0)
   219  		if err != nil {
   220  			return err
   221  		}
   222  
   223  		if expr := soleExpr(f); expr != nil {
   224  			// eval
   225  			v, err := starlark.EvalExpr(thread, expr, globals)
   226  			if err != nil {
   227  				return err
   228  			}
   229  
   230  			// print
   231  			if v != starlark.None {
   232  				buf.WriteString(v.String())
   233  			}
   234  		} else if err := starlark.ExecREPLChunk(f, thread, globals); err != nil {
   235  			return err
   236  		}
   237  		return nil
   238  	}
   239  
   240  	c := starlib.Completer{StringDict: globals}
   241  	for {
   242  		result := &workerpb.Result{}
   243  
   244  		switch v := cmd.Exec.(type) {
   245  		case *workerpb.Command_Input:
   246  			err := run(v.Input)
   247  			if err != nil {
   248  				l.Info("thread error", "err", err)
   249  			}
   250  			result.Result = &workerpb.Result_Output{
   251  				Output: &workerpb.Output{
   252  					Output: buf.String(),
   253  					Status: errorStatus(err).Proto(),
   254  				},
   255  			}
   256  
   257  		case *workerpb.Command_Complete:
   258  			completions := c.Complete(v.Complete)
   259  			result.Result = &workerpb.Result_Completion{
   260  				Completion: &workerpb.Completion{
   261  					Completions: completions,
   262  				},
   263  			}
   264  
   265  		case *workerpb.Command_Format:
   266  			b, err := Format(ctx, name, v.Format)
   267  			if err != nil {
   268  				l.Info("thread format error", "err", err)
   269  			}
   270  
   271  			result.Result = &workerpb.Result_Output{
   272  				Output: &workerpb.Output{
   273  					Output: string(b),
   274  					Status: errorStatus(err).Proto(),
   275  				},
   276  			}
   277  		}
   278  		if err = stream.Send(result); err != nil {
   279  			return err
   280  		}
   281  
   282  		cmd, err = stream.Recv()
   283  		if err != nil {
   284  			return err
   285  		}
   286  	}
   287  }
   288  
   289  func (s *Server) RunThread(ctx context.Context, req *workerpb.RunThreadRequest) (*workerpb.Output, error) {
   290  
   291  	l := logr.FromContextOrDiscard(ctx)
   292  	l.Info("running thread", "thread", req.Name)
   293  
   294  	creds, err := extractCredentials(ctx)
   295  	if err != nil {
   296  		return nil, err
   297  	}
   298  	op := &controlpb.Operation{
   299  		Name:        req.Name,
   300  		Credentials: creds,
   301  	}
   302  	if err := s.authorize(ctx, op); err != nil {
   303  		l.Error(err, "failed to authorize request", "name", req.Name)
   304  		return nil, err
   305  	}
   306  
   307  	name := strings.TrimPrefix(req.Name, "thread/")
   308  
   309  	var buf bytes.Buffer
   310  	thread := &starlark.Thread{
   311  		Name: name,
   312  		Print: func(_ *starlark.Thread, msg string) {
   313  			buf.WriteString(msg) //nolint
   314  		},
   315  		Load: s.load,
   316  	}
   317  
   318  	starlarkthread.SetContext(thread, ctx)
   319  	cleanup := starlarkthread.WithResourceStore(thread)
   320  	defer func() {
   321  		if cerr := cleanup(); err == nil {
   322  			err = cerr
   323  		}
   324  	}()
   325  
   326  	if name == "" {
   327  		return nil, status.Error(
   328  			codes.InvalidArgument,
   329  			"missing module name",
   330  		)
   331  	}
   332  	if _, err := s.Load(thread, name); err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	return &workerpb.Output{
   337  		Output: buf.String(),
   338  		Status: errorStatus(err).Proto(),
   339  	}, nil
   340  
   341  }
   342  func (s *Server) TestThread(ctx context.Context, req *workerpb.TestThreadRequest) (*workerpb.Output, error) {
   343  	l := logr.FromContextOrDiscard(ctx)
   344  	l.Info("testing thread", "thread", req.Name)
   345  
   346  	creds, err := extractCredentials(ctx)
   347  	if err != nil {
   348  		return nil, err
   349  	}
   350  	op := &controlpb.Operation{
   351  		Name:        req.Name,
   352  		Credentials: creds,
   353  	}
   354  	if err := s.authorize(ctx, op); err != nil {
   355  		l.Error(err, "failed to authorize request", "name", req.Name)
   356  		return nil, err
   357  	}
   358  
   359  	name := strings.TrimPrefix(req.Name, "thread/")
   360  
   361  	var buf bytes.Buffer
   362  	thread := &starlark.Thread{
   363  		Name: name,
   364  		Print: func(_ *starlark.Thread, msg string) {
   365  			buf.WriteString(msg) //nolint
   366  		},
   367  		Load: s.load,
   368  	}
   369  	values, err := s.Load(thread, name)
   370  	if err != nil {
   371  		return nil, err
   372  	}
   373  
   374  	errorf := func(err error) {
   375  		switch err := err.(type) {
   376  		case *starlark.EvalError:
   377  			var found bool
   378  			for i := range err.CallStack {
   379  				posn := err.CallStack.At(i).Pos
   380  				if posn.Filename() == name {
   381  					linenum := int(posn.Line)
   382  					msg := err.Error()
   383  
   384  					fmt.Fprintf(&buf, "\n%s:%d: unexpected error: %v", name, linenum, msg)
   385  					found = true
   386  					break
   387  				}
   388  			}
   389  			if !found {
   390  				fmt.Fprint(&buf, err.Backtrace()) //nolint
   391  			}
   392  		case nil:
   393  			// success
   394  		default:
   395  			fmt.Fprintf(&buf, "\n%s", err) //nolint
   396  		}
   397  	}
   398  
   399  	tests := []testing.InternalTest{{
   400  		Name: name,
   401  		F: func(t *testing.T) {
   402  			for key, val := range values {
   403  				if !strings.HasPrefix(key, "test_") {
   404  					continue // ignore
   405  				}
   406  				if _, ok := val.(starlark.Callable); !ok {
   407  					continue // ignore non callable
   408  				}
   409  
   410  				key, val := key, val
   411  				t.Run(key, func(t *testing.T) {
   412  					tt := starlarkassert.NewTest(t)
   413  					if _, err := starlark.Call(
   414  						thread, val, starlark.Tuple{tt}, nil,
   415  					); err != nil {
   416  						errorf(err)
   417  					}
   418  				})
   419  			}
   420  
   421  		},
   422  	}}
   423  
   424  	var (
   425  		matchPat string
   426  		matchRe  *regexp.Regexp
   427  	)
   428  	deps := starlarkassert.MatchStringOnly(
   429  		func(pat, str string) (result bool, err error) {
   430  			if matchRe == nil || matchPat != pat {
   431  				matchPat = pat
   432  				matchRe, err = regexp.Compile(matchPat)
   433  				if err != nil {
   434  					return
   435  				}
   436  			}
   437  			return matchRe.MatchString(str), nil
   438  		},
   439  	)
   440  	var result *status.Status
   441  	if testing.MainStart(deps, tests, nil, nil, nil).Run() > 0 {
   442  		result = status.New(
   443  			codes.Unknown, // TODO: error code.
   444  			"failed",
   445  		)
   446  	} else {
   447  		result = status.New(
   448  			codes.OK,
   449  			"passed",
   450  		)
   451  	}
   452  
   453  	return &workerpb.Output{
   454  		Output: buf.String(),
   455  		Status: result.Proto(),
   456  	}, nil
   457  }