github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testserver/cluster.go (about)

     1  package testserver
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strconv"
     8  	"strings"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/authzed/consistent"
    14  	"github.com/cespare/xxhash/v2"
    15  	"github.com/stretchr/testify/require"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/backoff"
    18  	"google.golang.org/grpc/balancer"
    19  	"google.golang.org/grpc/resolver"
    20  
    21  	combineddispatch "github.com/authzed/spicedb/internal/dispatch/combined"
    22  	"github.com/authzed/spicedb/pkg/cmd/server"
    23  	"github.com/authzed/spicedb/pkg/cmd/util"
    24  	"github.com/authzed/spicedb/pkg/datastore"
    25  	"github.com/authzed/spicedb/pkg/secrets"
    26  )
    27  
    28  const TestResolverScheme = "test"
    29  
    30  type TempError struct{}
    31  
    32  func (t TempError) Error() string {
    33  	return "no dialers yet"
    34  }
    35  
    36  func (t TempError) Temporary() bool {
    37  	return true
    38  }
    39  
    40  type dialerFunc func(ctx context.Context, s string) (net.Conn, error)
    41  
    42  // track prefixes used for making test clusters to avoid registering the same
    43  // prometheus subsystem twice in one run
    44  var usedPrefixes sync.Map
    45  
    46  func getPrefix(t testing.TB) string {
    47  	for {
    48  		prefix, err := secrets.TokenHex(8)
    49  		require.NoError(t, err)
    50  		if _, ok := usedPrefixes.Load(prefix); !ok {
    51  			usedPrefixes.Store(prefix, struct{}{})
    52  			return prefix
    53  		}
    54  	}
    55  }
    56  
    57  var testResolverBuilder = &SafeManualResolverBuilder{}
    58  
    59  func init() {
    60  	// register hashring balancer
    61  	balancer.Register(consistent.NewBuilder(xxhash.Sum64))
    62  
    63  	// Register a manual resolver.Builder  that we can feed addresses for tests
    64  	// Registration is not thread safe, so we register a single resolver.Builder
    65  	// to handle all clusters, rather than registering a unique resolver.Builder
    66  	// per cluster.
    67  	resolver.Register(testResolverBuilder)
    68  }
    69  
    70  // SafeManualResolverBuilder is a resolver builder that builds SafeManualResolvers
    71  // it is similar to manual.Resolver in grpc, but is thread safe
    72  type SafeManualResolverBuilder struct {
    73  	resolvers sync.Map
    74  	addrs     sync.Map
    75  }
    76  
    77  func (b *SafeManualResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
    78  	if target.URL.Scheme != TestResolverScheme {
    79  		return nil, fmt.Errorf("test resolver builder only works with test:// addresses")
    80  	}
    81  	var addrs []resolver.Address
    82  	addrVal, ok := b.addrs.Load(target.URL.Hostname())
    83  	if !ok {
    84  		addrs = make([]resolver.Address, 0)
    85  	} else {
    86  		addrs = addrVal.([]resolver.Address)
    87  	}
    88  	r := &SafeManualResolver{
    89  		prefix: target.URL.Hostname(),
    90  		cc:     cc,
    91  		opts:   opts,
    92  		addrs:  addrs,
    93  	}
    94  	b.resolvers.Store(target.URL.Hostname(), r)
    95  	return r, nil
    96  }
    97  
    98  func (b *SafeManualResolverBuilder) Scheme() string {
    99  	return "test"
   100  }
   101  
   102  func (b *SafeManualResolverBuilder) SetAddrs(prefix string, addrs []resolver.Address) {
   103  	b.addrs.Store(prefix, addrs)
   104  }
   105  
   106  func (b *SafeManualResolverBuilder) ResolveNow(prefix string) {
   107  	r, ok := b.resolvers.Load(prefix)
   108  	if !ok {
   109  		fmt.Println("NO RESOLVER YET") // shouldn't happen, but log
   110  		return
   111  	}
   112  	r.(*SafeManualResolver).ResolveNow(resolver.ResolveNowOptions{})
   113  }
   114  
   115  // SafeManualResolver is the resolver type that SafeManualResolverBuilder builds
   116  // it returns a static list of addresses
   117  type SafeManualResolver struct {
   118  	prefix string
   119  	cc     resolver.ClientConn
   120  	opts   resolver.BuildOptions
   121  	addrs  []resolver.Address
   122  }
   123  
   124  // ResolveNow implements the resolver.Resolver interface
   125  // It sends the static list of addresses to the underlying resolver.ClientConn
   126  func (r *SafeManualResolver) ResolveNow(_ resolver.ResolveNowOptions) {
   127  	if r.cc == nil {
   128  		return
   129  	}
   130  	if err := r.cc.UpdateState(resolver.State{Addresses: r.addrs}); err != nil {
   131  		fmt.Println("ERROR UPDATING STATE", err) // shouldn't happen, log
   132  	}
   133  }
   134  
   135  // Close implements the resolver.Resolver interface
   136  func (r *SafeManualResolver) Close() {}
   137  
   138  // TestClusterWithDispatch creates a cluster with `size` nodes
   139  // The cluster has a real dispatch stack that uses bufconn grpc connections
   140  func TestClusterWithDispatch(t testing.TB, size uint, ds datastore.Datastore) ([]*grpc.ClientConn, func()) {
   141  	return TestClusterWithDispatchAndCacheConfig(t, size, ds)
   142  }
   143  
   144  // TestClusterWithDispatchAndCacheConfig creates a cluster with `size` nodes and with cache toggled.
   145  func TestClusterWithDispatchAndCacheConfig(t testing.TB, size uint, ds datastore.Datastore) ([]*grpc.ClientConn, func()) {
   146  	// each cluster gets a unique prefix since grpc resolution is process-global
   147  	prefix := getPrefix(t)
   148  
   149  	// make placeholder resolved addresses, 1 per node
   150  	addresses := make([]resolver.Address, 0, size)
   151  	for i := uint(0); i < size; i++ {
   152  		addresses = append(addresses, resolver.Address{
   153  			Addr:       fmt.Sprintf("%s_%d", prefix, i),
   154  			ServerName: "",
   155  		})
   156  	}
   157  	testResolverBuilder.SetAddrs(prefix, addresses)
   158  
   159  	dialers := make([]dialerFunc, 0, size)
   160  	conns := make([]*grpc.ClientConn, 0, size)
   161  	cancelFuncs := make([]func(), 0, size)
   162  
   163  	for i := uint(0); i < size; i++ {
   164  		dispatcherOptions := []combineddispatch.Option{
   165  			combineddispatch.UpstreamAddr("test://" + prefix),
   166  			combineddispatch.PrometheusSubsystem(fmt.Sprintf("%s_%d_client_dispatch", prefix, i)),
   167  			combineddispatch.GrpcDialOpts(
   168  				grpc.WithDefaultServiceConfig(
   169  					(&consistent.BalancerConfig{
   170  						ReplicationFactor: 1500,
   171  						Spread:            1,
   172  					}).MustServiceConfigJSON()),
   173  				grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
   174  					// it's possible grpc tries to dial before we have set the
   175  					// buffconn dialers, we have to return a "TempError" so that
   176  					// grpc knows to retry the connection.
   177  					if len(dialers) == 0 {
   178  						return nil, TempError{}
   179  					}
   180  					// "s" here will be the address from the manual resolver
   181  					// like `<prefix>_<node number>`
   182  					i, err := strconv.Atoi(strings.TrimPrefix(s, prefix+"_"))
   183  					require.NoError(t, err)
   184  					return dialers[i](ctx, s)
   185  				}),
   186  			),
   187  		}
   188  
   189  		dispatcher, err := combineddispatch.NewDispatcher(dispatcherOptions...)
   190  		require.NoError(t, err)
   191  
   192  		serverOptions := []server.ConfigOption{
   193  			server.WithDatastore(ds),
   194  			server.WithDispatcher(dispatcher),
   195  			server.WithDispatchMaxDepth(50),
   196  			server.WithMaximumPreconditionCount(1000),
   197  			server.WithMaximumUpdatesPerWrite(1000),
   198  			server.WithGRPCServer(util.GRPCServerConfig{
   199  				Network: util.BufferedNetwork,
   200  				Enabled: true,
   201  			}),
   202  			server.WithMaxRelationshipContextSize(25000),
   203  			server.WithSchemaPrefixesRequired(false),
   204  			server.WithGRPCAuthFunc(func(ctx context.Context) (context.Context, error) {
   205  				return ctx, nil
   206  			}),
   207  			server.WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: false}),
   208  			server.WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: false}),
   209  			server.WithDispatchServer(util.GRPCServerConfig{
   210  				Enabled: true,
   211  				Network: util.BufferedNetwork,
   212  			}),
   213  			server.WithDispatchClusterMetricsPrefix(fmt.Sprintf("%s_%d_dispatch", prefix, i)),
   214  		}
   215  
   216  		ctx, cancel := context.WithCancel(context.Background())
   217  		srv, err := server.NewConfigWithOptions(serverOptions...).Complete(ctx)
   218  		require.NoError(t, err)
   219  
   220  		go func() {
   221  			require.NoError(t, srv.Run(ctx))
   222  		}()
   223  		cancelFuncs = append(cancelFuncs, cancel)
   224  
   225  		dialers = append(dialers, srv.DispatchNetDialContext)
   226  		conn, err := srv.GRPCDialContext(ctx,
   227  			grpc.WithReturnConnectionError(),
   228  			grpc.WithBlock(),
   229  			grpc.WithConnectParams(grpc.ConnectParams{
   230  				Backoff: backoff.Config{
   231  					BaseDelay:  1 * time.Second,
   232  					Multiplier: 2,
   233  					MaxDelay:   15 * time.Second,
   234  				},
   235  			}))
   236  		require.NoError(t, err)
   237  		conns = append(conns, conn)
   238  	}
   239  
   240  	// resolve after dialers have been set to initialize connections
   241  	testResolverBuilder.ResolveNow(prefix)
   242  
   243  	return conns, func() {
   244  		for _, c := range conns {
   245  			require.NoError(t, c.Close())
   246  		}
   247  		for _, c := range cancelFuncs {
   248  			c()
   249  		}
   250  	}
   251  }