go.temporal.io/server@v1.23.0/common/searchattribute/encode_value.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package searchattribute
    26  
    27  import (
    28  	"errors"
    29  	"fmt"
    30  	"time"
    31  	"unicode/utf8"
    32  
    33  	commonpb "go.temporal.io/api/common/v1"
    34  	enumspb "go.temporal.io/api/enums/v1"
    35  
    36  	"go.temporal.io/server/common/payload"
    37  )
    38  
    39  var ErrInvalidString = errors.New("SearchAttribute value is not a valid UTF-8 string")
    40  
    41  // EncodeValue encodes search attribute value and IndexedValueType to Payload.
    42  func EncodeValue(val interface{}, t enumspb.IndexedValueType) (*commonpb.Payload, error) {
    43  	valPayload, err := payload.Encode(val)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	setMetadataType(valPayload, t)
    49  	return valPayload, nil
    50  }
    51  
    52  // DecodeValue decodes search attribute value from Payload using (in order):
    53  // 1. passed type t.
    54  // 2. type from MetadataType field, if t is not specified.
    55  // allowList allows list of values when it's not keyword list type.
    56  func DecodeValue(
    57  	value *commonpb.Payload,
    58  	t enumspb.IndexedValueType,
    59  	allowList bool,
    60  ) (any, error) {
    61  	if t == enumspb.INDEXED_VALUE_TYPE_UNSPECIFIED {
    62  		var err error
    63  		t, err = enumspb.IndexedValueTypeFromString(string(value.Metadata[MetadataType]))
    64  		if err != nil {
    65  			return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
    66  		}
    67  	}
    68  
    69  	switch t {
    70  	case enumspb.INDEXED_VALUE_TYPE_BOOL:
    71  		return decodeValueTyped[bool](value, allowList)
    72  	case enumspb.INDEXED_VALUE_TYPE_DATETIME:
    73  		return decodeValueTyped[time.Time](value, allowList)
    74  	case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
    75  		return decodeValueTyped[float64](value, allowList)
    76  	case enumspb.INDEXED_VALUE_TYPE_INT:
    77  		return decodeValueTyped[int64](value, allowList)
    78  	case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
    79  		return validateStrings(decodeValueTyped[string](value, allowList))
    80  	case enumspb.INDEXED_VALUE_TYPE_TEXT:
    81  		return validateStrings(decodeValueTyped[string](value, allowList))
    82  	case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
    83  		return validateStrings(decodeValueTyped[[]string](value, false))
    84  	default:
    85  		return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
    86  	}
    87  }
    88  
    89  func validateStrings(anyValue any, err error) (any, error) {
    90  	if err != nil {
    91  		return anyValue, err
    92  	}
    93  
    94  	// validate strings
    95  	switch value := anyValue.(type) {
    96  	case string:
    97  		if !utf8.ValidString(value) {
    98  			return nil, fmt.Errorf("%w: %s", ErrInvalidString, value)
    99  		}
   100  	case []string:
   101  		for _, item := range value {
   102  			if !utf8.ValidString(item) {
   103  				return nil, fmt.Errorf("%w: %s", ErrInvalidString, item)
   104  			}
   105  		}
   106  	}
   107  	return anyValue, err
   108  }
   109  
   110  // decodeValueTyped tries to decode to the given type.
   111  // If the input is a list and allowList is false, then it will return only the first element.
   112  // If the input is a list and allowList is true, then it will return the decoded list.
   113  //
   114  //nolint:revive // allowList is a control flag
   115  func decodeValueTyped[T any](value *commonpb.Payload, allowList bool) (any, error) {
   116  	// At first, it tries to decode to pointer of actual type (i.e. `*string` for `string`).
   117  	// This is to ensure that `nil` values are decoded back as `nil` using `NilPayloadConverter`.
   118  	// If value is not `nil` but some value of expected type, the code relies on the fact that
   119  	// search attributes are always encoded with `JsonPayloadConverter`, which uses standard
   120  	// `json.Unmarshal` function, which works fine with pointer types when decoding values.
   121  	// If decoding to pointer type fails, it tries to decode to array of the same type because
   122  	// search attributes support polymorphism: field of specific type may also have an array of that type.
   123  	// If resulting slice has zero length, it gets substitute with `nil` to treat nils and empty slices equally.
   124  	// If allowList is true, it returns the list as it is. If allowList is false and the list has
   125  	// only one element, then return it. Otherwise, return an error.
   126  	// If search attribute value is `nil`, it means that search attribute needs to be removed from the document.
   127  	var val *T
   128  	if err := payload.Decode(value, &val); err != nil {
   129  		var listVal []T
   130  		if err := payload.Decode(value, &listVal); err != nil {
   131  			return nil, err
   132  		}
   133  		if len(listVal) == 0 {
   134  			return nil, nil
   135  		}
   136  		if allowList {
   137  			return listVal, nil
   138  		}
   139  		if len(listVal) == 1 {
   140  			return listVal[0], nil
   141  		}
   142  		return nil, fmt.Errorf("list of values not allowed for type %T", listVal[0])
   143  	}
   144  	if val == nil {
   145  		return nil, nil
   146  	}
   147  	return *val, nil
   148  }