github.com/lianghucheng/zrddz@v0.0.0-20200923083010-c71f680932e2/src/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go (about) 1 package sasl 2 3 // #include "sasl_windows.h" 4 import "C" 5 6 import ( 7 "fmt" 8 "strings" 9 "sync" 10 "unsafe" 11 ) 12 13 type saslStepper interface { 14 Step(serverData []byte) (clientData []byte, done bool, err error) 15 Close() 16 } 17 18 type saslSession struct { 19 // Credentials 20 mech string 21 service string 22 host string 23 userPlusRealm string 24 target string 25 domain string 26 27 // Internal state 28 authComplete bool 29 errored bool 30 step int 31 32 // C internal state 33 credHandle C.CredHandle 34 context C.CtxtHandle 35 hasContext C.int 36 37 // Keep track of pointers we need to explicitly free 38 stringsToFree []*C.char 39 } 40 41 var initError error 42 var initOnce sync.Once 43 44 func initSSPI() { 45 rc := C.load_secur32_dll() 46 if rc != 0 { 47 initError = fmt.Errorf("Error loading libraries: %v", rc) 48 } 49 } 50 51 func New(username, password, mechanism, service, host string) (saslStepper, error) { 52 initOnce.Do(initSSPI) 53 ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username} 54 if service == "" { 55 service = "mongodb" 56 } 57 if i := strings.Index(host, ":"); i >= 0 { 58 host = host[:i] 59 } 60 ss.service = service 61 ss.host = host 62 63 usernameComponents := strings.Split(username, "@") 64 if len(usernameComponents) < 2 { 65 return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username) 66 } 67 user := usernameComponents[0] 68 ss.domain = usernameComponents[1] 69 ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host) 70 71 var status C.SECURITY_STATUS 72 // Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle 73 if len(password) > 0 { 74 status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain)) 75 } else { 76 status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain)) 77 } 78 if status != C.SEC_E_OK { 79 ss.errored = true 80 return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status) 81 } 82 return ss, nil 83 } 84 85 func (ss *saslSession) cstr(s string) *C.char { 86 cstr := C.CString(s) 87 ss.stringsToFree = append(ss.stringsToFree, cstr) 88 return cstr 89 } 90 91 func (ss *saslSession) Close() { 92 for _, cstr := range ss.stringsToFree { 93 C.free(unsafe.Pointer(cstr)) 94 } 95 } 96 97 func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { 98 ss.step++ 99 if ss.step > 10 { 100 return nil, false, fmt.Errorf("too many SSPI steps without authentication") 101 } 102 var buffer C.PVOID 103 var bufferLength C.ULONG 104 var outBuffer C.PVOID 105 var outBufferLength C.ULONG 106 if len(serverData) > 0 { 107 buffer = (C.PVOID)(unsafe.Pointer(&serverData[0])) 108 bufferLength = C.ULONG(len(serverData)) 109 } 110 var status C.int 111 if ss.authComplete { 112 // Step 3: last bit of magic to use the correct server credentials 113 status = C.sspi_send_client_authz_id(&ss.context, &outBuffer, &outBufferLength, ss.cstr(ss.userPlusRealm)) 114 } else { 115 // Step 1 + Step 2: set up security context with the server and TGT 116 status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, buffer, bufferLength, &outBuffer, &outBufferLength, ss.cstr(ss.target)) 117 } 118 if outBuffer != C.PVOID(nil) { 119 defer C.free(unsafe.Pointer(outBuffer)) 120 } 121 if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED { 122 ss.errored = true 123 return nil, false, ss.handleSSPIErrorCode(status) 124 } 125 126 clientData = C.GoBytes(unsafe.Pointer(outBuffer), C.int(outBufferLength)) 127 if status == C.SEC_E_OK { 128 ss.authComplete = true 129 return clientData, true, nil 130 } else { 131 ss.hasContext = 1 132 return clientData, false, nil 133 } 134 } 135 136 func (ss *saslSession) handleSSPIErrorCode(code C.int) error { 137 switch { 138 case code == C.SEC_E_TARGET_UNKNOWN: 139 return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain) 140 } 141 return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code) 142 }