github.com/cilium/statedb@v0.3.2/http.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package statedb
     5  
     6  import (
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"time"
    13  
    14  	"github.com/cilium/statedb/part"
    15  )
    16  
    17  func (db *DB) HTTPHandler() http.Handler {
    18  	h := dbHandler{db}
    19  	mux := http.NewServeMux()
    20  	mux.HandleFunc("GET /dump", h.dumpAll)
    21  	mux.HandleFunc("GET /dump/{table}", h.dumpTable)
    22  	mux.HandleFunc("GET /query", h.query)
    23  	mux.HandleFunc("GET /changes/{table}", h.changes)
    24  	return mux
    25  }
    26  
    27  type dbHandler struct {
    28  	db *DB
    29  }
    30  
    31  func (h dbHandler) dumpAll(w http.ResponseWriter, r *http.Request) {
    32  	w.Header().Add("Content-Type", "application/json")
    33  	w.WriteHeader(http.StatusOK)
    34  	h.db.ReadTxn().WriteJSON(w)
    35  }
    36  
    37  func (h dbHandler) dumpTable(w http.ResponseWriter, r *http.Request) {
    38  	w.Header().Add("Content-Type", "application/json")
    39  	w.WriteHeader(http.StatusOK)
    40  
    41  	var err error
    42  	if table := r.PathValue("table"); table != "" {
    43  		err = h.db.ReadTxn().WriteJSON(w, r.PathValue("table"))
    44  	} else {
    45  		err = h.db.ReadTxn().WriteJSON(w)
    46  	}
    47  	if err != nil {
    48  		panic(err)
    49  	}
    50  }
    51  
    52  func (h dbHandler) query(w http.ResponseWriter, r *http.Request) {
    53  	enc := json.NewEncoder(w)
    54  
    55  	var req QueryRequest
    56  	body, err := io.ReadAll(r.Body)
    57  	r.Body.Close()
    58  	if err != nil {
    59  		w.WriteHeader(http.StatusBadRequest)
    60  		enc.Encode(QueryResponse{Err: err.Error()})
    61  		return
    62  	}
    63  
    64  	if err := json.Unmarshal(body, &req); err != nil {
    65  		w.WriteHeader(http.StatusBadRequest)
    66  		enc.Encode(QueryResponse{Err: err.Error()})
    67  		return
    68  	}
    69  
    70  	queryKey, err := base64.StdEncoding.DecodeString(req.Key)
    71  	if err != nil {
    72  		w.WriteHeader(http.StatusBadRequest)
    73  		enc.Encode(QueryResponse{Err: err.Error()})
    74  		return
    75  	}
    76  
    77  	txn := h.db.ReadTxn().getTxn()
    78  
    79  	// Look up the table
    80  	var table TableMeta
    81  	for _, e := range txn.root {
    82  		if e.meta.Name() == req.Table {
    83  			table = e.meta
    84  			break
    85  		}
    86  	}
    87  	if table == nil {
    88  		w.WriteHeader(http.StatusNotFound)
    89  		enc.Encode(QueryResponse{Err: fmt.Sprintf("Table %q not found", req.Table)})
    90  		return
    91  	}
    92  
    93  	indexPos := table.indexPos(req.Index)
    94  
    95  	indexTxn, err := txn.indexReadTxn(table, indexPos)
    96  	if err != nil {
    97  		w.WriteHeader(http.StatusBadRequest)
    98  		enc.Encode(QueryResponse{Err: err.Error()})
    99  		return
   100  	}
   101  
   102  	w.WriteHeader(http.StatusOK)
   103  	onObject := func(obj object) error {
   104  		return enc.Encode(QueryResponse{
   105  			Rev: obj.revision,
   106  			Obj: obj.data,
   107  		})
   108  	}
   109  	runQuery(indexTxn, req.LowerBound, queryKey, onObject)
   110  }
   111  
   112  type QueryRequest struct {
   113  	Key        string `json:"key"` // Base64 encoded query key
   114  	Table      string `json:"table"`
   115  	Index      string `json:"index"`
   116  	LowerBound bool   `json:"lowerbound"`
   117  }
   118  
   119  type QueryResponse struct {
   120  	Rev uint64 `json:"rev"`
   121  	Obj any    `json:"obj"`
   122  	Err string `json:"err,omitempty"`
   123  }
   124  
   125  func runQuery(indexTxn indexReadTxn, lowerbound bool, queryKey []byte, onObject func(object) error) {
   126  	var iter *part.Iterator[object]
   127  	if lowerbound {
   128  		iter = indexTxn.LowerBound(queryKey)
   129  	} else {
   130  		iter, _ = indexTxn.Prefix(queryKey)
   131  	}
   132  	var match func([]byte) bool
   133  	switch {
   134  	case lowerbound:
   135  		match = func([]byte) bool { return true }
   136  	case indexTxn.unique:
   137  		match = func(k []byte) bool { return len(k) == len(queryKey) }
   138  	default:
   139  		match = func(k []byte) bool {
   140  			secondary, _ := decodeNonUniqueKey(k)
   141  			return len(secondary) == len(queryKey)
   142  		}
   143  	}
   144  	for key, obj, ok := iter.Next(); ok; key, obj, ok = iter.Next() {
   145  		if !match(key) {
   146  			continue
   147  		}
   148  		if err := onObject(obj); err != nil {
   149  			panic(err)
   150  		}
   151  	}
   152  }
   153  
   154  func (h dbHandler) changes(w http.ResponseWriter, r *http.Request) {
   155  	const keepaliveInterval = 30 * time.Second
   156  
   157  	enc := json.NewEncoder(w)
   158  	tableName := r.PathValue("table")
   159  
   160  	// Look up the table
   161  	var tableMeta TableMeta
   162  	for _, e := range h.db.ReadTxn().getTxn().root {
   163  		if e.meta.Name() == tableName {
   164  			tableMeta = e.meta
   165  			break
   166  		}
   167  	}
   168  	if tableMeta == nil {
   169  		w.WriteHeader(http.StatusNotFound)
   170  		enc.Encode(QueryResponse{Err: fmt.Sprintf("Table %q not found", tableName)})
   171  		return
   172  	}
   173  
   174  	// Register for changes.
   175  	wtxn := h.db.WriteTxn(tableMeta)
   176  	changeIter, err := tableMeta.anyChanges(wtxn)
   177  	wtxn.Commit()
   178  	if err != nil {
   179  		w.WriteHeader(http.StatusInternalServerError)
   180  		return
   181  	}
   182  
   183  	w.WriteHeader(http.StatusOK)
   184  
   185  	ticker := time.NewTicker(keepaliveInterval)
   186  	defer ticker.Stop()
   187  
   188  	for {
   189  		changes, watch := changeIter.nextAny(h.db.ReadTxn())
   190  		for change := range changes {
   191  			err := enc.Encode(change)
   192  			if err != nil {
   193  				panic(err)
   194  			}
   195  		}
   196  		w.(http.Flusher).Flush()
   197  		select {
   198  		case <-r.Context().Done():
   199  			return
   200  
   201  		case <-ticker.C:
   202  			// Send an empty keep-alive
   203  			enc.Encode(Change[any]{})
   204  
   205  		case <-watch:
   206  		}
   207  	}
   208  }