github.com/xzzpig/headscale-manager@v1.3.3/db/db.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/xzzpig/headscale-manager/config"
     8  	"go.mongodb.org/mongo-driver/bson"
     9  	"go.mongodb.org/mongo-driver/bson/primitive"
    10  	"go.mongodb.org/mongo-driver/mongo"
    11  	"go.mongodb.org/mongo-driver/mongo/options"
    12  	"go.uber.org/zap"
    13  )
    14  
    15  type DB struct {
    16  	client  *mongo.Client
    17  	timeout time.Duration
    18  }
    19  
    20  type Bsonable interface {
    21  	ToBson() *bson.M
    22  }
    23  
    24  type Saveable interface {
    25  	Bsonable
    26  	GetID() *string
    27  }
    28  
    29  var db *DB
    30  var logger *zap.Logger
    31  
    32  func Connect() {
    33  	logger = zap.L().Named("db")
    34  	timeout := config.GetConfig().Mongo.Timout
    35  	client, err := mongo.NewClient(options.Client().ApplyURI(config.GetConfig().Mongo.Uri))
    36  	if err != nil {
    37  		logger.Panic("Failed to create mongo client", zap.Error(err))
    38  	}
    39  
    40  	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
    41  	defer cancel()
    42  	err = client.Connect(ctx)
    43  	if err != nil {
    44  		logger.Panic("Failed to connect to mongo", zap.Error(err))
    45  	}
    46  
    47  	//ping the database
    48  	err = client.Ping(ctx, nil)
    49  	if err != nil {
    50  		logger.Panic("Failed to ping mongo", zap.Error(err))
    51  	}
    52  
    53  	logger.Info("Connected to mongo")
    54  
    55  	db = &DB{
    56  		client:  client,
    57  		timeout: time.Duration(timeout) * time.Second,
    58  	}
    59  }
    60  
    61  func Get() *DB {
    62  	return db
    63  }
    64  
    65  func (db *DB) Collection(collectionName string) *mongo.Collection {
    66  	return db.client.Database("headscale").Collection(collectionName)
    67  }
    68  
    69  func Find[Type any](db *DB, collectionName string, filter interface{}, opts ...*options.FindOptions) ([]*Type, error) {
    70  	collection := db.Collection(collectionName)
    71  	ctx, cancel := context.WithTimeout(context.Background(), db.timeout)
    72  	defer cancel()
    73  
    74  	var objs []*Type
    75  
    76  	res, err := collection.Find(ctx, filter)
    77  	if err != nil {
    78  		logger.Error("Failed to find", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Error(err))
    79  		return nil, err
    80  	}
    81  	defer res.Close(ctx)
    82  	for res.Next(ctx) {
    83  		var singleObj *Type
    84  		if err = res.Decode(&singleObj); err != nil {
    85  			logger.Error("Failed to decode", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Error(err))
    86  			return nil, err
    87  		}
    88  		objs = append(objs, singleObj)
    89  	}
    90  	logger.Debug("Find", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Any("result", objs))
    91  
    92  	return objs, err
    93  }
    94  
    95  func FindOne[Type any](db *DB, collectionName string, filter interface{}, opts ...*options.FindOneOptions) (*Type, error) {
    96  	collection := db.Collection(collectionName)
    97  	ctx, cancel := context.WithTimeout(context.Background(), db.timeout)
    98  	defer cancel()
    99  
   100  	var obj *Type
   101  	err := collection.FindOne(ctx, filter, opts...).Decode(&obj)
   102  	logger.Debug("FindOne", zap.String("collection", collectionName), zap.Any("filter", filter), zap.Any("result", obj), zap.Error(err))
   103  	return obj, err
   104  }
   105  
   106  func Save[Type Saveable](db *DB, collectionName string, obj Type) (*mongo.UpdateResult, error) {
   107  	collection := db.Collection(collectionName)
   108  	ctx, cancel := context.WithTimeout(context.Background(), db.timeout)
   109  	defer cancel()
   110  	id := obj.GetID()
   111  	if id != nil {
   112  		objId, err := primitive.ObjectIDFromHex(*id)
   113  		if err != nil {
   114  			logger.Error("Failed to parse id", zap.String("collection", collectionName), zap.String("id", *id), zap.Error(err))
   115  			return nil, err
   116  		}
   117  		b := obj.ToBson()
   118  		res, err := collection.UpdateByID(ctx, objId, bson.M{"$set": b}, options.Update().SetUpsert(true))
   119  		if err != nil {
   120  			logger.Error("Failed to update", zap.String("collection", collectionName), zap.Any("obj", b), zap.Error(err))
   121  			return nil, err
   122  		}
   123  		logger.Debug("Update", zap.String("collection", collectionName), zap.Any("obj", b), zap.Any("result", res))
   124  		return res, err
   125  	} else {
   126  		b := obj.ToBson()
   127  		res, err := collection.InsertOne(ctx, b)
   128  		if err != nil {
   129  			logger.Error("Failed to insert", zap.String("collection", collectionName), zap.Any("obj", b), zap.Error(err))
   130  			return nil, err
   131  		}
   132  		logger.Debug("Insert", zap.String("collection", collectionName), zap.Any("obj", b), zap.Any("result", res))
   133  		return &mongo.UpdateResult{UpsertedID: res.InsertedID}, nil
   134  	}
   135  }
   136  
   137  func Delete(db *DB, collectionName string, id string) (*mongo.DeleteResult, error) {
   138  	collection := db.Collection(collectionName)
   139  	ctx, cancel := context.WithTimeout(context.Background(), db.timeout)
   140  	defer cancel()
   141  	objId, err := primitive.ObjectIDFromHex(id)
   142  	if err != nil {
   143  		logger.Error("Failed to parse id", zap.String("collection", collectionName), zap.String("id", id), zap.Error(err))
   144  		return nil, err
   145  	}
   146  	res, err := collection.DeleteOne(ctx, bson.M{"_id": objId})
   147  	if err != nil {
   148  		logger.Error("Failed to delete", zap.String("collection", collectionName), zap.String("id", id), zap.Error(err))
   149  		return nil, err
   150  	}
   151  	logger.Debug("Delete", zap.String("collection", collectionName), zap.String("id", id), zap.Any("result", res))
   152  	return res, err
   153  }