github.com/clly/consul@v1.4.5/agent/consul/state/session.go (about)

     1  package state
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	"github.com/hashicorp/consul/agent/structs"
     8  	"github.com/hashicorp/consul/api"
     9  	"github.com/hashicorp/go-memdb"
    10  )
    11  
    12  // sessionsTableSchema returns a new table schema used for storing session
    13  // information.
    14  func sessionsTableSchema() *memdb.TableSchema {
    15  	return &memdb.TableSchema{
    16  		Name: "sessions",
    17  		Indexes: map[string]*memdb.IndexSchema{
    18  			"id": &memdb.IndexSchema{
    19  				Name:         "id",
    20  				AllowMissing: false,
    21  				Unique:       true,
    22  				Indexer: &memdb.UUIDFieldIndex{
    23  					Field: "ID",
    24  				},
    25  			},
    26  			"node": &memdb.IndexSchema{
    27  				Name:         "node",
    28  				AllowMissing: false,
    29  				Unique:       false,
    30  				Indexer: &memdb.StringFieldIndex{
    31  					Field:     "Node",
    32  					Lowercase: true,
    33  				},
    34  			},
    35  		},
    36  	}
    37  }
    38  
    39  // sessionChecksTableSchema returns a new table schema used for storing session
    40  // checks.
    41  func sessionChecksTableSchema() *memdb.TableSchema {
    42  	return &memdb.TableSchema{
    43  		Name: "session_checks",
    44  		Indexes: map[string]*memdb.IndexSchema{
    45  			"id": &memdb.IndexSchema{
    46  				Name:         "id",
    47  				AllowMissing: false,
    48  				Unique:       true,
    49  				Indexer: &memdb.CompoundIndex{
    50  					Indexes: []memdb.Indexer{
    51  						&memdb.StringFieldIndex{
    52  							Field:     "Node",
    53  							Lowercase: true,
    54  						},
    55  						&memdb.StringFieldIndex{
    56  							Field:     "CheckID",
    57  							Lowercase: true,
    58  						},
    59  						&memdb.UUIDFieldIndex{
    60  							Field: "Session",
    61  						},
    62  					},
    63  				},
    64  			},
    65  			"node_check": &memdb.IndexSchema{
    66  				Name:         "node_check",
    67  				AllowMissing: false,
    68  				Unique:       false,
    69  				Indexer: &memdb.CompoundIndex{
    70  					Indexes: []memdb.Indexer{
    71  						&memdb.StringFieldIndex{
    72  							Field:     "Node",
    73  							Lowercase: true,
    74  						},
    75  						&memdb.StringFieldIndex{
    76  							Field:     "CheckID",
    77  							Lowercase: true,
    78  						},
    79  					},
    80  				},
    81  			},
    82  			"session": &memdb.IndexSchema{
    83  				Name:         "session",
    84  				AllowMissing: false,
    85  				Unique:       false,
    86  				Indexer: &memdb.UUIDFieldIndex{
    87  					Field: "Session",
    88  				},
    89  			},
    90  		},
    91  	}
    92  }
    93  
    94  func init() {
    95  	registerSchema(sessionsTableSchema)
    96  	registerSchema(sessionChecksTableSchema)
    97  }
    98  
    99  // Sessions is used to pull the full list of sessions for use during snapshots.
   100  func (s *Snapshot) Sessions() (memdb.ResultIterator, error) {
   101  	iter, err := s.tx.Get("sessions", "id")
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	return iter, nil
   106  }
   107  
   108  // Session is used when restoring from a snapshot. For general inserts, use
   109  // SessionCreate.
   110  func (s *Restore) Session(sess *structs.Session) error {
   111  	// Insert the session.
   112  	if err := s.tx.Insert("sessions", sess); err != nil {
   113  		return fmt.Errorf("failed inserting session: %s", err)
   114  	}
   115  
   116  	// Insert the check mappings.
   117  	for _, checkID := range sess.Checks {
   118  		mapping := &sessionCheck{
   119  			Node:    sess.Node,
   120  			CheckID: checkID,
   121  			Session: sess.ID,
   122  		}
   123  		if err := s.tx.Insert("session_checks", mapping); err != nil {
   124  			return fmt.Errorf("failed inserting session check mapping: %s", err)
   125  		}
   126  	}
   127  
   128  	// Update the index.
   129  	if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil {
   130  		return fmt.Errorf("failed updating index: %s", err)
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  // SessionCreate is used to register a new session in the state store.
   137  func (s *Store) SessionCreate(idx uint64, sess *structs.Session) error {
   138  	tx := s.db.Txn(true)
   139  	defer tx.Abort()
   140  
   141  	// This code is technically able to (incorrectly) update an existing
   142  	// session but we never do that in practice. The upstream endpoint code
   143  	// always adds a unique ID when doing a create operation so we never hit
   144  	// an existing session again. It isn't worth the overhead to verify
   145  	// that here, but it's worth noting that we should never do this in the
   146  	// future.
   147  
   148  	// Call the session creation
   149  	if err := s.sessionCreateTxn(tx, idx, sess); err != nil {
   150  		return err
   151  	}
   152  
   153  	tx.Commit()
   154  	return nil
   155  }
   156  
   157  // sessionCreateTxn is the inner method used for creating session entries in
   158  // an open transaction. Any health checks registered with the session will be
   159  // checked for failing status. Returns any error encountered.
   160  func (s *Store) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error {
   161  	// Check that we have a session ID
   162  	if sess.ID == "" {
   163  		return ErrMissingSessionID
   164  	}
   165  
   166  	// Verify the session behavior is valid
   167  	switch sess.Behavior {
   168  	case "":
   169  		// Release by default to preserve backwards compatibility
   170  		sess.Behavior = structs.SessionKeysRelease
   171  	case structs.SessionKeysRelease:
   172  	case structs.SessionKeysDelete:
   173  	default:
   174  		return fmt.Errorf("Invalid session behavior: %s", sess.Behavior)
   175  	}
   176  
   177  	// Assign the indexes. ModifyIndex likely will not be used but
   178  	// we set it here anyways for sanity.
   179  	sess.CreateIndex = idx
   180  	sess.ModifyIndex = idx
   181  
   182  	// Check that the node exists
   183  	node, err := tx.First("nodes", "id", sess.Node)
   184  	if err != nil {
   185  		return fmt.Errorf("failed node lookup: %s", err)
   186  	}
   187  	if node == nil {
   188  		return ErrMissingNode
   189  	}
   190  
   191  	// Go over the session checks and ensure they exist.
   192  	for _, checkID := range sess.Checks {
   193  		check, err := tx.First("checks", "id", sess.Node, string(checkID))
   194  		if err != nil {
   195  			return fmt.Errorf("failed check lookup: %s", err)
   196  		}
   197  		if check == nil {
   198  			return fmt.Errorf("Missing check '%s' registration", checkID)
   199  		}
   200  
   201  		// Check that the check is not in critical state
   202  		status := check.(*structs.HealthCheck).Status
   203  		if status == api.HealthCritical {
   204  			return fmt.Errorf("Check '%s' is in %s state", checkID, status)
   205  		}
   206  	}
   207  
   208  	// Insert the session
   209  	if err := tx.Insert("sessions", sess); err != nil {
   210  		return fmt.Errorf("failed inserting session: %s", err)
   211  	}
   212  
   213  	// Insert the check mappings
   214  	for _, checkID := range sess.Checks {
   215  		mapping := &sessionCheck{
   216  			Node:    sess.Node,
   217  			CheckID: checkID,
   218  			Session: sess.ID,
   219  		}
   220  		if err := tx.Insert("session_checks", mapping); err != nil {
   221  			return fmt.Errorf("failed inserting session check mapping: %s", err)
   222  		}
   223  	}
   224  
   225  	// Update the index
   226  	if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil {
   227  		return fmt.Errorf("failed updating index: %s", err)
   228  	}
   229  
   230  	return nil
   231  }
   232  
   233  // SessionGet is used to retrieve an active session from the state store.
   234  func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) {
   235  	tx := s.db.Txn(false)
   236  	defer tx.Abort()
   237  
   238  	// Get the table index.
   239  	idx := maxIndexTxn(tx, "sessions")
   240  
   241  	// Look up the session by its ID
   242  	watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID)
   243  	if err != nil {
   244  		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
   245  	}
   246  	ws.Add(watchCh)
   247  	if session != nil {
   248  		return idx, session.(*structs.Session), nil
   249  	}
   250  	return idx, nil, nil
   251  }
   252  
   253  // SessionList returns a slice containing all of the active sessions.
   254  func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
   255  	tx := s.db.Txn(false)
   256  	defer tx.Abort()
   257  
   258  	// Get the table index.
   259  	idx := maxIndexTxn(tx, "sessions")
   260  
   261  	// Query all of the active sessions.
   262  	sessions, err := tx.Get("sessions", "id")
   263  	if err != nil {
   264  		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
   265  	}
   266  	ws.Add(sessions.WatchCh())
   267  
   268  	// Go over the sessions and create a slice of them.
   269  	var result structs.Sessions
   270  	for session := sessions.Next(); session != nil; session = sessions.Next() {
   271  		result = append(result, session.(*structs.Session))
   272  	}
   273  	return idx, result, nil
   274  }
   275  
   276  // NodeSessions returns a set of active sessions associated
   277  // with the given node ID. The returned index is the highest
   278  // index seen from the result set.
   279  func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) {
   280  	tx := s.db.Txn(false)
   281  	defer tx.Abort()
   282  
   283  	// Get the table index.
   284  	idx := maxIndexTxn(tx, "sessions")
   285  
   286  	// Get all of the sessions which belong to the node
   287  	sessions, err := tx.Get("sessions", "node", nodeID)
   288  	if err != nil {
   289  		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
   290  	}
   291  	ws.Add(sessions.WatchCh())
   292  
   293  	// Go over all of the sessions and return them as a slice
   294  	var result structs.Sessions
   295  	for session := sessions.Next(); session != nil; session = sessions.Next() {
   296  		result = append(result, session.(*structs.Session))
   297  	}
   298  	return idx, result, nil
   299  }
   300  
   301  // SessionDestroy is used to remove an active session. This will
   302  // implicitly invalidate the session and invoke the specified
   303  // session destroy behavior.
   304  func (s *Store) SessionDestroy(idx uint64, sessionID string) error {
   305  	tx := s.db.Txn(true)
   306  	defer tx.Abort()
   307  
   308  	// Call the session deletion.
   309  	if err := s.deleteSessionTxn(tx, idx, sessionID); err != nil {
   310  		return err
   311  	}
   312  
   313  	tx.Commit()
   314  	return nil
   315  }
   316  
   317  // deleteSessionTxn is the inner method, which is used to do the actual
   318  // session deletion and handle session invalidation, etc.
   319  func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) error {
   320  	// Look up the session.
   321  	sess, err := tx.First("sessions", "id", sessionID)
   322  	if err != nil {
   323  		return fmt.Errorf("failed session lookup: %s", err)
   324  	}
   325  	if sess == nil {
   326  		return nil
   327  	}
   328  
   329  	// Delete the session and write the new index.
   330  	if err := tx.Delete("sessions", sess); err != nil {
   331  		return fmt.Errorf("failed deleting session: %s", err)
   332  	}
   333  	if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil {
   334  		return fmt.Errorf("failed updating index: %s", err)
   335  	}
   336  
   337  	// Enforce the max lock delay.
   338  	session := sess.(*structs.Session)
   339  	delay := session.LockDelay
   340  	if delay > structs.MaxLockDelay {
   341  		delay = structs.MaxLockDelay
   342  	}
   343  
   344  	// Snag the current now time so that all the expirations get calculated
   345  	// the same way.
   346  	now := time.Now()
   347  
   348  	// Get an iterator over all of the keys with the given session.
   349  	entries, err := tx.Get("kvs", "session", sessionID)
   350  	if err != nil {
   351  		return fmt.Errorf("failed kvs lookup: %s", err)
   352  	}
   353  	var kvs []interface{}
   354  	for entry := entries.Next(); entry != nil; entry = entries.Next() {
   355  		kvs = append(kvs, entry)
   356  	}
   357  
   358  	// Invalidate any held locks.
   359  	switch session.Behavior {
   360  	case structs.SessionKeysRelease:
   361  		for _, obj := range kvs {
   362  			// Note that we clone here since we are modifying the
   363  			// returned object and want to make sure our set op
   364  			// respects the transaction we are in.
   365  			e := obj.(*structs.DirEntry).Clone()
   366  			e.Session = ""
   367  			if err := s.kvsSetTxn(tx, idx, e, true); err != nil {
   368  				return fmt.Errorf("failed kvs update: %s", err)
   369  			}
   370  
   371  			// Apply the lock delay if present.
   372  			if delay > 0 {
   373  				s.lockDelay.SetExpiration(e.Key, now, delay)
   374  			}
   375  		}
   376  	case structs.SessionKeysDelete:
   377  		for _, obj := range kvs {
   378  			e := obj.(*structs.DirEntry)
   379  			if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil {
   380  				return fmt.Errorf("failed kvs delete: %s", err)
   381  			}
   382  
   383  			// Apply the lock delay if present.
   384  			if delay > 0 {
   385  				s.lockDelay.SetExpiration(e.Key, now, delay)
   386  			}
   387  		}
   388  	default:
   389  		return fmt.Errorf("unknown session behavior %#v", session.Behavior)
   390  	}
   391  
   392  	// Delete any check mappings.
   393  	mappings, err := tx.Get("session_checks", "session", sessionID)
   394  	if err != nil {
   395  		return fmt.Errorf("failed session checks lookup: %s", err)
   396  	}
   397  	{
   398  		var objs []interface{}
   399  		for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() {
   400  			objs = append(objs, mapping)
   401  		}
   402  
   403  		// Do the delete in a separate loop so we don't trash the iterator.
   404  		for _, obj := range objs {
   405  			if err := tx.Delete("session_checks", obj); err != nil {
   406  				return fmt.Errorf("failed deleting session check: %s", err)
   407  			}
   408  		}
   409  	}
   410  
   411  	// Delete any prepared queries.
   412  	queries, err := tx.Get("prepared-queries", "session", sessionID)
   413  	if err != nil {
   414  		return fmt.Errorf("failed prepared query lookup: %s", err)
   415  	}
   416  	{
   417  		var ids []string
   418  		for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() {
   419  			ids = append(ids, toPreparedQuery(wrapped).ID)
   420  		}
   421  
   422  		// Do the delete in a separate loop so we don't trash the iterator.
   423  		for _, id := range ids {
   424  			if err := s.preparedQueryDeleteTxn(tx, idx, id); err != nil {
   425  				return fmt.Errorf("failed prepared query delete: %s", err)
   426  			}
   427  		}
   428  	}
   429  
   430  	return nil
   431  }