github.com/lianghucheng/zrddz@v0.0.0-20200923083010-c71f680932e2/src/gopkg.in/mgo.v2/auth.go (about)

     1  // mgo - MongoDB driver for Go
     2  //
     3  // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
     4  //
     5  // All rights reserved.
     6  //
     7  // Redistribution and use in source and binary forms, with or without
     8  // modification, are permitted provided that the following conditions are met:
     9  //
    10  // 1. Redistributions of source code must retain the above copyright notice, this
    11  //    list of conditions and the following disclaimer.
    12  // 2. Redistributions in binary form must reproduce the above copyright notice,
    13  //    this list of conditions and the following disclaimer in the documentation
    14  //    and/or other materials provided with the distribution.
    15  //
    16  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
    17  // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
    18  // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    19  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
    20  // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
    21  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
    22  // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
    23  // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    24  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    25  // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    26  
    27  package mgo
    28  
    29  import (
    30  	"crypto/md5"
    31  	"crypto/sha1"
    32  	"encoding/hex"
    33  	"errors"
    34  	"fmt"
    35  	"sync"
    36  
    37  	"gopkg.in/mgo.v2/bson"
    38  	"gopkg.in/mgo.v2/internal/scram"
    39  )
    40  
    41  type authCmd struct {
    42  	Authenticate int
    43  
    44  	Nonce string
    45  	User  string
    46  	Key   string
    47  }
    48  
    49  type startSaslCmd struct {
    50  	StartSASL int `bson:"startSasl"`
    51  }
    52  
    53  type authResult struct {
    54  	ErrMsg string
    55  	Ok     bool
    56  }
    57  
    58  type getNonceCmd struct {
    59  	GetNonce int
    60  }
    61  
    62  type getNonceResult struct {
    63  	Nonce string
    64  	Err   string "$err"
    65  	Code  int
    66  }
    67  
    68  type logoutCmd struct {
    69  	Logout int
    70  }
    71  
    72  type saslCmd struct {
    73  	Start          int    `bson:"saslStart,omitempty"`
    74  	Continue       int    `bson:"saslContinue,omitempty"`
    75  	ConversationId int    `bson:"conversationId,omitempty"`
    76  	Mechanism      string `bson:"mechanism,omitempty"`
    77  	Payload        []byte
    78  }
    79  
    80  type saslResult struct {
    81  	Ok    bool `bson:"ok"`
    82  	NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?)
    83  	Done  bool
    84  
    85  	ConversationId int `bson:"conversationId"`
    86  	Payload        []byte
    87  	ErrMsg         string
    88  }
    89  
    90  type saslStepper interface {
    91  	Step(serverData []byte) (clientData []byte, done bool, err error)
    92  	Close()
    93  }
    94  
    95  func (socket *mongoSocket) getNonce() (nonce string, err error) {
    96  	socket.Lock()
    97  	for socket.cachedNonce == "" && socket.dead == nil {
    98  		debugf("Socket %p to %s: waiting for nonce", socket, socket.addr)
    99  		socket.gotNonce.Wait()
   100  	}
   101  	if socket.cachedNonce == "mongos" {
   102  		socket.Unlock()
   103  		return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth")
   104  	}
   105  	debugf("Socket %p to %s: got nonce", socket, socket.addr)
   106  	nonce, err = socket.cachedNonce, socket.dead
   107  	socket.cachedNonce = ""
   108  	socket.Unlock()
   109  	if err != nil {
   110  		nonce = ""
   111  	}
   112  	return
   113  }
   114  
   115  func (socket *mongoSocket) resetNonce() {
   116  	debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr)
   117  	op := &queryOp{}
   118  	op.query = &getNonceCmd{GetNonce: 1}
   119  	op.collection = "admin.$cmd"
   120  	op.limit = -1
   121  	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
   122  		if err != nil {
   123  			socket.kill(errors.New("getNonce: "+err.Error()), true)
   124  			return
   125  		}
   126  		result := &getNonceResult{}
   127  		err = bson.Unmarshal(docData, &result)
   128  		if err != nil {
   129  			socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true)
   130  			return
   131  		}
   132  		debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result)
   133  		if result.Code == 13390 {
   134  			// mongos doesn't yet support auth (see http://j.mp/mongos-auth)
   135  			result.Nonce = "mongos"
   136  		} else if result.Nonce == "" {
   137  			var msg string
   138  			if result.Err != "" {
   139  				msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code)
   140  			} else {
   141  				msg = "Got an empty nonce"
   142  			}
   143  			socket.kill(errors.New(msg), true)
   144  			return
   145  		}
   146  		socket.Lock()
   147  		if socket.cachedNonce != "" {
   148  			socket.Unlock()
   149  			panic("resetNonce: nonce already cached")
   150  		}
   151  		socket.cachedNonce = result.Nonce
   152  		socket.gotNonce.Signal()
   153  		socket.Unlock()
   154  	}
   155  	err := socket.Query(op)
   156  	if err != nil {
   157  		socket.kill(errors.New("resetNonce: "+err.Error()), true)
   158  	}
   159  }
   160  
   161  func (socket *mongoSocket) Login(cred Credential) error {
   162  	socket.Lock()
   163  	if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 {
   164  		cred.Mechanism = "SCRAM-SHA-1"
   165  	}
   166  	for _, sockCred := range socket.creds {
   167  		if sockCred == cred {
   168  			debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username)
   169  			socket.Unlock()
   170  			return nil
   171  		}
   172  	}
   173  	if socket.dropLogout(cred) {
   174  		debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username)
   175  		socket.creds = append(socket.creds, cred)
   176  		socket.Unlock()
   177  		return nil
   178  	}
   179  	socket.Unlock()
   180  
   181  	debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username)
   182  
   183  	var err error
   184  	switch cred.Mechanism {
   185  	case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501.
   186  		err = socket.loginClassic(cred)
   187  	case "PLAIN":
   188  		err = socket.loginPlain(cred)
   189  	case "MONGODB-X509":
   190  		err = socket.loginX509(cred)
   191  	default:
   192  		// Try SASL for everything else, if it is available.
   193  		err = socket.loginSASL(cred)
   194  	}
   195  
   196  	if err != nil {
   197  		debugf("Socket %p to %s: login error: %s", socket, socket.addr, err)
   198  	} else {
   199  		debugf("Socket %p to %s: login successful", socket, socket.addr)
   200  	}
   201  	return err
   202  }
   203  
   204  func (socket *mongoSocket) loginClassic(cred Credential) error {
   205  	// Note that this only works properly because this function is
   206  	// synchronous, which means the nonce won't get reset while we're
   207  	// using it and any other login requests will block waiting for a
   208  	// new nonce provided in the defer call below.
   209  	nonce, err := socket.getNonce()
   210  	if err != nil {
   211  		return err
   212  	}
   213  	defer socket.resetNonce()
   214  
   215  	psum := md5.New()
   216  	psum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
   217  
   218  	ksum := md5.New()
   219  	ksum.Write([]byte(nonce + cred.Username))
   220  	ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil))))
   221  
   222  	key := hex.EncodeToString(ksum.Sum(nil))
   223  
   224  	cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key}
   225  	res := authResult{}
   226  	return socket.loginRun(cred.Source, &cmd, &res, func() error {
   227  		if !res.Ok {
   228  			return errors.New(res.ErrMsg)
   229  		}
   230  		socket.Lock()
   231  		socket.dropAuth(cred.Source)
   232  		socket.creds = append(socket.creds, cred)
   233  		socket.Unlock()
   234  		return nil
   235  	})
   236  }
   237  
   238  type authX509Cmd struct {
   239  	Authenticate int
   240  	User         string
   241  	Mechanism    string
   242  }
   243  
   244  func (socket *mongoSocket) loginX509(cred Credential) error {
   245  	cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"}
   246  	res := authResult{}
   247  	return socket.loginRun(cred.Source, &cmd, &res, func() error {
   248  		if !res.Ok {
   249  			return errors.New(res.ErrMsg)
   250  		}
   251  		socket.Lock()
   252  		socket.dropAuth(cred.Source)
   253  		socket.creds = append(socket.creds, cred)
   254  		socket.Unlock()
   255  		return nil
   256  	})
   257  }
   258  
   259  func (socket *mongoSocket) loginPlain(cred Credential) error {
   260  	cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)}
   261  	res := authResult{}
   262  	return socket.loginRun(cred.Source, &cmd, &res, func() error {
   263  		if !res.Ok {
   264  			return errors.New(res.ErrMsg)
   265  		}
   266  		socket.Lock()
   267  		socket.dropAuth(cred.Source)
   268  		socket.creds = append(socket.creds, cred)
   269  		socket.Unlock()
   270  		return nil
   271  	})
   272  }
   273  
   274  func (socket *mongoSocket) loginSASL(cred Credential) error {
   275  	var sasl saslStepper
   276  	var err error
   277  	if cred.Mechanism == "SCRAM-SHA-1" {
   278  		// SCRAM is handled without external libraries.
   279  		sasl = saslNewScram(cred)
   280  	} else if len(cred.ServiceHost) > 0 {
   281  		sasl, err = saslNew(cred, cred.ServiceHost)
   282  	} else {
   283  		sasl, err = saslNew(cred, socket.Server().Addr)
   284  	}
   285  	if err != nil {
   286  		return err
   287  	}
   288  	defer sasl.Close()
   289  
   290  	// The goal of this logic is to carry a locked socket until the
   291  	// local SASL step confirms the auth is valid; the socket needs to be
   292  	// locked so that concurrent action doesn't leave the socket in an
   293  	// auth state that doesn't reflect the operations that took place.
   294  	// As a simple case, imagine inverting login=>logout to logout=>login.
   295  	//
   296  	// The logic below works because the lock func isn't called concurrently.
   297  	locked := false
   298  	lock := func(b bool) {
   299  		if locked != b {
   300  			locked = b
   301  			if b {
   302  				socket.Lock()
   303  			} else {
   304  				socket.Unlock()
   305  			}
   306  		}
   307  	}
   308  
   309  	lock(true)
   310  	defer lock(false)
   311  
   312  	start := 1
   313  	cmd := saslCmd{}
   314  	res := saslResult{}
   315  	for {
   316  		payload, done, err := sasl.Step(res.Payload)
   317  		if err != nil {
   318  			return err
   319  		}
   320  		if done && res.Done {
   321  			socket.dropAuth(cred.Source)
   322  			socket.creds = append(socket.creds, cred)
   323  			break
   324  		}
   325  		lock(false)
   326  
   327  		cmd = saslCmd{
   328  			Start:          start,
   329  			Continue:       1 - start,
   330  			ConversationId: res.ConversationId,
   331  			Mechanism:      cred.Mechanism,
   332  			Payload:        payload,
   333  		}
   334  		start = 0
   335  		err = socket.loginRun(cred.Source, &cmd, &res, func() error {
   336  			// See the comment on lock for why this is necessary.
   337  			lock(true)
   338  			if !res.Ok || res.NotOk {
   339  				return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg)
   340  			}
   341  			return nil
   342  		})
   343  		if err != nil {
   344  			return err
   345  		}
   346  		if done && res.Done {
   347  			socket.dropAuth(cred.Source)
   348  			socket.creds = append(socket.creds, cred)
   349  			break
   350  		}
   351  	}
   352  
   353  	return nil
   354  }
   355  
   356  func saslNewScram(cred Credential) *saslScram {
   357  	credsum := md5.New()
   358  	credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
   359  	client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
   360  	return &saslScram{cred: cred, client: client}
   361  }
   362  
   363  type saslScram struct {
   364  	cred   Credential
   365  	client *scram.Client
   366  }
   367  
   368  func (s *saslScram) Close() {}
   369  
   370  func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) {
   371  	more := s.client.Step(serverData)
   372  	return s.client.Out(), !more, s.client.Err()
   373  }
   374  
   375  func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error {
   376  	var mutex sync.Mutex
   377  	var replyErr error
   378  	mutex.Lock()
   379  
   380  	op := queryOp{}
   381  	op.query = query
   382  	op.collection = db + ".$cmd"
   383  	op.limit = -1
   384  	op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
   385  		defer mutex.Unlock()
   386  
   387  		if err != nil {
   388  			replyErr = err
   389  			return
   390  		}
   391  
   392  		err = bson.Unmarshal(docData, result)
   393  		if err != nil {
   394  			replyErr = err
   395  		} else {
   396  			// Must handle this within the read loop for the socket, so
   397  			// that concurrent login requests are properly ordered.
   398  			replyErr = f()
   399  		}
   400  	}
   401  
   402  	err := socket.Query(&op)
   403  	if err != nil {
   404  		return err
   405  	}
   406  	mutex.Lock() // Wait.
   407  	return replyErr
   408  }
   409  
   410  func (socket *mongoSocket) Logout(db string) {
   411  	socket.Lock()
   412  	cred, found := socket.dropAuth(db)
   413  	if found {
   414  		debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db)
   415  		socket.logout = append(socket.logout, cred)
   416  	}
   417  	socket.Unlock()
   418  }
   419  
   420  func (socket *mongoSocket) LogoutAll() {
   421  	socket.Lock()
   422  	if l := len(socket.creds); l > 0 {
   423  		debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l)
   424  		socket.logout = append(socket.logout, socket.creds...)
   425  		socket.creds = socket.creds[0:0]
   426  	}
   427  	socket.Unlock()
   428  }
   429  
   430  func (socket *mongoSocket) flushLogout() (ops []interface{}) {
   431  	socket.Lock()
   432  	if l := len(socket.logout); l > 0 {
   433  		debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l)
   434  		for i := 0; i != l; i++ {
   435  			op := queryOp{}
   436  			op.query = &logoutCmd{1}
   437  			op.collection = socket.logout[i].Source + ".$cmd"
   438  			op.limit = -1
   439  			ops = append(ops, &op)
   440  		}
   441  		socket.logout = socket.logout[0:0]
   442  	}
   443  	socket.Unlock()
   444  	return
   445  }
   446  
   447  func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) {
   448  	for i, sockCred := range socket.creds {
   449  		if sockCred.Source == db {
   450  			copy(socket.creds[i:], socket.creds[i+1:])
   451  			socket.creds = socket.creds[:len(socket.creds)-1]
   452  			return sockCred, true
   453  		}
   454  	}
   455  	return cred, false
   456  }
   457  
   458  func (socket *mongoSocket) dropLogout(cred Credential) (found bool) {
   459  	for i, sockCred := range socket.logout {
   460  		if sockCred == cred {
   461  			copy(socket.logout[i:], socket.logout[i+1:])
   462  			socket.logout = socket.logout[:len(socket.logout)-1]
   463  			return true
   464  		}
   465  	}
   466  	return false
   467  }