github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/worker/server_test.go (about)

     1  // Copyright 2022 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package worker_test
     6  
     7  import (
     8  	"context"
     9  	"net"
    10  	"os"
    11  	"testing"
    12  
    13  	"github.com/emcfarlane/larking/apipb/workerpb"
    14  	"github.com/emcfarlane/larking/control"
    15  	"github.com/emcfarlane/larking/worker"
    16  
    17  	"github.com/go-logr/logr"
    18  	testing_logr "github.com/go-logr/logr/testing"
    19  	"github.com/google/go-cmp/cmp"
    20  	starlarkmath "go.starlark.net/lib/math"
    21  	"go.starlark.net/starlark"
    22  	"golang.org/x/sync/errgroup"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials/insecure"
    25  	"google.golang.org/protobuf/testing/protocmp"
    26  )
    27  
    28  func testContext(t *testing.T) context.Context {
    29  	ctx := context.Background()
    30  	log := testing_logr.NewTestLogger(t)
    31  	ctx = logr.NewContext(ctx, log)
    32  	return ctx
    33  }
    34  
    35  func TestAPIServer(t *testing.T) {
    36  	//log := testing_logr.NewTestLogger(t)
    37  
    38  	workerServer := worker.NewServer(
    39  		func(_ *starlark.Thread, module string) (starlark.StringDict, error) {
    40  			if module == "math.star" {
    41  				return starlarkmath.Module.Members, nil
    42  			}
    43  			return nil, os.ErrNotExist
    44  		},
    45  		control.InsecureControlClient{},
    46  		"worker",
    47  	)
    48  
    49  	var opts []grpc.ServerOption
    50  	grpcServer := grpc.NewServer(opts...)
    51  	workerpb.RegisterWorkerServer(grpcServer, workerServer)
    52  
    53  	lis, err := net.Listen("tcp", "localhost:0")
    54  	if err != nil {
    55  		t.Fatalf("failed to listen: %v", err)
    56  	}
    57  	defer lis.Close()
    58  
    59  	var g errgroup.Group
    60  	defer func() {
    61  		if err := g.Wait(); err != nil {
    62  			t.Fatal(err)
    63  		}
    64  	}()
    65  	g.Go(func() error {
    66  		if err := grpcServer.Serve(lis); err != nil {
    67  			return err
    68  		}
    69  		return nil
    70  	})
    71  	defer grpcServer.GracefulStop()
    72  
    73  	// Create the client.
    74  	conn, err := grpc.Dial(
    75  		lis.Addr().String(),
    76  		grpc.WithTransportCredentials(insecure.NewCredentials()),
    77  	)
    78  	if err != nil {
    79  		t.Fatalf("cannot connect to server: %v", err)
    80  	}
    81  	defer conn.Close()
    82  
    83  	client := workerpb.NewWorkerClient(conn)
    84  
    85  	tests := []struct {
    86  		name string
    87  		ins  []*workerpb.Command
    88  		outs []*workerpb.Result
    89  	}{{
    90  		name: "fibonacci",
    91  		ins: []*workerpb.Command{{
    92  			Name: "",
    93  			Exec: &workerpb.Command_Input{
    94  				Input: `def fibonacci(n):
    95  	    res = list(range(n))
    96  	    for i in res[2:]:
    97  		res[i] = res[i-2] + res[i-1]
    98  	    return res
    99  `},
   100  		}, {
   101  			Exec: &workerpb.Command_Input{
   102  				Input: "fibonacci(10)\n",
   103  			},
   104  		}},
   105  		outs: []*workerpb.Result{{
   106  			Result: &workerpb.Result_Output{
   107  				Output: &workerpb.Output{
   108  					Output: "",
   109  				},
   110  			},
   111  		}, {
   112  			Result: &workerpb.Result_Output{
   113  				Output: &workerpb.Output{
   114  					Output: "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]",
   115  				},
   116  			},
   117  		}},
   118  	}, {
   119  		name: "load",
   120  		ins: []*workerpb.Command{{
   121  			Name: "",
   122  			Exec: &workerpb.Command_Input{
   123  				Input: `load("math.star", "pow")`,
   124  			},
   125  		}, {
   126  			Exec: &workerpb.Command_Input{
   127  				Input: "pow(2, 3)",
   128  			},
   129  		}},
   130  		outs: []*workerpb.Result{{
   131  			Result: &workerpb.Result_Output{
   132  				Output: &workerpb.Output{
   133  					Output: "",
   134  				},
   135  			},
   136  		}, {
   137  			Result: &workerpb.Result_Output{
   138  				Output: &workerpb.Output{
   139  					Output: "8.0",
   140  				},
   141  			},
   142  		}},
   143  	}}
   144  	cmpOpts := cmp.Options{protocmp.Transform()}
   145  
   146  	for _, tt := range tests {
   147  		t.Run(tt.name, func(t *testing.T) {
   148  			ctx := testContext(t)
   149  
   150  			if len(tt.ins) < len(tt.outs) {
   151  				t.Fatal("invalid args")
   152  			}
   153  
   154  			stream, err := client.RunOnThread(ctx)
   155  			if err != nil {
   156  				t.Fatal(err)
   157  			}
   158  
   159  			for i := 0; i < len(tt.ins); i++ {
   160  				in := tt.ins[i]
   161  				if err := stream.Send(in); err != nil {
   162  					t.Fatal(err)
   163  				}
   164  
   165  				out, err := stream.Recv()
   166  				if err != nil {
   167  					t.Fatal(err)
   168  				}
   169  				t.Logf("out: %v", out)
   170  
   171  				diff := cmp.Diff(out, tt.outs[i], cmpOpts...)
   172  				if diff != "" {
   173  					t.Error(diff)
   174  				}
   175  			}
   176  		})
   177  	}
   178  	//t.Logf("thread: %v", s.ls.threads["default"])
   179  }