github.com/safing/portbase@v0.19.5/api/database.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"sync"
    10  
    11  	"github.com/gorilla/websocket"
    12  	"github.com/tevino/abool"
    13  	"github.com/tidwall/gjson"
    14  	"github.com/tidwall/sjson"
    15  
    16  	"github.com/safing/portbase/container"
    17  	"github.com/safing/portbase/database"
    18  	"github.com/safing/portbase/database/iterator"
    19  	"github.com/safing/portbase/database/query"
    20  	"github.com/safing/portbase/database/record"
    21  	"github.com/safing/portbase/formats/dsd"
    22  	"github.com/safing/portbase/formats/varint"
    23  	"github.com/safing/portbase/log"
    24  )
    25  
    26  const (
    27  	dbMsgTypeOk      = "ok"
    28  	dbMsgTypeError   = "error"
    29  	dbMsgTypeDone    = "done"
    30  	dbMsgTypeSuccess = "success"
    31  	dbMsgTypeUpd     = "upd"
    32  	dbMsgTypeNew     = "new"
    33  	dbMsgTypeDel     = "del"
    34  	dbMsgTypeWarning = "warning"
    35  
    36  	dbAPISeperator = "|"
    37  	emptyString    = ""
    38  )
    39  
    40  var (
    41  	dbAPISeperatorBytes       = []byte(dbAPISeperator)
    42  	dbCompatibilityPermission = PermitAdmin
    43  )
    44  
    45  func init() {
    46  	RegisterHandler("/api/database/v1", WrapInAuthHandler(
    47  		startDatabaseWebsocketAPI,
    48  		// Default to admin read/write permissions until the database gets support
    49  		// for api permissions.
    50  		dbCompatibilityPermission,
    51  		dbCompatibilityPermission,
    52  	))
    53  }
    54  
    55  // DatabaseAPI is a generic database API interface.
    56  type DatabaseAPI struct {
    57  	queriesLock sync.Mutex
    58  	queries     map[string]*iterator.Iterator
    59  
    60  	subsLock sync.Mutex
    61  	subs     map[string]*database.Subscription
    62  
    63  	shutdownSignal chan struct{}
    64  	shuttingDown   *abool.AtomicBool
    65  	db             *database.Interface
    66  
    67  	sendBytes func(data []byte)
    68  }
    69  
    70  // DatabaseWebsocketAPI is a database websocket API interface.
    71  type DatabaseWebsocketAPI struct {
    72  	DatabaseAPI
    73  
    74  	sendQueue chan []byte
    75  	conn      *websocket.Conn
    76  }
    77  
    78  func allowAnyOrigin(r *http.Request) bool {
    79  	return true
    80  }
    81  
    82  // CreateDatabaseAPI creates a new database interface.
    83  func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
    84  	return DatabaseAPI{
    85  		queries:        make(map[string]*iterator.Iterator),
    86  		subs:           make(map[string]*database.Subscription),
    87  		shutdownSignal: make(chan struct{}),
    88  		shuttingDown:   abool.NewBool(false),
    89  		db:             database.NewInterface(nil),
    90  		sendBytes:      sendFunction,
    91  	}
    92  }
    93  
    94  func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
    95  	upgrader := websocket.Upgrader{
    96  		CheckOrigin:     allowAnyOrigin,
    97  		ReadBufferSize:  1024,
    98  		WriteBufferSize: 65536,
    99  	}
   100  	wsConn, err := upgrader.Upgrade(w, r, nil)
   101  	if err != nil {
   102  		errMsg := fmt.Sprintf("could not upgrade: %s", err)
   103  		log.Error(errMsg)
   104  		http.Error(w, errMsg, http.StatusBadRequest)
   105  		return
   106  	}
   107  
   108  	newDBAPI := &DatabaseWebsocketAPI{
   109  		DatabaseAPI: DatabaseAPI{
   110  			queries:        make(map[string]*iterator.Iterator),
   111  			subs:           make(map[string]*database.Subscription),
   112  			shutdownSignal: make(chan struct{}),
   113  			shuttingDown:   abool.NewBool(false),
   114  			db:             database.NewInterface(nil),
   115  		},
   116  
   117  		sendQueue: make(chan []byte, 100),
   118  		conn:      wsConn,
   119  	}
   120  
   121  	newDBAPI.sendBytes = func(data []byte) {
   122  		newDBAPI.sendQueue <- data
   123  	}
   124  
   125  	module.StartWorker("database api handler", newDBAPI.handler)
   126  	module.StartWorker("database api writer", newDBAPI.writer)
   127  
   128  	log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
   129  }
   130  
   131  func (api *DatabaseWebsocketAPI) handler(context.Context) error {
   132  	defer func() {
   133  		_ = api.shutdown(nil)
   134  	}()
   135  
   136  	for {
   137  		_, msg, err := api.conn.ReadMessage()
   138  		if err != nil {
   139  			return api.shutdown(err)
   140  		}
   141  
   142  		api.Handle(msg)
   143  	}
   144  }
   145  
   146  func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error {
   147  	defer func() {
   148  		_ = api.shutdown(nil)
   149  	}()
   150  
   151  	var data []byte
   152  	var err error
   153  
   154  	for {
   155  		select {
   156  		// prioritize direct writes
   157  		case data = <-api.sendQueue:
   158  			if len(data) == 0 {
   159  				return nil
   160  			}
   161  		case <-ctx.Done():
   162  			return nil
   163  		case <-api.shutdownSignal:
   164  			return nil
   165  		}
   166  
   167  		// log.Tracef("api: sending %s", string(*msg))
   168  		err = api.conn.WriteMessage(websocket.BinaryMessage, data)
   169  		if err != nil {
   170  			return api.shutdown(err)
   171  		}
   172  	}
   173  }
   174  
   175  func (api *DatabaseWebsocketAPI) shutdown(err error) error {
   176  	// Check if we are the first to shut down.
   177  	if !api.shuttingDown.SetToIf(false, true) {
   178  		return nil
   179  	}
   180  
   181  	// Check the given error.
   182  	if err != nil {
   183  		if websocket.IsCloseError(err,
   184  			websocket.CloseNormalClosure,
   185  			websocket.CloseGoingAway,
   186  			websocket.CloseAbnormalClosure,
   187  		) {
   188  			log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
   189  		} else {
   190  			log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
   191  		}
   192  	}
   193  
   194  	// Trigger shutdown.
   195  	close(api.shutdownSignal)
   196  	_ = api.conn.Close()
   197  	return nil
   198  }
   199  
   200  // Handle handles a message for the database API.
   201  func (api *DatabaseAPI) Handle(msg []byte) {
   202  	// 123|get|<key>
   203  	//    123|ok|<key>|<data>
   204  	//    123|error|<message>
   205  	// 124|query|<query>
   206  	//    124|ok|<key>|<data>
   207  	//    124|done
   208  	//    124|error|<message>
   209  	//    124|warning|<message> // error with single record, operation continues
   210  	// 124|cancel
   211  	// 125|sub|<query>
   212  	//    125|upd|<key>|<data>
   213  	//    125|new|<key>|<data>
   214  	//    127|del|<key>
   215  	//    125|warning|<message> // error with single record, operation continues
   216  	// 125|cancel
   217  	// 127|qsub|<query>
   218  	//    127|ok|<key>|<data>
   219  	//    127|done
   220  	//    127|error|<message>
   221  	//    127|upd|<key>|<data>
   222  	//    127|new|<key>|<data>
   223  	//    127|del|<key>
   224  	//    127|warning|<message> // error with single record, operation continues
   225  	// 127|cancel
   226  
   227  	// 128|create|<key>|<data>
   228  	//    128|success
   229  	//    128|error|<message>
   230  	// 129|update|<key>|<data>
   231  	//    129|success
   232  	//    129|error|<message>
   233  	// 130|insert|<key>|<data>
   234  	//    130|success
   235  	//    130|error|<message>
   236  	// 131|delete|<key>
   237  	//    131|success
   238  	//    131|error|<message>
   239  
   240  	parts := bytes.SplitN(msg, []byte("|"), 3)
   241  
   242  	// Handle special command "cancel"
   243  	if len(parts) == 2 && string(parts[1]) == "cancel" {
   244  		// 124|cancel
   245  		// 125|cancel
   246  		// 127|cancel
   247  		go api.handleCancel(parts[0])
   248  		return
   249  	}
   250  
   251  	if len(parts) != 3 {
   252  		api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
   253  		return
   254  	}
   255  
   256  	switch string(parts[1]) {
   257  	case "get":
   258  		// 123|get|<key>
   259  		go api.handleGet(parts[0], string(parts[2]))
   260  	case "query":
   261  		// 124|query|<query>
   262  		go api.handleQuery(parts[0], string(parts[2]))
   263  	case "sub":
   264  		// 125|sub|<query>
   265  		go api.handleSub(parts[0], string(parts[2]))
   266  	case "qsub":
   267  		// 127|qsub|<query>
   268  		go api.handleQsub(parts[0], string(parts[2]))
   269  	case "create", "update", "insert":
   270  		// split key and payload
   271  		dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
   272  		if len(dataParts) != 2 {
   273  			api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
   274  			return
   275  		}
   276  
   277  		switch string(parts[1]) {
   278  		case "create":
   279  			// 128|create|<key>|<data>
   280  			go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
   281  		case "update":
   282  			// 129|update|<key>|<data>
   283  			go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
   284  		case "insert":
   285  			// 130|insert|<key>|<data>
   286  			go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
   287  		}
   288  	case "delete":
   289  		// 131|delete|<key>
   290  		go api.handleDelete(parts[0], string(parts[2]))
   291  	default:
   292  		api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
   293  	}
   294  }
   295  
   296  func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
   297  	c := container.New(opID)
   298  	c.Append(dbAPISeperatorBytes)
   299  	c.Append([]byte(msgType))
   300  
   301  	if msgOrKey != emptyString {
   302  		c.Append(dbAPISeperatorBytes)
   303  		c.Append([]byte(msgOrKey))
   304  	}
   305  
   306  	if len(data) > 0 {
   307  		c.Append(dbAPISeperatorBytes)
   308  		c.Append(data)
   309  	}
   310  
   311  	api.sendBytes(c.CompileData())
   312  }
   313  
   314  func (api *DatabaseAPI) handleGet(opID []byte, key string) {
   315  	// 123|get|<key>
   316  	//    123|ok|<key>|<data>
   317  	//    123|error|<message>
   318  
   319  	var data []byte
   320  
   321  	r, err := api.db.Get(key)
   322  	if err == nil {
   323  		data, err = MarshalRecord(r, true)
   324  	}
   325  	if err != nil {
   326  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   327  		return
   328  	}
   329  	api.send(opID, dbMsgTypeOk, r.Key(), data)
   330  }
   331  
   332  func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) {
   333  	// 124|query|<query>
   334  	//    124|ok|<key>|<data>
   335  	//    124|done
   336  	//    124|warning|<message>
   337  	//    124|error|<message>
   338  	//    124|warning|<message> // error with single record, operation continues
   339  	// 124|cancel
   340  
   341  	var err error
   342  
   343  	q, err := query.ParseQuery(queryText)
   344  	if err != nil {
   345  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   346  		return
   347  	}
   348  
   349  	api.processQuery(opID, q)
   350  }
   351  
   352  func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
   353  	it, err := api.db.Query(q)
   354  	if err != nil {
   355  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   356  		return false
   357  	}
   358  
   359  	// Save query iterator.
   360  	api.queriesLock.Lock()
   361  	api.queries[string(opID)] = it
   362  	api.queriesLock.Unlock()
   363  
   364  	// Remove query iterator after it ended.
   365  	defer func() {
   366  		api.queriesLock.Lock()
   367  		defer api.queriesLock.Unlock()
   368  		delete(api.queries, string(opID))
   369  	}()
   370  
   371  	for {
   372  		select {
   373  		case <-api.shutdownSignal:
   374  			// cancel query and return
   375  			it.Cancel()
   376  			return false
   377  		case r := <-it.Next:
   378  			// process query feed
   379  			if r != nil {
   380  				// process record
   381  				data, err := MarshalRecord(r, true)
   382  				if err != nil {
   383  					api.send(opID, dbMsgTypeWarning, err.Error(), nil)
   384  					continue
   385  				}
   386  				api.send(opID, dbMsgTypeOk, r.Key(), data)
   387  			} else {
   388  				// sub feed ended
   389  				if it.Err() != nil {
   390  					api.send(opID, dbMsgTypeError, it.Err().Error(), nil)
   391  					return false
   392  				}
   393  				api.send(opID, dbMsgTypeDone, emptyString, nil)
   394  				return true
   395  			}
   396  		}
   397  	}
   398  }
   399  
   400  // func (api *DatabaseWebsocketAPI) runQuery()
   401  
   402  func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
   403  	// 125|sub|<query>
   404  	//    125|upd|<key>|<data>
   405  	//    125|new|<key>|<data>
   406  	//    125|delete|<key>
   407  	//    125|warning|<message> // error with single record, operation continues
   408  	// 125|cancel
   409  	var err error
   410  
   411  	q, err := query.ParseQuery(queryText)
   412  	if err != nil {
   413  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   414  		return
   415  	}
   416  
   417  	sub, ok := api.registerSub(opID, q)
   418  	if !ok {
   419  		return
   420  	}
   421  	api.processSub(opID, sub)
   422  }
   423  
   424  func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.Subscription, ok bool) {
   425  	var err error
   426  	sub, err = api.db.Subscribe(q)
   427  	if err != nil {
   428  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   429  		return nil, false
   430  	}
   431  
   432  	return sub, true
   433  }
   434  
   435  func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
   436  	// Save subscription.
   437  	api.subsLock.Lock()
   438  	api.subs[string(opID)] = sub
   439  	api.subsLock.Unlock()
   440  
   441  	// Remove subscription after it ended.
   442  	defer func() {
   443  		api.subsLock.Lock()
   444  		defer api.subsLock.Unlock()
   445  		delete(api.subs, string(opID))
   446  	}()
   447  
   448  	for {
   449  		select {
   450  		case <-api.shutdownSignal:
   451  			// cancel sub and return
   452  			_ = sub.Cancel()
   453  			return
   454  		case r := <-sub.Feed:
   455  			// process sub feed
   456  			if r != nil {
   457  				// process record
   458  				data, err := MarshalRecord(r, true)
   459  				if err != nil {
   460  					api.send(opID, dbMsgTypeWarning, err.Error(), nil)
   461  					continue
   462  				}
   463  				// TODO: use upd, new and delete msgTypes
   464  				r.Lock()
   465  				isDeleted := r.Meta().IsDeleted()
   466  				isNew := r.Meta().Created == r.Meta().Modified
   467  				r.Unlock()
   468  				switch {
   469  				case isDeleted:
   470  					api.send(opID, dbMsgTypeDel, r.Key(), nil)
   471  				case isNew:
   472  					api.send(opID, dbMsgTypeNew, r.Key(), data)
   473  				default:
   474  					api.send(opID, dbMsgTypeUpd, r.Key(), data)
   475  				}
   476  			} else {
   477  				// sub feed ended
   478  				api.send(opID, dbMsgTypeDone, "", nil)
   479  				return
   480  			}
   481  		}
   482  	}
   483  }
   484  
   485  func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {
   486  	// 127|qsub|<query>
   487  	//    127|ok|<key>|<data>
   488  	//    127|done
   489  	//    127|error|<message>
   490  	//    127|upd|<key>|<data>
   491  	//    127|new|<key>|<data>
   492  	//    127|delete|<key>
   493  	//    127|warning|<message> // error with single record, operation continues
   494  	// 127|cancel
   495  
   496  	var err error
   497  
   498  	q, err := query.ParseQuery(queryText)
   499  	if err != nil {
   500  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   501  		return
   502  	}
   503  
   504  	sub, ok := api.registerSub(opID, q)
   505  	if !ok {
   506  		return
   507  	}
   508  	ok = api.processQuery(opID, q)
   509  	if !ok {
   510  		return
   511  	}
   512  	api.processSub(opID, sub)
   513  }
   514  
   515  func (api *DatabaseAPI) handleCancel(opID []byte) {
   516  	api.cancelQuery(opID)
   517  	api.cancelSub(opID)
   518  }
   519  
   520  func (api *DatabaseAPI) cancelQuery(opID []byte) {
   521  	api.queriesLock.Lock()
   522  	defer api.queriesLock.Unlock()
   523  
   524  	// Get subscription from api.
   525  	it, ok := api.queries[string(opID)]
   526  	if !ok {
   527  		// Fail silently as quries end by themselves when finished.
   528  		return
   529  	}
   530  
   531  	// End query.
   532  	it.Cancel()
   533  
   534  	// The query handler will end the communication with a done message.
   535  }
   536  
   537  func (api *DatabaseAPI) cancelSub(opID []byte) {
   538  	api.subsLock.Lock()
   539  	defer api.subsLock.Unlock()
   540  
   541  	// Get subscription from api.
   542  	sub, ok := api.subs[string(opID)]
   543  	if !ok {
   544  		api.send(opID, dbMsgTypeError, "could not find subscription", nil)
   545  		return
   546  	}
   547  
   548  	// End subscription.
   549  	err := sub.Cancel()
   550  	if err != nil {
   551  		api.send(opID, dbMsgTypeError, fmt.Sprintf("failed to cancel subscription: %s", err), nil)
   552  	}
   553  
   554  	// The subscription handler will end the communication with a done message.
   555  }
   556  
   557  func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) {
   558  	// 128|create|<key>|<data>
   559  	//    128|success
   560  	//    128|error|<message>
   561  
   562  	// 129|update|<key>|<data>
   563  	//    129|success
   564  	//    129|error|<message>
   565  
   566  	if len(data) < 2 {
   567  		api.send(opID, dbMsgTypeError, "bad request: malformed message", nil)
   568  		return
   569  	}
   570  
   571  	// TODO - staged for deletion: remove transition code
   572  	// if data[0] != dsd.JSON {
   573  	// 	typedData := make([]byte, len(data)+1)
   574  	// 	typedData[0] = dsd.JSON
   575  	// 	copy(typedData[1:], data)
   576  	// 	data = typedData
   577  	// }
   578  
   579  	r, err := record.NewWrapper(key, nil, data[0], data[1:])
   580  	if err != nil {
   581  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   582  		return
   583  	}
   584  
   585  	if create {
   586  		err = api.db.PutNew(r)
   587  	} else {
   588  		err = api.db.Put(r)
   589  	}
   590  	if err != nil {
   591  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   592  		return
   593  	}
   594  	api.send(opID, dbMsgTypeSuccess, emptyString, nil)
   595  }
   596  
   597  func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) {
   598  	// 130|insert|<key>|<data>
   599  	//    130|success
   600  	//    130|error|<message>
   601  
   602  	r, err := api.db.Get(key)
   603  	if err != nil {
   604  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   605  		return
   606  	}
   607  
   608  	acc := r.GetAccessor(r)
   609  
   610  	result := gjson.ParseBytes(data)
   611  	anythingPresent := false
   612  	var insertError error
   613  	result.ForEach(func(key gjson.Result, value gjson.Result) bool {
   614  		anythingPresent = true
   615  		if !key.Exists() {
   616  			insertError = errors.New("values must be in a map")
   617  			return false
   618  		}
   619  		if key.Type != gjson.String {
   620  			insertError = errors.New("keys must be strings")
   621  			return false
   622  		}
   623  		if !value.Exists() {
   624  			insertError = errors.New("non-existent value")
   625  			return false
   626  		}
   627  		insertError = acc.Set(key.String(), value.Value())
   628  		return insertError == nil
   629  	})
   630  
   631  	if insertError != nil {
   632  		api.send(opID, dbMsgTypeError, insertError.Error(), nil)
   633  		return
   634  	}
   635  	if !anythingPresent {
   636  		api.send(opID, dbMsgTypeError, "could not find any valid values", nil)
   637  		return
   638  	}
   639  
   640  	err = api.db.Put(r)
   641  	if err != nil {
   642  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   643  		return
   644  	}
   645  
   646  	api.send(opID, dbMsgTypeSuccess, emptyString, nil)
   647  }
   648  
   649  func (api *DatabaseAPI) handleDelete(opID []byte, key string) {
   650  	// 131|delete|<key>
   651  	//    131|success
   652  	//    131|error|<message>
   653  
   654  	err := api.db.Delete(key)
   655  	if err != nil {
   656  		api.send(opID, dbMsgTypeError, err.Error(), nil)
   657  		return
   658  	}
   659  	api.send(opID, dbMsgTypeSuccess, emptyString, nil)
   660  }
   661  
   662  // MarshalRecord locks and marshals the given record, additionally adding
   663  // metadata and returning it as json.
   664  func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
   665  	r.Lock()
   666  	defer r.Unlock()
   667  
   668  	// Pour record into JSON.
   669  	jsonData, err := r.Marshal(r, dsd.JSON)
   670  	if err != nil {
   671  		return nil, err
   672  	}
   673  
   674  	// Remove JSON identifier for manual editing.
   675  	jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(dsd.JSON))
   676  
   677  	// Add metadata.
   678  	jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta())
   679  	if err != nil {
   680  		return nil, err
   681  	}
   682  
   683  	// Add database key.
   684  	jsonData, err = sjson.SetBytes(jsonData, "_meta.Key", r.Key())
   685  	if err != nil {
   686  		return nil, err
   687  	}
   688  
   689  	// Add JSON identifier again.
   690  	if withDSDIdentifier {
   691  		formatID := varint.Pack8(dsd.JSON)
   692  		finalData := make([]byte, 0, len(formatID)+len(jsonData))
   693  		finalData = append(finalData, formatID...)
   694  		finalData = append(finalData, jsonData...)
   695  		return finalData, nil
   696  	}
   697  	return jsonData, nil
   698  }