github.com/xzzpig/headscale-manager@v1.3.3/service/loader/loader.go (about)

     1  package loader
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"github.com/graph-gophers/dataloader"
    10  	"github.com/xzzpig/headscale-manager/db"
    11  	"github.com/xzzpig/headscale-manager/graph/model"
    12  	"go.mongodb.org/mongo-driver/bson"
    13  	"go.mongodb.org/mongo-driver/bson/primitive"
    14  	"go.mongodb.org/mongo-driver/mongo"
    15  )
    16  
    17  type Loaders struct {
    18  	MachineByID *dataloader.Loader
    19  	ProjectByID *dataloader.Loader
    20  	RouteByID   *dataloader.Loader
    21  
    22  	loaderMap map[string]*dataloader.Loader
    23  	lock      sync.Mutex
    24  }
    25  
    26  type ErrorObjNotFound struct {
    27  	ID string
    28  }
    29  
    30  func (e *ErrorObjNotFound) Error() string {
    31  	return fmt.Sprintf("obj %s not found by Loader", e.ID)
    32  }
    33  
    34  func NewObjNotFoundErr(id string) error {
    35  	err := &ErrorObjNotFound{
    36  		ID: id,
    37  	}
    38  	return errors.Join(err, mongo.ErrNoDocuments)
    39  }
    40  
    41  func GetLoader[T model.HasID](l *Loaders, tableName string) *dataloader.Loader {
    42  	l.lock.Lock()
    43  	defer l.lock.Unlock()
    44  	loader, ok := l.loaderMap[tableName]
    45  	if !ok {
    46  		loader = dataloader.NewBatchedLoader(func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
    47  			return GetByIDsLoader(ctx, keys, func(id []string) ([]*T, error) {
    48  				objIds := make([]primitive.ObjectID, len(id))
    49  				for i, v := range id {
    50  					objIds[i], _ = primitive.ObjectIDFromHex(v)
    51  				}
    52  				return db.Find[T](db.Get(), tableName, bson.M{"_id": bson.M{"$in": objIds}})
    53  			})
    54  		})
    55  		l.loaderMap[tableName] = loader
    56  	}
    57  	return loader
    58  }
    59  
    60  // NewLoaders instantiates data loaders for the middleware
    61  func NewLoaders() *Loaders {
    62  	loaders := &Loaders{
    63  		// MachineByID: dataloader.NewBatchedLoader(MachineByIDsLoader),
    64  		// ProjectByID: dataloader.NewBatchedLoader(ProjectByIDsLoader),
    65  		// RouteByID:   dataloader.NewBatchedLoader(RouteByIDsLoader),
    66  		loaderMap: map[string]*dataloader.Loader{},
    67  	}
    68  	return loaders
    69  }
    70  
    71  func GetByIDsLoader[T model.HasID](
    72  	ctx context.Context,
    73  	keys dataloader.Keys,
    74  	getByIDs func([]string) ([]*T, error)) []*dataloader.Result {
    75  
    76  	ids := make([]string, len(keys))
    77  	for index, key := range keys {
    78  		ids[index] = key.String()
    79  	}
    80  	objs, err := getByIDs(ids)
    81  	if err != nil {
    82  		panic(err)
    83  	}
    84  	// return User records into a map by ID
    85  	objById := map[string]*T{}
    86  	for _, obj := range objs {
    87  		objById[*(*obj).GetID()] = obj
    88  	}
    89  	// return users in the same order requested
    90  	output := make([]*dataloader.Result, len(keys))
    91  	for index, key := range keys {
    92  		obj, ok := objById[key.String()]
    93  		if ok {
    94  			output[index] = &dataloader.Result{Data: obj, Error: nil}
    95  		} else {
    96  			err := NewObjNotFoundErr(key.String())
    97  			output[index] = &dataloader.Result{Data: nil, Error: err}
    98  		}
    99  	}
   100  	return output
   101  }
   102  
   103  // func MachineByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
   104  // 	return GetByIDsLoader(ctx, keys, db.Get().MachineByIDs)
   105  // }
   106  
   107  // func ProjectByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
   108  // 	return GetByIDsLoader(ctx, keys, db.Get().ProjectByIDs)
   109  // }
   110  
   111  // func RouteByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
   112  // 	return GetByIDsLoader(ctx, keys, db.Get().RouteByIDs)
   113  // }