github.com/outbrain/consul@v1.4.5/agent/txn_endpoint.go (about)

     1  package agent
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/hashicorp/consul/agent/structs"
    10  	"github.com/hashicorp/consul/api"
    11  	"github.com/hashicorp/consul/types"
    12  )
    13  
    14  const (
    15  	// maxTxnOps is used to set an upper limit on the number of operations
    16  	// inside a transaction. If there are more operations than this, then the
    17  	// client is likely abusing transactions.
    18  	maxTxnOps = 64
    19  )
    20  
    21  // decodeValue decodes the value member of the given operation.
    22  func decodeValue(rawKV interface{}) error {
    23  	rawMap, ok := rawKV.(map[string]interface{})
    24  	if !ok {
    25  		return fmt.Errorf("unexpected raw KV type: %T", rawKV)
    26  	}
    27  	for k, v := range rawMap {
    28  		switch strings.ToLower(k) {
    29  		case "value":
    30  			// Leave the byte slice nil if we have a nil
    31  			// value.
    32  			if v == nil {
    33  				return nil
    34  			}
    35  
    36  			// Otherwise, base64 decode it.
    37  			s, ok := v.(string)
    38  			if !ok {
    39  				return fmt.Errorf("unexpected value type: %T", v)
    40  			}
    41  			decoded, err := base64.StdEncoding.DecodeString(s)
    42  			if err != nil {
    43  				return fmt.Errorf("failed to decode value: %v", err)
    44  			}
    45  			rawMap[k] = decoded
    46  			return nil
    47  		}
    48  	}
    49  	return nil
    50  }
    51  
    52  // fixupTxnOp looks for non-nil Txn operations and passes them on for
    53  // value conversion.
    54  func fixupTxnOp(rawOp interface{}) error {
    55  	rawMap, ok := rawOp.(map[string]interface{})
    56  	if !ok {
    57  		return fmt.Errorf("unexpected raw op type: %T", rawOp)
    58  	}
    59  	for k, v := range rawMap {
    60  		switch strings.ToLower(k) {
    61  		case "kv":
    62  			if v == nil {
    63  				return nil
    64  			}
    65  			return decodeValue(v)
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  // fixupTxnOps takes the raw decoded JSON and base64 decodes values in Txn ops,
    72  // replacing them with byte arrays.
    73  func fixupTxnOps(raw interface{}) error {
    74  	rawSlice, ok := raw.([]interface{})
    75  	if !ok {
    76  		return fmt.Errorf("unexpected raw type: %t", raw)
    77  	}
    78  	for _, rawOp := range rawSlice {
    79  		if err := fixupTxnOp(rawOp); err != nil {
    80  			return err
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  // isWrite returns true if the given operation alters the state store.
    87  func isWrite(op api.KVOp) bool {
    88  	switch op {
    89  	case api.KVSet, api.KVDelete, api.KVDeleteCAS, api.KVDeleteTree, api.KVCAS, api.KVLock, api.KVUnlock:
    90  		return true
    91  	}
    92  	return false
    93  }
    94  
    95  // convertOps takes the incoming body in API format and converts it to the
    96  // internal RPC format. This returns a count of the number of write ops, and
    97  // a boolean, that if false means an error response has been generated and
    98  // processing should stop.
    99  func (s *HTTPServer) convertOps(resp http.ResponseWriter, req *http.Request) (structs.TxnOps, int, bool) {
   100  	// Note the body is in API format, and not the RPC format. If we can't
   101  	// decode it, we will return a 400 since we don't have enough context to
   102  	// associate the error with a given operation.
   103  	var ops api.TxnOps
   104  	if err := decodeBody(req, &ops, fixupTxnOps); err != nil {
   105  		resp.WriteHeader(http.StatusBadRequest)
   106  		fmt.Fprintf(resp, "Failed to parse body: %v", err)
   107  		return nil, 0, false
   108  	}
   109  
   110  	// Enforce a reasonable upper limit on the number of operations in a
   111  	// transaction in order to curb abuse.
   112  	if size := len(ops); size > maxTxnOps {
   113  		resp.WriteHeader(http.StatusRequestEntityTooLarge)
   114  		fmt.Fprintf(resp, "Transaction contains too many operations (%d > %d)",
   115  			size, maxTxnOps)
   116  
   117  		return nil, 0, false
   118  	}
   119  
   120  	// Convert the KV API format into the RPC format. Note that fixupKVOps
   121  	// above will have already converted the base64 encoded strings into
   122  	// byte arrays so we can assign right over.
   123  	var opsRPC structs.TxnOps
   124  	var writes int
   125  	var netKVSize int
   126  	for _, in := range ops {
   127  		switch {
   128  		case in.KV != nil:
   129  			size := len(in.KV.Value)
   130  			if size > maxKVSize {
   131  				resp.WriteHeader(http.StatusRequestEntityTooLarge)
   132  				fmt.Fprintf(resp, "Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, maxKVSize)
   133  				return nil, 0, false
   134  			}
   135  			netKVSize += size
   136  
   137  			verb := api.KVOp(in.KV.Verb)
   138  			if isWrite(verb) {
   139  				writes++
   140  			}
   141  
   142  			out := &structs.TxnOp{
   143  				KV: &structs.TxnKVOp{
   144  					Verb: verb,
   145  					DirEnt: structs.DirEntry{
   146  						Key:     in.KV.Key,
   147  						Value:   in.KV.Value,
   148  						Flags:   in.KV.Flags,
   149  						Session: in.KV.Session,
   150  						RaftIndex: structs.RaftIndex{
   151  							ModifyIndex: in.KV.Index,
   152  						},
   153  					},
   154  				},
   155  			}
   156  			opsRPC = append(opsRPC, out)
   157  
   158  		case in.Node != nil:
   159  			if in.Node.Verb != api.NodeGet {
   160  				writes++
   161  			}
   162  
   163  			// Setup the default DC if not provided
   164  			if in.Node.Node.Datacenter == "" {
   165  				in.Node.Node.Datacenter = s.agent.config.Datacenter
   166  			}
   167  
   168  			node := in.Node.Node
   169  			out := &structs.TxnOp{
   170  				Node: &structs.TxnNodeOp{
   171  					Verb: in.Node.Verb,
   172  					Node: structs.Node{
   173  						ID:              types.NodeID(node.ID),
   174  						Node:            node.Node,
   175  						Address:         node.Address,
   176  						Datacenter:      node.Datacenter,
   177  						TaggedAddresses: node.TaggedAddresses,
   178  						Meta:            node.Meta,
   179  						RaftIndex: structs.RaftIndex{
   180  							ModifyIndex: node.ModifyIndex,
   181  						},
   182  					},
   183  				},
   184  			}
   185  			opsRPC = append(opsRPC, out)
   186  
   187  		case in.Service != nil:
   188  			if in.Service.Verb != api.ServiceGet {
   189  				writes++
   190  			}
   191  
   192  			svc := in.Service.Service
   193  			out := &structs.TxnOp{
   194  				Service: &structs.TxnServiceOp{
   195  					Verb: in.Service.Verb,
   196  					Node: in.Service.Node,
   197  					Service: structs.NodeService{
   198  						ID:      svc.ID,
   199  						Service: svc.Service,
   200  						Tags:    svc.Tags,
   201  						Address: svc.Address,
   202  						Meta:    svc.Meta,
   203  						Port:    svc.Port,
   204  						Weights: &structs.Weights{
   205  							Passing: svc.Weights.Passing,
   206  							Warning: svc.Weights.Warning,
   207  						},
   208  						EnableTagOverride: svc.EnableTagOverride,
   209  						RaftIndex: structs.RaftIndex{
   210  							ModifyIndex: svc.ModifyIndex,
   211  						},
   212  					},
   213  				},
   214  			}
   215  			opsRPC = append(opsRPC, out)
   216  
   217  		case in.Check != nil:
   218  			if in.Check.Verb != api.CheckGet {
   219  				writes++
   220  			}
   221  
   222  			check := in.Check.Check
   223  			out := &structs.TxnOp{
   224  				Check: &structs.TxnCheckOp{
   225  					Verb: in.Check.Verb,
   226  					Check: structs.HealthCheck{
   227  						Node:        check.Node,
   228  						CheckID:     types.CheckID(check.CheckID),
   229  						Name:        check.Name,
   230  						Status:      check.Status,
   231  						Notes:       check.Notes,
   232  						Output:      check.Output,
   233  						ServiceID:   check.ServiceID,
   234  						ServiceName: check.ServiceName,
   235  						ServiceTags: check.ServiceTags,
   236  						Definition: structs.HealthCheckDefinition{
   237  							HTTP:                           check.Definition.HTTP,
   238  							TLSSkipVerify:                  check.Definition.TLSSkipVerify,
   239  							Header:                         check.Definition.Header,
   240  							Method:                         check.Definition.Method,
   241  							TCP:                            check.Definition.TCP,
   242  							Interval:                       check.Definition.IntervalDuration,
   243  							Timeout:                        check.Definition.TimeoutDuration,
   244  							DeregisterCriticalServiceAfter: check.Definition.DeregisterCriticalServiceAfterDuration,
   245  						},
   246  						RaftIndex: structs.RaftIndex{
   247  							ModifyIndex: check.ModifyIndex,
   248  						},
   249  					},
   250  				},
   251  			}
   252  			opsRPC = append(opsRPC, out)
   253  		}
   254  	}
   255  
   256  	// Enforce an overall size limit to help prevent abuse.
   257  	if netKVSize > maxKVSize {
   258  		resp.WriteHeader(http.StatusRequestEntityTooLarge)
   259  		fmt.Fprintf(resp, "Cumulative size of key data is too large (%d > %d bytes)",
   260  			netKVSize, maxKVSize)
   261  
   262  		return nil, 0, false
   263  	}
   264  
   265  	return opsRPC, writes, true
   266  }
   267  
   268  // Txn handles requests to apply multiple operations in a single, atomic
   269  // transaction. A transaction consisting of only read operations will be fast-
   270  // pathed to an endpoint that supports consistency modes (but not blocking),
   271  // and everything else will be routed through Raft like a normal write.
   272  func (s *HTTPServer) Txn(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   273  	// Convert the ops from the API format to the internal format.
   274  	ops, writes, ok := s.convertOps(resp, req)
   275  	if !ok {
   276  		return nil, nil
   277  	}
   278  
   279  	// Fast-path a transaction with only writes to the read-only endpoint,
   280  	// which bypasses Raft, and allows for staleness.
   281  	conflict := false
   282  	var ret interface{}
   283  	if writes == 0 {
   284  		args := structs.TxnReadRequest{Ops: ops}
   285  		if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
   286  			return nil, nil
   287  		}
   288  
   289  		var reply structs.TxnReadResponse
   290  		if err := s.agent.RPC("Txn.Read", &args, &reply); err != nil {
   291  			return nil, err
   292  		}
   293  
   294  		// Since we don't do blocking, we only add the relevant headers
   295  		// for metadata.
   296  		setLastContact(resp, reply.LastContact)
   297  		setKnownLeader(resp, reply.KnownLeader)
   298  
   299  		ret, conflict = reply, len(reply.Errors) > 0
   300  	} else {
   301  		args := structs.TxnRequest{Ops: ops}
   302  		s.parseDC(req, &args.Datacenter)
   303  		s.parseToken(req, &args.Token)
   304  
   305  		var reply structs.TxnResponse
   306  		if err := s.agent.RPC("Txn.Apply", &args, &reply); err != nil {
   307  			return nil, err
   308  		}
   309  		ret, conflict = reply, len(reply.Errors) > 0
   310  	}
   311  
   312  	// If there was a conflict return the response object but set a special
   313  	// status code.
   314  	if conflict {
   315  		var buf []byte
   316  		var err error
   317  		buf, err = s.marshalJSON(req, ret)
   318  		if err != nil {
   319  			return nil, err
   320  		}
   321  
   322  		resp.Header().Set("Content-Type", "application/json")
   323  		resp.WriteHeader(http.StatusConflict)
   324  		resp.Write(buf)
   325  		return nil, nil
   326  	}
   327  
   328  	// Otherwise, return the results of the successful transaction.
   329  	return ret, nil
   330  }