github.com/blend/go-sdk@v1.20220411.3/shardutil/shards.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package shardutil
     9  
    10  import (
    11  	"context"
    12  	"sync"
    13  
    14  	"github.com/blend/go-sdk/db"
    15  	"github.com/blend/go-sdk/ex"
    16  )
    17  
    18  // Shards handles communicating with many underlying databases at once.
    19  type Shards struct {
    20  	Connections []*db.Connection
    21  	Opts        []InvocationOption
    22  }
    23  
    24  // PartitionIndex returns a partition index for a given hashCode.
    25  func (s Shards) PartitionIndex(hashCode int) int {
    26  	return hashCode % len(s.Connections)
    27  }
    28  
    29  // PartitionOptions returns db.InvocationOptions for a given partition.
    30  func (s Shards) PartitionOptions(partitionIndex int, opts ...InvocationOption) []db.InvocationOption {
    31  	var invocationOpts []db.InvocationOption
    32  	for _, opt := range s.Opts {
    33  		invocationOpts = append(invocationOpts, opt(partitionIndex))
    34  	}
    35  	for _, opt := range opts {
    36  		invocationOpts = append(invocationOpts, opt(partitionIndex))
    37  	}
    38  	return invocationOpts
    39  }
    40  
    41  // InvokeAll invokes a given function asynchronously for each connection in the manager.
    42  func (s Shards) InvokeAll(ctx context.Context, action func(int, *db.Invocation) error, opts ...InvocationOption) error {
    43  	wg := new(sync.WaitGroup)
    44  	wg.Add(len(s.Connections))
    45  
    46  	errors := make(chan error, len(s.Connections))
    47  	for index := 0; index < len(s.Connections); index++ {
    48  		go func(partitionIndex int) {
    49  			defer func() {
    50  				if r := recover(); r != nil {
    51  					errors <- ex.New(r)
    52  				}
    53  				wg.Done()
    54  			}()
    55  
    56  			invocation := s.Connections[partitionIndex].Invoke(
    57  				append(s.PartitionOptions(partitionIndex, opts...), db.OptContext(ctx))...,
    58  			)
    59  			if err := action(partitionIndex, invocation); err != nil {
    60  				errors <- err
    61  			}
    62  		}(index)
    63  	}
    64  
    65  	wg.Wait()
    66  	if len(errors) > 0 {
    67  		return <-errors
    68  	}
    69  	return nil
    70  }
    71  
    72  // InvokeOne creates a new db invocation routed to an underlying connection mapped by a given hashcode.
    73  // The underlying connection is determined by `PartitionIndex(hashCode)`.
    74  // The options are special parameterized versions of normal `db.InvocationOptions` that also take a partition index.
    75  // The returned invocation will map to only (1) underlying connection.
    76  func (s Shards) InvokeOne(ctx context.Context, hashCode int, opts ...InvocationOption) *db.Invocation {
    77  	partitionIndex := s.PartitionIndex(hashCode)
    78  	return s.Connections[partitionIndex].Invoke(append(s.PartitionOptions(partitionIndex, opts...), db.OptContext(ctx))...)
    79  }