github.com/jcmturner/gokrb5/v8@v8.4.4/client/session.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/jcmturner/gokrb5/v8/iana/nametype"
    12  	"github.com/jcmturner/gokrb5/v8/krberror"
    13  	"github.com/jcmturner/gokrb5/v8/messages"
    14  	"github.com/jcmturner/gokrb5/v8/types"
    15  )
    16  
    17  // sessions hold TGTs and are keyed on the realm name
    18  type sessions struct {
    19  	Entries map[string]*session
    20  	mux     sync.RWMutex
    21  }
    22  
    23  // destroy erases all sessions
    24  func (s *sessions) destroy() {
    25  	s.mux.Lock()
    26  	defer s.mux.Unlock()
    27  	for k, e := range s.Entries {
    28  		e.destroy()
    29  		delete(s.Entries, k)
    30  	}
    31  }
    32  
    33  // update replaces a session with the one provided or adds it as a new one
    34  func (s *sessions) update(sess *session) {
    35  	s.mux.Lock()
    36  	defer s.mux.Unlock()
    37  	// if a session already exists for this, cancel its auto renew.
    38  	if i, ok := s.Entries[sess.realm]; ok {
    39  		if i != sess {
    40  			// Session in the sessions cache is not the same as one provided.
    41  			// Cancel the one in the cache and add this one.
    42  			i.mux.Lock()
    43  			defer i.mux.Unlock()
    44  			if i.cancel != nil {
    45  				i.cancel <- true
    46  			}
    47  			s.Entries[sess.realm] = sess
    48  			return
    49  		}
    50  	}
    51  	// No session for this realm was found so just add it
    52  	s.Entries[sess.realm] = sess
    53  }
    54  
    55  // get returns the session for the realm specified
    56  func (s *sessions) get(realm string) (*session, bool) {
    57  	s.mux.RLock()
    58  	defer s.mux.RUnlock()
    59  	sess, ok := s.Entries[realm]
    60  	return sess, ok
    61  }
    62  
    63  // session holds the TGT details for a realm
    64  type session struct {
    65  	realm                string
    66  	authTime             time.Time
    67  	endTime              time.Time
    68  	renewTill            time.Time
    69  	tgt                  messages.Ticket
    70  	sessionKey           types.EncryptionKey
    71  	sessionKeyExpiration time.Time
    72  	cancel               chan bool
    73  	mux                  sync.RWMutex
    74  }
    75  
    76  // jsonSession is used to enable marshaling some information of a session in a JSON format
    77  type jsonSession struct {
    78  	Realm                string
    79  	AuthTime             time.Time
    80  	EndTime              time.Time
    81  	RenewTill            time.Time
    82  	SessionKeyExpiration time.Time
    83  }
    84  
    85  // AddSession adds a session for a realm with a TGT to the client's session cache.
    86  // A goroutine is started to automatically renew the TGT before expiry.
    87  func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
    88  	if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
    89  		// Not a TGT
    90  		return
    91  	}
    92  	realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
    93  	s := &session{
    94  		realm:                realm,
    95  		authTime:             dep.AuthTime,
    96  		endTime:              dep.EndTime,
    97  		renewTill:            dep.RenewTill,
    98  		tgt:                  tgt,
    99  		sessionKey:           dep.Key,
   100  		sessionKeyExpiration: dep.KeyExpiration,
   101  	}
   102  	cl.sessions.update(s)
   103  	cl.enableAutoSessionRenewal(s)
   104  	cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
   105  }
   106  
   107  // update overwrites the session details with those from the TGT and decrypted encPart
   108  func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
   109  	s.mux.Lock()
   110  	defer s.mux.Unlock()
   111  	s.authTime = dep.AuthTime
   112  	s.endTime = dep.EndTime
   113  	s.renewTill = dep.RenewTill
   114  	s.tgt = tgt
   115  	s.sessionKey = dep.Key
   116  	s.sessionKeyExpiration = dep.KeyExpiration
   117  }
   118  
   119  // destroy will cancel any auto renewal of the session and set the expiration times to the current time
   120  func (s *session) destroy() {
   121  	s.mux.Lock()
   122  	defer s.mux.Unlock()
   123  	if s.cancel != nil {
   124  		s.cancel <- true
   125  	}
   126  	s.endTime = time.Now().UTC()
   127  	s.renewTill = s.endTime
   128  	s.sessionKeyExpiration = s.endTime
   129  }
   130  
   131  // valid informs if the TGT is still within the valid time window
   132  func (s *session) valid() bool {
   133  	s.mux.RLock()
   134  	defer s.mux.RUnlock()
   135  	t := time.Now().UTC()
   136  	if t.Before(s.endTime) && s.authTime.Before(t) {
   137  		return true
   138  	}
   139  	return false
   140  }
   141  
   142  // tgtDetails is a thread safe way to get the session's realm, TGT and session key values
   143  func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
   144  	s.mux.RLock()
   145  	defer s.mux.RUnlock()
   146  	return s.realm, s.tgt, s.sessionKey
   147  }
   148  
   149  // timeDetails is a thread safe way to get the session's validity time values
   150  func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
   151  	s.mux.RLock()
   152  	defer s.mux.RUnlock()
   153  	return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
   154  }
   155  
   156  // JSON return information about the held sessions in a JSON format.
   157  func (s *sessions) JSON() (string, error) {
   158  	s.mux.RLock()
   159  	defer s.mux.RUnlock()
   160  	var js []jsonSession
   161  	keys := make([]string, 0, len(s.Entries))
   162  	for k := range s.Entries {
   163  		keys = append(keys, k)
   164  	}
   165  	sort.Strings(keys)
   166  	for _, k := range keys {
   167  		r, at, et, rt, kt := s.Entries[k].timeDetails()
   168  		j := jsonSession{
   169  			Realm:                r,
   170  			AuthTime:             at,
   171  			EndTime:              et,
   172  			RenewTill:            rt,
   173  			SessionKeyExpiration: kt,
   174  		}
   175  		js = append(js, j)
   176  	}
   177  	b, err := json.MarshalIndent(js, "", "  ")
   178  	if err != nil {
   179  		return "", err
   180  	}
   181  	return string(b), nil
   182  }
   183  
   184  // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
   185  func (cl *Client) enableAutoSessionRenewal(s *session) {
   186  	var timer *time.Timer
   187  	s.mux.Lock()
   188  	s.cancel = make(chan bool, 1)
   189  	s.mux.Unlock()
   190  	go func(s *session) {
   191  		for {
   192  			s.mux.RLock()
   193  			w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
   194  			s.mux.RUnlock()
   195  			if w < 0 {
   196  				return
   197  			}
   198  			timer = time.NewTimer(w)
   199  			select {
   200  			case <-timer.C:
   201  				renewal, err := cl.refreshSession(s)
   202  				if err != nil {
   203  					cl.Log("error refreshing session: %v", err)
   204  				}
   205  				if !renewal && err == nil {
   206  					// end this goroutine as there will have been a new login and new auto renewal goroutine created.
   207  					return
   208  				}
   209  			case <-s.cancel:
   210  				// cancel has been called. Stop the timer and exit.
   211  				timer.Stop()
   212  				return
   213  			}
   214  		}
   215  	}(s)
   216  }
   217  
   218  // renewTGT renews the client's TGT session.
   219  func (cl *Client) renewTGT(s *session) error {
   220  	realm, tgt, skey := s.tgtDetails()
   221  	spn := types.PrincipalName{
   222  		NameType:   nametype.KRB_NT_SRV_INST,
   223  		NameString: []string{"krbtgt", realm},
   224  	}
   225  	_, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
   226  	if err != nil {
   227  		return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
   228  	}
   229  	s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
   230  	cl.sessions.update(s)
   231  	cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
   232  	return nil
   233  }
   234  
   235  // refreshSession updates either through renewal or creating a new login.
   236  // The boolean indicates if the update was a renewal.
   237  func (cl *Client) refreshSession(s *session) (bool, error) {
   238  	s.mux.RLock()
   239  	realm := s.realm
   240  	renewTill := s.renewTill
   241  	s.mux.RUnlock()
   242  	cl.Log("refreshing TGT session for %s", realm)
   243  	if time.Now().UTC().Before(renewTill) {
   244  		err := cl.renewTGT(s)
   245  		return true, err
   246  	}
   247  	err := cl.realmLogin(realm)
   248  	return false, err
   249  }
   250  
   251  // ensureValidSession makes sure there is a valid session for the realm
   252  func (cl *Client) ensureValidSession(realm string) error {
   253  	s, ok := cl.sessions.get(realm)
   254  	if ok {
   255  		s.mux.RLock()
   256  		d := s.endTime.Sub(s.authTime) / 6
   257  		if s.endTime.Sub(time.Now().UTC()) > d {
   258  			s.mux.RUnlock()
   259  			return nil
   260  		}
   261  		s.mux.RUnlock()
   262  		_, err := cl.refreshSession(s)
   263  		return err
   264  	}
   265  	return cl.realmLogin(realm)
   266  }
   267  
   268  // sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
   269  func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
   270  	err = cl.ensureValidSession(realm)
   271  	if err != nil {
   272  		return
   273  	}
   274  	s, ok := cl.sessions.get(realm)
   275  	if !ok {
   276  		err = fmt.Errorf("could not find TGT session for %s", realm)
   277  		return
   278  	}
   279  	_, tgt, sessionKey = s.tgtDetails()
   280  	return
   281  }
   282  
   283  // sessionTimes provides the timing information with regards to a session for the realm specified.
   284  func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
   285  	s, ok := cl.sessions.get(realm)
   286  	if !ok {
   287  		err = fmt.Errorf("could not find TGT session for %s", realm)
   288  		return
   289  	}
   290  	_, authTime, endTime, renewTime, sessionExp = s.timeDetails()
   291  	return
   292  }
   293  
   294  // spnRealm resolves the realm name of a service principal name
   295  func (cl *Client) spnRealm(spn types.PrincipalName) string {
   296  	return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
   297  }