github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/extensions/rest_storage.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package extensions
    13  
    14  import (
    15  	"io"
    16  	"net/http"
    17  )
    18  
    19  type RESTHandlers struct {
    20  	ls    LoaderStorer
    21  	proxy Proxy
    22  }
    23  
    24  func NewRESTHandlers(ls LoaderStorer, proxy Proxy) *RESTHandlers {
    25  	return &RESTHandlers{
    26  		ls:    ls,
    27  		proxy: proxy,
    28  	}
    29  }
    30  
    31  type RESTStorageHandlers struct {
    32  	ls LoaderStorer
    33  }
    34  
    35  func newRESTStorageHandlers(ls LoaderStorer) *RESTStorageHandlers {
    36  	return &RESTStorageHandlers{
    37  		ls: ls,
    38  	}
    39  }
    40  
    41  func (h *RESTHandlers) StorageHandler() http.Handler {
    42  	return newRESTStorageHandlers(h.ls).Handler()
    43  }
    44  
    45  func (h *RESTStorageHandlers) Handler() http.Handler {
    46  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    47  		switch r.Method {
    48  		case http.MethodGet:
    49  			h.get(w, r)
    50  		case http.MethodPut:
    51  			h.put(w, r)
    52  		default:
    53  			w.WriteHeader(http.StatusMethodNotAllowed)
    54  		}
    55  	})
    56  }
    57  
    58  func (h *RESTStorageHandlers) get(w http.ResponseWriter, r *http.Request) {
    59  	if len(r.URL.String()) == 0 || h.extractConcept(r) == "" {
    60  		h.getAll(w, r)
    61  		return
    62  	}
    63  
    64  	h.getOne(w, r)
    65  }
    66  
    67  func (h *RESTStorageHandlers) getOne(w http.ResponseWriter, r *http.Request) {
    68  	concept := h.extractConcept(r)
    69  	if concept == "" {
    70  		w.WriteHeader(http.StatusNotFound)
    71  		return
    72  	}
    73  
    74  	res, err := h.ls.Load(concept)
    75  	if err != nil {
    76  		w.WriteHeader(http.StatusInternalServerError)
    77  		w.Write([]byte(err.Error()))
    78  		return
    79  	}
    80  
    81  	if res == nil {
    82  		w.WriteHeader(http.StatusNotFound)
    83  		return
    84  	}
    85  
    86  	w.Write(res)
    87  }
    88  
    89  func (h *RESTStorageHandlers) getAll(w http.ResponseWriter, r *http.Request) {
    90  	res, err := h.ls.LoadAll()
    91  	if err != nil {
    92  		w.WriteHeader(http.StatusInternalServerError)
    93  		w.Write([]byte(err.Error()))
    94  		return
    95  	}
    96  
    97  	w.Write(res)
    98  }
    99  
   100  func (h *RESTStorageHandlers) put(w http.ResponseWriter, r *http.Request) {
   101  	defer r.Body.Close()
   102  	concept := h.extractConcept(r)
   103  	if len(concept) == 0 {
   104  		w.WriteHeader(http.StatusNotFound)
   105  		return
   106  	}
   107  
   108  	body, err := io.ReadAll(r.Body)
   109  	if err != nil {
   110  		w.WriteHeader(http.StatusInternalServerError)
   111  		w.Write([]byte(err.Error()))
   112  	}
   113  
   114  	err = h.ls.Store(concept, body)
   115  	if err != nil {
   116  		w.WriteHeader(http.StatusInternalServerError)
   117  		w.Write([]byte(err.Error()))
   118  	}
   119  }
   120  
   121  func (h *RESTStorageHandlers) extractConcept(r *http.Request) string {
   122  	// cutoff leading slash, consider the rest the concept
   123  	return r.URL.String()[1:]
   124  }
   125  
   126  type Storer interface {
   127  	Store(concept string, value []byte) error
   128  }
   129  
   130  type Loader interface {
   131  	Load(concept string) ([]byte, error)
   132  }
   133  
   134  type LoaderAller interface {
   135  	LoadAll() ([]byte, error)
   136  }
   137  
   138  type LoaderStorer interface {
   139  	Storer
   140  	Loader
   141  	LoaderAller
   142  }