github.com/xzzpig/headscale-manager@v1.3.3/service/loader/loader.go (about) 1 package loader 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync" 8 9 "github.com/graph-gophers/dataloader" 10 "github.com/xzzpig/headscale-manager/db" 11 "github.com/xzzpig/headscale-manager/graph/model" 12 "go.mongodb.org/mongo-driver/bson" 13 "go.mongodb.org/mongo-driver/bson/primitive" 14 "go.mongodb.org/mongo-driver/mongo" 15 ) 16 17 type Loaders struct { 18 MachineByID *dataloader.Loader 19 ProjectByID *dataloader.Loader 20 RouteByID *dataloader.Loader 21 22 loaderMap map[string]*dataloader.Loader 23 lock sync.Mutex 24 } 25 26 type ErrorObjNotFound struct { 27 ID string 28 } 29 30 func (e *ErrorObjNotFound) Error() string { 31 return fmt.Sprintf("obj %s not found by Loader", e.ID) 32 } 33 34 func NewObjNotFoundErr(id string) error { 35 err := &ErrorObjNotFound{ 36 ID: id, 37 } 38 return errors.Join(err, mongo.ErrNoDocuments) 39 } 40 41 func GetLoader[T model.HasID](l *Loaders, tableName string) *dataloader.Loader { 42 l.lock.Lock() 43 defer l.lock.Unlock() 44 loader, ok := l.loaderMap[tableName] 45 if !ok { 46 loader = dataloader.NewBatchedLoader(func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 47 return GetByIDsLoader(ctx, keys, func(id []string) ([]*T, error) { 48 objIds := make([]primitive.ObjectID, len(id)) 49 for i, v := range id { 50 objIds[i], _ = primitive.ObjectIDFromHex(v) 51 } 52 return db.Find[T](db.Get(), tableName, bson.M{"_id": bson.M{"$in": objIds}}) 53 }) 54 }) 55 l.loaderMap[tableName] = loader 56 } 57 return loader 58 } 59 60 // NewLoaders instantiates data loaders for the middleware 61 func NewLoaders() *Loaders { 62 loaders := &Loaders{ 63 // MachineByID: dataloader.NewBatchedLoader(MachineByIDsLoader), 64 // ProjectByID: dataloader.NewBatchedLoader(ProjectByIDsLoader), 65 // RouteByID: dataloader.NewBatchedLoader(RouteByIDsLoader), 66 loaderMap: map[string]*dataloader.Loader{}, 67 } 68 return loaders 69 } 70 71 func GetByIDsLoader[T model.HasID]( 72 ctx context.Context, 73 keys dataloader.Keys, 74 getByIDs func([]string) ([]*T, error)) []*dataloader.Result { 75 76 ids := make([]string, len(keys)) 77 for index, key := range keys { 78 ids[index] = key.String() 79 } 80 objs, err := getByIDs(ids) 81 if err != nil { 82 panic(err) 83 } 84 // return User records into a map by ID 85 objById := map[string]*T{} 86 for _, obj := range objs { 87 objById[*(*obj).GetID()] = obj 88 } 89 // return users in the same order requested 90 output := make([]*dataloader.Result, len(keys)) 91 for index, key := range keys { 92 obj, ok := objById[key.String()] 93 if ok { 94 output[index] = &dataloader.Result{Data: obj, Error: nil} 95 } else { 96 err := NewObjNotFoundErr(key.String()) 97 output[index] = &dataloader.Result{Data: nil, Error: err} 98 } 99 } 100 return output 101 } 102 103 // func MachineByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 104 // return GetByIDsLoader(ctx, keys, db.Get().MachineByIDs) 105 // } 106 107 // func ProjectByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 108 // return GetByIDsLoader(ctx, keys, db.Get().ProjectByIDs) 109 // } 110 111 // func RouteByIDsLoader(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 112 // return GetByIDsLoader(ctx, keys, db.Get().RouteByIDs) 113 // }