github.com/xzzpig/headscale-manager@v1.3.3/db/db.go (about) 1 package db 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/xzzpig/headscale-manager/config" 8 "go.mongodb.org/mongo-driver/bson" 9 "go.mongodb.org/mongo-driver/bson/primitive" 10 "go.mongodb.org/mongo-driver/mongo" 11 "go.mongodb.org/mongo-driver/mongo/options" 12 "go.uber.org/zap" 13 ) 14 15 type DB struct { 16 client *mongo.Client 17 timeout time.Duration 18 } 19 20 type Bsonable interface { 21 ToBson() *bson.M 22 } 23 24 type Saveable interface { 25 Bsonable 26 GetID() *string 27 } 28 29 var db *DB 30 var logger *zap.Logger 31 32 func Connect() { 33 logger = zap.L().Named("db") 34 timeout := config.GetConfig().Mongo.Timout 35 client, err := mongo.NewClient(options.Client().ApplyURI(config.GetConfig().Mongo.Uri)) 36 if err != nil { 37 logger.Panic("Failed to create mongo client", zap.Error(err)) 38 } 39 40 ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) 41 defer cancel() 42 err = client.Connect(ctx) 43 if err != nil { 44 logger.Panic("Failed to connect to mongo", zap.Error(err)) 45 } 46 47 //ping the database 48 err = client.Ping(ctx, nil) 49 if err != nil { 50 logger.Panic("Failed to ping mongo", zap.Error(err)) 51 } 52 53 logger.Info("Connected to mongo") 54 55 db = &DB{ 56 client: client, 57 timeout: time.Duration(timeout) * time.Second, 58 } 59 } 60 61 func Get() *DB { 62 return db 63 } 64 65 func (db *DB) Collection(collectionName string) *mongo.Collection { 66 return db.client.Database("headscale").Collection(collectionName) 67 } 68 69 func Find[Type any](db *DB, collectionName string, filter interface{}, opts ...*options.FindOptions) ([]*Type, error) { 70 collection := db.Collection(collectionName) 71 ctx, cancel := context.WithTimeout(context.Background(), db.timeout) 72 defer cancel() 73 74 var objs []*Type 75 76 res, err := collection.Find(ctx, filter) 77 if err != nil { 78 logger.Error("Failed to find", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Error(err)) 79 return nil, err 80 } 81 defer res.Close(ctx) 82 for res.Next(ctx) { 83 var singleObj *Type 84 if err = res.Decode(&singleObj); err != nil { 85 logger.Error("Failed to decode", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Error(err)) 86 return nil, err 87 } 88 objs = append(objs, singleObj) 89 } 90 logger.Debug("Find", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Any("result", objs)) 91 92 return objs, err 93 } 94 95 func FindOne[Type any](db *DB, collectionName string, filter interface{}, opts ...*options.FindOneOptions) (*Type, error) { 96 collection := db.Collection(collectionName) 97 ctx, cancel := context.WithTimeout(context.Background(), db.timeout) 98 defer cancel() 99 100 var obj *Type 101 err := collection.FindOne(ctx, filter, opts...).Decode(&obj) 102 logger.Debug("FindOne", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Any("result", obj), zap.Error(err)) 103 return obj, err 104 } 105 106 func Save[Type Saveable](db *DB, collectionName string, obj Type) (*mongo.UpdateResult, error) { 107 collection := db.Collection(collectionName) 108 ctx, cancel := context.WithTimeout(context.Background(), db.timeout) 109 defer cancel() 110 id := obj.GetID() 111 if id != nil { 112 objId, err := primitive.ObjectIDFromHex(*id) 113 if err != nil { 114 logger.Error("Failed to parse id", zap.String("collection", collectionName), zap.String("id", *id), zap.Error(err)) 115 return nil, err 116 } 117 b := obj.ToBson() 118 res, err := collection.UpdateByID(ctx, objId, bson.M{"$set": b}, options.Update().SetUpsert(true)) 119 if err != nil { 120 logger.Error("Failed to update", zap.String("collection", collectionName), zap.Any("obj", b), zap.Error(err)) 121 return nil, err 122 } 123 logger.Debug("Update", zap.String("collection", collectionName), zap.Any("obj", b), zap.Any("result", res)) 124 return res, err 125 } else { 126 b := obj.ToBson() 127 res, err := collection.InsertOne(ctx, b) 128 if err != nil { 129 logger.Error("Failed to insert", zap.String("collection", collectionName), zap.Any("obj", b), zap.Error(err)) 130 return nil, err 131 } 132 logger.Debug("Insert", zap.String("collection", collectionName), zap.Any("obj", b), zap.Any("result", res)) 133 return &mongo.UpdateResult{UpsertedID: res.InsertedID}, nil 134 } 135 } 136 137 func Delete(db *DB, collectionName string, id string) (*mongo.DeleteResult, error) { 138 collection := db.Collection(collectionName) 139 ctx, cancel := context.WithTimeout(context.Background(), db.timeout) 140 defer cancel() 141 objId, err := primitive.ObjectIDFromHex(id) 142 if err != nil { 143 logger.Error("Failed to parse id", zap.String("collection", collectionName), zap.String("id", id), zap.Error(err)) 144 return nil, err 145 } 146 res, err := collection.DeleteOne(ctx, bson.M{"_id": objId}) 147 if err != nil { 148 logger.Error("Failed to delete", zap.String("collection", collectionName), zap.String("id", id), zap.Error(err)) 149 return nil, err 150 } 151 logger.Debug("Delete", zap.String("collection", collectionName), zap.String("id", id), zap.Any("result", res)) 152 return res, err 153 }