github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/worker/server.go (about) 1 // Copyright 2022 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package worker 6 7 import ( 8 "bytes" 9 "context" 10 "encoding/base64" 11 "fmt" 12 "regexp" 13 "strings" 14 "testing" 15 16 "github.com/go-logr/logr" 17 "go.starlark.net/starlark" 18 "go.starlark.net/syntax" 19 "google.golang.org/grpc/codes" 20 "google.golang.org/grpc/metadata" 21 "google.golang.org/grpc/status" 22 23 "github.com/emcfarlane/larking/apipb/controlpb" 24 "github.com/emcfarlane/larking/apipb/workerpb" 25 "github.com/emcfarlane/larking/starlib" 26 "github.com/emcfarlane/larking/starlib/starlarkthread" 27 "github.com/emcfarlane/starlarkassert" 28 ) 29 30 type loadFunc func(*starlark.Thread, string) (starlark.StringDict, error) 31 32 type Server struct { 33 workerpb.UnimplementedWorkerServer 34 load loadFunc 35 control controlpb.ControlClient 36 name string 37 } 38 39 func NewServer( 40 load func(thread *starlark.Thread, module string) (starlark.StringDict, error), 41 control controlpb.ControlClient, 42 name string, 43 ) *Server { 44 return &Server{ 45 load: load, 46 control: control, 47 name: name, 48 } 49 } 50 51 func (s *Server) Load(thread *starlark.Thread, module string) (starlark.StringDict, error) { 52 if s.load == nil { 53 return nil, status.Error( 54 codes.Unavailable, 55 "module loading not avaliable", 56 ) 57 } 58 return s.load(thread, module) 59 } 60 61 func (s *Server) authorize(ctx context.Context, op *controlpb.Operation) error { 62 req := &controlpb.CheckRequest{ 63 Name: s.name, 64 Operation: op, 65 } 66 67 rsp, err := s.control.Check(ctx, req) 68 if err != nil { 69 return err 70 } 71 if s := rsp.Status; s != nil { 72 st := status.FromProto(s) 73 return st.Err() 74 } 75 return nil 76 } 77 78 var ( 79 errMissingCredentials = status.Error(codes.Unauthenticated, "missing credentials") 80 errInvalidCredentials = status.Error(codes.Unauthenticated, "invalid credentials") 81 ) 82 83 func extractCredentials(ctx context.Context) (*controlpb.Credentials, error) { 84 md, ok := metadata.FromIncomingContext(ctx) 85 if !ok { 86 return nil, status.Error(codes.InvalidArgument, "invalid metadata") 87 } 88 89 for _, hdrKey := range []string{"http-authorization", "authorization"} { 90 keys := md.Get(hdrKey) 91 if len(keys) == 0 { 92 continue 93 } 94 vals := strings.Split(keys[0], " ") 95 if len(vals) == 1 && len(vals[0]) == 0 { 96 continue 97 } 98 if len(vals) != 2 { 99 return nil, errMissingCredentials 100 } 101 val := vals[1] 102 103 switch strings.ToLower(vals[0]) { 104 case "bearer": 105 return &controlpb.Credentials{ 106 Type: &controlpb.Credentials_Bearer{ 107 Bearer: &controlpb.Credentials_BearerToken{ 108 AccessToken: val, 109 }, 110 }, 111 }, nil 112 113 case "basic": 114 c, err := base64.StdEncoding.DecodeString(val) 115 if err != nil { 116 return nil, err 117 } 118 cs := string(c) 119 s := strings.IndexByte(cs, ':') 120 if s < 0 { 121 return nil, errMissingCredentials 122 } 123 124 return &controlpb.Credentials{ 125 Type: &controlpb.Credentials_Basic{ 126 Basic: &controlpb.Credentials_BasicAuth{ 127 Username: cs[:s], 128 Password: cs[s+1:], 129 }, 130 }, 131 }, nil 132 133 default: 134 return nil, errInvalidCredentials 135 } 136 } 137 return &controlpb.Credentials{ 138 Type: &controlpb.Credentials_Insecure{ 139 Insecure: true, 140 }, 141 }, nil 142 } 143 144 func soleExpr(f *syntax.File) syntax.Expr { 145 if len(f.Stmts) == 1 { 146 if stmt, ok := f.Stmts[0].(*syntax.ExprStmt); ok { 147 return stmt.X 148 } 149 } 150 return nil 151 } 152 153 // Create ServerStream... 154 func (s *Server) RunOnThread(stream workerpb.Worker_RunOnThreadServer) (err error) { 155 ctx := stream.Context() 156 l := logr.FromContextOrDiscard(ctx) 157 158 cmd, err := stream.Recv() 159 if err != nil { 160 return err 161 } 162 l.Info("running on thread", "thread", cmd.Name) 163 164 creds, err := extractCredentials(ctx) 165 if err != nil { 166 return err 167 } 168 169 op := &controlpb.Operation{ 170 Name: cmd.Name, 171 Credentials: creds, 172 } 173 174 if err := s.authorize(ctx, op); err != nil { 175 l.Error(err, "failed to authorize request", "name", cmd.Name) 176 return err 177 } 178 179 name := strings.TrimPrefix(cmd.Name, "thread/") 180 181 var buf bytes.Buffer 182 thread := &starlark.Thread{ 183 Name: name, 184 Print: func(_ *starlark.Thread, msg string) { 185 buf.WriteString(msg) //nolint 186 }, 187 Load: s.load, 188 } 189 190 starlarkthread.SetContext(thread, ctx) 191 cleanup := starlarkthread.WithResourceStore(thread) 192 defer func() { 193 if cerr := cleanup(); err == nil { 194 err = cerr 195 } 196 }() 197 198 globals := starlib.NewGlobals() 199 if name != "" { 200 if s.load == nil { 201 return status.Error( 202 codes.Unavailable, 203 "module loading not avaliable", 204 ) 205 } 206 predeclared, err := s.load(thread, name) 207 if err != nil { 208 return err 209 } 210 for key, val := range predeclared { 211 globals[key] = val // copy thread values to globals 212 } 213 thread.Name = name 214 } 215 216 run := func(input string) error { 217 buf.Reset() 218 f, err := syntax.Parse(thread.Name, input, 0) 219 if err != nil { 220 return err 221 } 222 223 if expr := soleExpr(f); expr != nil { 224 // eval 225 v, err := starlark.EvalExpr(thread, expr, globals) 226 if err != nil { 227 return err 228 } 229 230 // print 231 if v != starlark.None { 232 buf.WriteString(v.String()) 233 } 234 } else if err := starlark.ExecREPLChunk(f, thread, globals); err != nil { 235 return err 236 } 237 return nil 238 } 239 240 c := starlib.Completer{StringDict: globals} 241 for { 242 result := &workerpb.Result{} 243 244 switch v := cmd.Exec.(type) { 245 case *workerpb.Command_Input: 246 err := run(v.Input) 247 if err != nil { 248 l.Info("thread error", "err", err) 249 } 250 result.Result = &workerpb.Result_Output{ 251 Output: &workerpb.Output{ 252 Output: buf.String(), 253 Status: errorStatus(err).Proto(), 254 }, 255 } 256 257 case *workerpb.Command_Complete: 258 completions := c.Complete(v.Complete) 259 result.Result = &workerpb.Result_Completion{ 260 Completion: &workerpb.Completion{ 261 Completions: completions, 262 }, 263 } 264 265 case *workerpb.Command_Format: 266 b, err := Format(ctx, name, v.Format) 267 if err != nil { 268 l.Info("thread format error", "err", err) 269 } 270 271 result.Result = &workerpb.Result_Output{ 272 Output: &workerpb.Output{ 273 Output: string(b), 274 Status: errorStatus(err).Proto(), 275 }, 276 } 277 } 278 if err = stream.Send(result); err != nil { 279 return err 280 } 281 282 cmd, err = stream.Recv() 283 if err != nil { 284 return err 285 } 286 } 287 } 288 289 func (s *Server) RunThread(ctx context.Context, req *workerpb.RunThreadRequest) (*workerpb.Output, error) { 290 291 l := logr.FromContextOrDiscard(ctx) 292 l.Info("running thread", "thread", req.Name) 293 294 creds, err := extractCredentials(ctx) 295 if err != nil { 296 return nil, err 297 } 298 op := &controlpb.Operation{ 299 Name: req.Name, 300 Credentials: creds, 301 } 302 if err := s.authorize(ctx, op); err != nil { 303 l.Error(err, "failed to authorize request", "name", req.Name) 304 return nil, err 305 } 306 307 name := strings.TrimPrefix(req.Name, "thread/") 308 309 var buf bytes.Buffer 310 thread := &starlark.Thread{ 311 Name: name, 312 Print: func(_ *starlark.Thread, msg string) { 313 buf.WriteString(msg) //nolint 314 }, 315 Load: s.load, 316 } 317 318 starlarkthread.SetContext(thread, ctx) 319 cleanup := starlarkthread.WithResourceStore(thread) 320 defer func() { 321 if cerr := cleanup(); err == nil { 322 err = cerr 323 } 324 }() 325 326 if name == "" { 327 return nil, status.Error( 328 codes.InvalidArgument, 329 "missing module name", 330 ) 331 } 332 if _, err := s.Load(thread, name); err != nil { 333 return nil, err 334 } 335 336 return &workerpb.Output{ 337 Output: buf.String(), 338 Status: errorStatus(err).Proto(), 339 }, nil 340 341 } 342 func (s *Server) TestThread(ctx context.Context, req *workerpb.TestThreadRequest) (*workerpb.Output, error) { 343 l := logr.FromContextOrDiscard(ctx) 344 l.Info("testing thread", "thread", req.Name) 345 346 creds, err := extractCredentials(ctx) 347 if err != nil { 348 return nil, err 349 } 350 op := &controlpb.Operation{ 351 Name: req.Name, 352 Credentials: creds, 353 } 354 if err := s.authorize(ctx, op); err != nil { 355 l.Error(err, "failed to authorize request", "name", req.Name) 356 return nil, err 357 } 358 359 name := strings.TrimPrefix(req.Name, "thread/") 360 361 var buf bytes.Buffer 362 thread := &starlark.Thread{ 363 Name: name, 364 Print: func(_ *starlark.Thread, msg string) { 365 buf.WriteString(msg) //nolint 366 }, 367 Load: s.load, 368 } 369 values, err := s.Load(thread, name) 370 if err != nil { 371 return nil, err 372 } 373 374 errorf := func(err error) { 375 switch err := err.(type) { 376 case *starlark.EvalError: 377 var found bool 378 for i := range err.CallStack { 379 posn := err.CallStack.At(i).Pos 380 if posn.Filename() == name { 381 linenum := int(posn.Line) 382 msg := err.Error() 383 384 fmt.Fprintf(&buf, "\n%s:%d: unexpected error: %v", name, linenum, msg) 385 found = true 386 break 387 } 388 } 389 if !found { 390 fmt.Fprint(&buf, err.Backtrace()) //nolint 391 } 392 case nil: 393 // success 394 default: 395 fmt.Fprintf(&buf, "\n%s", err) //nolint 396 } 397 } 398 399 tests := []testing.InternalTest{{ 400 Name: name, 401 F: func(t *testing.T) { 402 for key, val := range values { 403 if !strings.HasPrefix(key, "test_") { 404 continue // ignore 405 } 406 if _, ok := val.(starlark.Callable); !ok { 407 continue // ignore non callable 408 } 409 410 key, val := key, val 411 t.Run(key, func(t *testing.T) { 412 tt := starlarkassert.NewTest(t) 413 if _, err := starlark.Call( 414 thread, val, starlark.Tuple{tt}, nil, 415 ); err != nil { 416 errorf(err) 417 } 418 }) 419 } 420 421 }, 422 }} 423 424 var ( 425 matchPat string 426 matchRe *regexp.Regexp 427 ) 428 deps := starlarkassert.MatchStringOnly( 429 func(pat, str string) (result bool, err error) { 430 if matchRe == nil || matchPat != pat { 431 matchPat = pat 432 matchRe, err = regexp.Compile(matchPat) 433 if err != nil { 434 return 435 } 436 } 437 return matchRe.MatchString(str), nil 438 }, 439 ) 440 var result *status.Status 441 if testing.MainStart(deps, tests, nil, nil, nil).Run() > 0 { 442 result = status.New( 443 codes.Unknown, // TODO: error code. 444 "failed", 445 ) 446 } else { 447 result = status.New( 448 codes.OK, 449 "passed", 450 ) 451 } 452 453 return &workerpb.Output{ 454 Output: buf.String(), 455 Status: result.Proto(), 456 }, nil 457 }