github.com/MontFerret/ferret@v0.18.0/pkg/drivers/headers.go (about)

     1  package drivers
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"hash/fnv"
     8  	"net/textproto"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/wI2L/jettison"
    13  
    14  	"github.com/MontFerret/ferret/pkg/runtime/core"
    15  	"github.com/MontFerret/ferret/pkg/runtime/values"
    16  )
    17  
    18  // HTTPHeaders HTTP header object
    19  type HTTPHeaders struct {
    20  	values map[string][]string
    21  }
    22  
    23  func NewHTTPHeaders() *HTTPHeaders {
    24  	return NewHTTPHeadersWith(make(map[string][]string))
    25  }
    26  
    27  func NewHTTPHeadersWith(values map[string][]string) *HTTPHeaders {
    28  	return &HTTPHeaders{values}
    29  }
    30  
    31  func (h *HTTPHeaders) Length() values.Int {
    32  	return values.NewInt(len(h.values))
    33  }
    34  
    35  func (h *HTTPHeaders) Type() core.Type {
    36  	return HTTPHeaderType
    37  }
    38  
    39  func (h *HTTPHeaders) String() string {
    40  	var buf bytes.Buffer
    41  
    42  	for k := range h.values {
    43  		buf.WriteString(fmt.Sprintf("%s=%s;", k, h.Get(k)))
    44  	}
    45  
    46  	return buf.String()
    47  }
    48  
    49  func (h *HTTPHeaders) Compare(other core.Value) int64 {
    50  	if other.Type() != HTTPHeaderType {
    51  		return Compare(HTTPHeaderType, other.Type())
    52  	}
    53  
    54  	oh := other.(*HTTPHeaders)
    55  
    56  	if len(h.values) > len(oh.values) {
    57  		return 1
    58  	} else if len(h.values) < len(oh.values) {
    59  		return -1
    60  	}
    61  
    62  	for k := range h.values {
    63  		c := strings.Compare(h.Get(k), oh.Get(k))
    64  
    65  		if c != 0 {
    66  			return int64(c)
    67  		}
    68  	}
    69  
    70  	return 0
    71  }
    72  
    73  func (h *HTTPHeaders) Unwrap() interface{} {
    74  	return h.values
    75  }
    76  
    77  func (h *HTTPHeaders) Hash() uint64 {
    78  	hash := fnv.New64a()
    79  
    80  	hash.Write([]byte(h.Type().String()))
    81  	hash.Write([]byte(":"))
    82  	hash.Write([]byte("{"))
    83  
    84  	keys := make([]string, 0, len(h.values))
    85  
    86  	for key := range h.values {
    87  		keys = append(keys, key)
    88  	}
    89  
    90  	// order does not really matter
    91  	// but it will give us a consistent hash sum
    92  	sort.Strings(keys)
    93  	endIndex := len(keys) - 1
    94  
    95  	for idx, key := range keys {
    96  		hash.Write([]byte(key))
    97  		hash.Write([]byte(":"))
    98  
    99  		value := h.Get(key)
   100  
   101  		hash.Write([]byte(value))
   102  
   103  		if idx != endIndex {
   104  			hash.Write([]byte(","))
   105  		}
   106  	}
   107  
   108  	hash.Write([]byte("}"))
   109  
   110  	return hash.Sum64()
   111  }
   112  
   113  func (h *HTTPHeaders) Copy() core.Value {
   114  	return &HTTPHeaders{h.values}
   115  }
   116  
   117  func (h *HTTPHeaders) Clone() core.Cloneable {
   118  	cp := make(map[string][]string)
   119  
   120  	for k, v := range h.values {
   121  		cp[k] = v
   122  	}
   123  
   124  	return &HTTPHeaders{cp}
   125  }
   126  
   127  func (h *HTTPHeaders) MarshalJSON() ([]byte, error) {
   128  	headers := map[string]string{}
   129  
   130  	for key, val := range h.values {
   131  		headers[key] = strings.Join(val, ", ")
   132  	}
   133  
   134  	out, err := jettison.MarshalOpts(headers)
   135  
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return out, err
   141  }
   142  
   143  func (h *HTTPHeaders) Set(key, value string) {
   144  	textproto.MIMEHeader(h.values).Set(key, value)
   145  }
   146  
   147  func (h *HTTPHeaders) SetArr(key string, value []string) {
   148  	h.values[key] = value
   149  }
   150  
   151  func (h *HTTPHeaders) Get(key string) string {
   152  	_, found := h.values[key]
   153  
   154  	if !found {
   155  		return ""
   156  	}
   157  
   158  	return textproto.MIMEHeader(h.values).Get(key)
   159  }
   160  
   161  func (h *HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, core.PathError) {
   162  	if len(path) == 0 {
   163  		return values.None, nil
   164  	}
   165  
   166  	segmentIx := 0
   167  	segment := path[segmentIx]
   168  
   169  	return values.NewString(h.Get(string(values.ToString(segment)))), nil
   170  }
   171  
   172  func (h *HTTPHeaders) ForEach(predicate func(value []string, key string) bool) {
   173  	for key, val := range h.values {
   174  		if !predicate(val, key) {
   175  			break
   176  		}
   177  	}
   178  }