github.com/lianghucheng/zrddz@v0.0.0-20200923083010-c71f680932e2/src/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go (about)

     1  package sasl
     2  
     3  // #include "sasl_windows.h"
     4  import "C"
     5  
     6  import (
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"unsafe"
    11  )
    12  
    13  type saslStepper interface {
    14  	Step(serverData []byte) (clientData []byte, done bool, err error)
    15  	Close()
    16  }
    17  
    18  type saslSession struct {
    19  	// Credentials
    20  	mech          string
    21  	service       string
    22  	host          string
    23  	userPlusRealm string
    24  	target        string
    25  	domain        string
    26  
    27  	// Internal state
    28  	authComplete bool
    29  	errored      bool
    30  	step         int
    31  
    32  	// C internal state
    33  	credHandle C.CredHandle
    34  	context    C.CtxtHandle
    35  	hasContext C.int
    36  
    37  	// Keep track of pointers we need to explicitly free
    38  	stringsToFree []*C.char
    39  }
    40  
    41  var initError error
    42  var initOnce sync.Once
    43  
    44  func initSSPI() {
    45  	rc := C.load_secur32_dll()
    46  	if rc != 0 {
    47  		initError = fmt.Errorf("Error loading libraries: %v", rc)
    48  	}
    49  }
    50  
    51  func New(username, password, mechanism, service, host string) (saslStepper, error) {
    52  	initOnce.Do(initSSPI)
    53  	ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username}
    54  	if service == "" {
    55  		service = "mongodb"
    56  	}
    57  	if i := strings.Index(host, ":"); i >= 0 {
    58  		host = host[:i]
    59  	}
    60  	ss.service = service
    61  	ss.host = host
    62  
    63  	usernameComponents := strings.Split(username, "@")
    64  	if len(usernameComponents) < 2 {
    65  		return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username)
    66  	}
    67  	user := usernameComponents[0]
    68  	ss.domain = usernameComponents[1]
    69  	ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host)
    70  
    71  	var status C.SECURITY_STATUS
    72  	// Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle
    73  	if len(password) > 0 {
    74  		status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain))
    75  	} else {
    76  		status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain))
    77  	}
    78  	if status != C.SEC_E_OK {
    79  		ss.errored = true
    80  		return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status)
    81  	}
    82  	return ss, nil
    83  }
    84  
    85  func (ss *saslSession) cstr(s string) *C.char {
    86  	cstr := C.CString(s)
    87  	ss.stringsToFree = append(ss.stringsToFree, cstr)
    88  	return cstr
    89  }
    90  
    91  func (ss *saslSession) Close() {
    92  	for _, cstr := range ss.stringsToFree {
    93  		C.free(unsafe.Pointer(cstr))
    94  	}
    95  }
    96  
    97  func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
    98  	ss.step++
    99  	if ss.step > 10 {
   100  		return nil, false, fmt.Errorf("too many SSPI steps without authentication")
   101  	}
   102  	var buffer C.PVOID
   103  	var bufferLength C.ULONG
   104  	var outBuffer C.PVOID
   105  	var outBufferLength C.ULONG
   106  	if len(serverData) > 0 {
   107  		buffer = (C.PVOID)(unsafe.Pointer(&serverData[0]))
   108  		bufferLength = C.ULONG(len(serverData))
   109  	}
   110  	var status C.int
   111  	if ss.authComplete {
   112  		// Step 3: last bit of magic to use the correct server credentials
   113  		status = C.sspi_send_client_authz_id(&ss.context, &outBuffer, &outBufferLength, ss.cstr(ss.userPlusRealm))
   114  	} else {
   115  		// Step 1 + Step 2: set up security context with the server and TGT
   116  		status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, buffer, bufferLength, &outBuffer, &outBufferLength, ss.cstr(ss.target))
   117  	}
   118  	if outBuffer != C.PVOID(nil) {
   119  		defer C.free(unsafe.Pointer(outBuffer))
   120  	}
   121  	if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED {
   122  		ss.errored = true
   123  		return nil, false, ss.handleSSPIErrorCode(status)
   124  	}
   125  
   126  	clientData = C.GoBytes(unsafe.Pointer(outBuffer), C.int(outBufferLength))
   127  	if status == C.SEC_E_OK {
   128  		ss.authComplete = true
   129  		return clientData, true, nil
   130  	} else {
   131  		ss.hasContext = 1
   132  		return clientData, false, nil
   133  	}
   134  }
   135  
   136  func (ss *saslSession) handleSSPIErrorCode(code C.int) error {
   137  	switch {
   138  	case code == C.SEC_E_TARGET_UNKNOWN:
   139  		return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain)
   140  	}
   141  	return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code)
   142  }