github.heygears.com/openimsdk/tools@v0.0.49/db/mongoutil/tx.go (about)

     1  // Copyright © 2023 OpenIM. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mongoutil
    16  
    17  import (
    18  	"context"
    19  	"github.com/openimsdk/tools/db/tx"
    20  	"github.com/openimsdk/tools/errs"
    21  	"go.mongodb.org/mongo-driver/bson"
    22  	"go.mongodb.org/mongo-driver/mongo"
    23  )
    24  
    25  func NewMongoTx(ctx context.Context, client *mongo.Client) (tx.Tx, error) {
    26  	mtx := mongoTx{
    27  		client: client,
    28  	}
    29  	if err := mtx.init(ctx); err != nil {
    30  		return nil, err
    31  	}
    32  	return &mtx, nil
    33  }
    34  
    35  func NewMongo(client *mongo.Client) tx.Tx {
    36  	return &mongoTx{
    37  		client: client,
    38  	}
    39  }
    40  
    41  type mongoTx struct {
    42  	client *mongo.Client
    43  	tx     func(context.Context, func(ctx context.Context) error) error
    44  }
    45  
    46  func (m *mongoTx) init(ctx context.Context) error {
    47  	var res map[string]any
    48  	if err := m.client.Database("admin").RunCommand(ctx, bson.M{"isMaster": 1}).Decode(&res); err != nil {
    49  		return errs.WrapMsg(err, "check whether mongo is deployed in a cluster")
    50  	}
    51  	if _, allowTx := res["setName"]; !allowTx {
    52  		return nil // non-clustered transactions are not supported
    53  	}
    54  	m.tx = func(fnctx context.Context, fn func(ctx context.Context) error) error {
    55  		sess, err := m.client.StartSession()
    56  		if err != nil {
    57  			return errs.WrapMsg(err, "mongodb start session failed")
    58  		}
    59  		defer sess.EndSession(fnctx)
    60  		_, err = sess.WithTransaction(fnctx, func(sessCtx mongo.SessionContext) (any, error) {
    61  			return nil, fn(sessCtx)
    62  		})
    63  		return errs.WrapMsg(err, "mongodb transaction failed")
    64  	}
    65  	return nil
    66  }
    67  
    68  func (m *mongoTx) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
    69  	if m.tx == nil {
    70  		return fn(ctx)
    71  	}
    72  	return m.tx(ctx, fn)
    73  }