github.com/xmidt-org/webpa-common@v1.11.9/device/devicegate/filterHandler.go (about)

     1  package devicegate
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/http"
    10  
    11  	"github.com/go-kit/kit/log"
    12  	"github.com/go-kit/kit/log/level"
    13  	"github.com/xmidt-org/webpa-common/logging"
    14  	"github.com/xmidt-org/webpa-common/xhttp"
    15  )
    16  
    17  // ContextKey is a custom type for setting keys in a request's context
    18  type ContextKey string
    19  
    20  const gateKey ContextKey = "gate"
    21  
    22  // FilterHandler is an http.Handler that can get, add, and delete filters from a devicegate Interface
    23  type FilterHandler struct {
    24  	Gate Interface
    25  }
    26  
    27  // GateLogger is used to log extra details about the gate
    28  type GateLogger struct {
    29  	Logger log.Logger
    30  }
    31  
    32  // GetFilters is a handler function that gets all of the filters set on a gate
    33  func (fh *FilterHandler) GetFilters(response http.ResponseWriter, request *http.Request) {
    34  	response.Header().Set("Content-Type", "application/json")
    35  	JSON, _ := json.Marshal(fh.Gate)
    36  	fmt.Fprintf(response, `%s`, JSON)
    37  }
    38  
    39  // UpdateFilters is a handler function that updates the filters stored in a gate
    40  func (fh *FilterHandler) UpdateFilters(response http.ResponseWriter, request *http.Request) {
    41  	logger := logging.GetLogger(request.Context())
    42  
    43  	message, err := validateRequestBody(request)
    44  
    45  	if err != nil {
    46  		logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error with request body", logging.ErrorKey(), err)
    47  		xhttp.WriteError(response, http.StatusBadRequest, err)
    48  		return
    49  	}
    50  
    51  	if allow, err := checkRequestDetails(message, fh.Gate, true); !allow {
    52  		logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), err)
    53  		xhttp.WriteError(response, http.StatusBadRequest, err)
    54  		return
    55  	}
    56  
    57  	if _, created := fh.Gate.SetFilter(message.Key, message.Values); created {
    58  		response.WriteHeader(http.StatusCreated)
    59  	} else {
    60  		response.WriteHeader(http.StatusOK)
    61  	}
    62  
    63  	newCtx := context.WithValue(request.Context(), gateKey, fh.Gate)
    64  	*request = *request.WithContext(newCtx)
    65  }
    66  
    67  // DeleteFilter is a handler function used to delete a particular filter stored in the gate
    68  func (fh *FilterHandler) DeleteFilter(response http.ResponseWriter, request *http.Request) {
    69  	logger := logging.GetLogger(request.Context())
    70  
    71  	message, err := validateRequestBody(request)
    72  
    73  	if err != nil {
    74  		logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error with request body", logging.ErrorKey(), err)
    75  		xhttp.WriteError(response, http.StatusBadRequest, err)
    76  		return
    77  	}
    78  
    79  	if allow, err := checkRequestDetails(message, fh.Gate, false); !allow {
    80  		logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), err)
    81  		xhttp.WriteError(response, http.StatusBadRequest, err)
    82  		return
    83  	}
    84  
    85  	fh.Gate.DeleteFilter(message.Key)
    86  	response.WriteHeader(http.StatusOK)
    87  
    88  	newCtx := context.WithValue(request.Context(), gateKey, fh.Gate)
    89  	*request = *request.WithContext(newCtx)
    90  }
    91  
    92  // LogFilters is a decorator that logs the updated filters list and writes the updated list in the response body
    93  func (gl GateLogger) LogFilters(next http.Handler) http.Handler {
    94  	return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
    95  		next.ServeHTTP(response, request)
    96  
    97  		if gate, ok := request.Context().Value(gateKey).(Interface); ok {
    98  			if filtersJSON, err := json.Marshal(gate); err == nil {
    99  				response.Header().Set("Content-Type", "application/json")
   100  				fmt.Fprintf(response, `%s`, filtersJSON)
   101  				gl.Logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "gate filters updated", "filters", string(filtersJSON))
   102  			} else {
   103  				gl.Logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error with unmarshalling gate", logging.ErrorKey(), err)
   104  			}
   105  		} else {
   106  			gl.Logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "gate not found in request context")
   107  		}
   108  
   109  	})
   110  
   111  }
   112  
   113  // check that a message body is can be read and unmarshalled
   114  func validateRequestBody(request *http.Request) (FilterRequest, error) {
   115  	var message FilterRequest
   116  	msgBytes, err := ioutil.ReadAll(request.Body)
   117  	request.Body.Close()
   118  
   119  	if err != nil {
   120  		return message, err
   121  	}
   122  
   123  	if e := json.Unmarshal(msgBytes, &message); e != nil {
   124  		return message, e
   125  	}
   126  
   127  	return message, nil
   128  
   129  }
   130  
   131  // validate content of request body
   132  func checkRequestDetails(f FilterRequest, gate Interface, checkFilterValues bool) (bool, error) {
   133  	if len(f.Key) == 0 {
   134  		return false, errors.New("missing filter key")
   135  	}
   136  
   137  	if checkFilterValues {
   138  		if len(f.Values) == 0 {
   139  			return false, errors.New("missing filter values")
   140  		}
   141  
   142  		if allowedFilters, allowedFiltersFound := gate.GetAllowedFilters(); allowedFiltersFound {
   143  			if !allowedFilters.Has(f.Key) {
   144  				allowedFiltersJSON, _ := json.Marshal(allowedFilters)
   145  				return false, fmt.Errorf("filter key %s is not allowed. Allowed filters: %s", f.Key, allowedFiltersJSON)
   146  			}
   147  		}
   148  	}
   149  
   150  	return true, nil
   151  }