github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/types/basic.go (about)

     1  package types
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"math/big"
     7  	"time"
     8  
     9  	"github.com/authzed/spicedb/pkg/spiceerrors"
    10  
    11  	"github.com/authzed/cel-go/cel"
    12  )
    13  
    14  func requireType[T any](value any) (any, error) {
    15  	vle, ok := value.(T)
    16  	if !ok {
    17  		return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value)
    18  	}
    19  	return vle, nil
    20  }
    21  
    22  func convertNumericType[T int64 | uint64 | float64](value any) (any, error) {
    23  	directValue, ok := value.(T)
    24  	if ok {
    25  		return directValue, nil
    26  	}
    27  
    28  	floatValue, ok := value.(float64)
    29  	bigFloat := big.NewFloat(floatValue)
    30  	if !ok {
    31  		stringValue, ok := value.(string)
    32  		if !ok {
    33  			return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value)
    34  		}
    35  
    36  		f, _, err := big.ParseFloat(stringValue, 10, 64, 0)
    37  		if err != nil {
    38  			return nil, fmt.Errorf("a %T value is required, but found invalid string value `%v`", *new(T), value)
    39  		}
    40  
    41  		bigFloat = f
    42  	}
    43  
    44  	// Convert the float to the int or uint if necessary.
    45  	n := *new(T)
    46  	switch any(n).(type) {
    47  	case int64:
    48  		if !bigFloat.IsInt() {
    49  			return nil, fmt.Errorf("a int value is required, but found numeric value `%s`", bigFloat.String())
    50  		}
    51  
    52  		numericValue, _ := bigFloat.Int64()
    53  		return numericValue, nil
    54  
    55  	case uint64:
    56  		if !bigFloat.IsInt() {
    57  			return nil, fmt.Errorf("a uint value is required, but found numeric value `%s`", bigFloat.String())
    58  		}
    59  
    60  		numericValue, _ := bigFloat.Int64()
    61  		if numericValue < 0 {
    62  			return nil, fmt.Errorf("a uint value is required, but found int64 value `%s`", bigFloat.String())
    63  		}
    64  		return uint64(numericValue), nil
    65  
    66  	case float64:
    67  		numericValue, _ := bigFloat.Float64()
    68  		return numericValue, nil
    69  
    70  	default:
    71  		return nil, spiceerrors.MustBugf("unsupported numeric type in caveat number type conversion: %T", n)
    72  	}
    73  }
    74  
    75  var (
    76  	AnyType     = registerBasicType("any", cel.DynType, func(value any) (any, error) { return value, nil })
    77  	BooleanType = registerBasicType("bool", cel.BoolType, requireType[bool])
    78  	StringType  = registerBasicType("string", cel.StringType, requireType[string])
    79  	IntType     = registerBasicType("int", cel.IntType, convertNumericType[int64])
    80  	UIntType    = registerBasicType("uint", cel.IntType, convertNumericType[uint64])
    81  	DoubleType  = registerBasicType("double", cel.DoubleType, convertNumericType[float64])
    82  
    83  	BytesType = registerBasicType("bytes", cel.BytesType, func(value any) (any, error) {
    84  		vle, ok := value.(string)
    85  		if !ok {
    86  			return nil, fmt.Errorf("bytes requires a base64 unicode string, found: %T `%v`", value, value)
    87  		}
    88  
    89  		decoded, err := base64.StdEncoding.DecodeString(vle)
    90  		if err != nil {
    91  			return nil, fmt.Errorf("bytes requires a base64 encoded string: %w", err)
    92  		}
    93  
    94  		return decoded, nil
    95  	})
    96  
    97  	DurationType = registerBasicType("duration", cel.DurationType, func(value any) (any, error) {
    98  		vle, ok := value.(string)
    99  		if !ok {
   100  			return nil, fmt.Errorf("durations requires a duration string, found: %T", value)
   101  		}
   102  
   103  		d, err := time.ParseDuration(vle)
   104  		if err != nil {
   105  			return nil, fmt.Errorf("could not parse duration string `%s`: %w", vle, err)
   106  		}
   107  
   108  		return d, nil
   109  	})
   110  
   111  	TimestampType = registerBasicType("timestamp", cel.TimestampType, func(value any) (any, error) {
   112  		vle, ok := value.(string)
   113  		if !ok {
   114  			return nil, fmt.Errorf("timestamps requires a RFC 3339 formatted timestamp string, found: %T `%v`", value, value)
   115  		}
   116  
   117  		d, err := time.Parse(time.RFC3339, vle)
   118  		if err != nil {
   119  			return nil, fmt.Errorf("could not parse RFC 3339 formatted timestamp string `%s`: %w", vle, err)
   120  		}
   121  
   122  		return d, nil
   123  	})
   124  
   125  	ListType = registerGenericType("list", 1,
   126  		func(childTypes []VariableType) VariableType {
   127  			return VariableType{
   128  				localName:  "list",
   129  				celType:    cel.ListType(childTypes[0].celType),
   130  				childTypes: childTypes,
   131  				converter: func(value any) (any, error) {
   132  					vle, ok := value.([]any)
   133  					if !ok {
   134  						return nil, fmt.Errorf("list requires a list, found: %T", value)
   135  					}
   136  
   137  					converted := make([]any, 0, len(vle))
   138  					for index, item := range vle {
   139  						convertedItem, err := childTypes[0].ConvertValue(item)
   140  						if err != nil {
   141  							return nil, fmt.Errorf("found an invalid value for item at index %d: %w", index, err)
   142  						}
   143  						converted = append(converted, convertedItem)
   144  					}
   145  
   146  					return converted, nil
   147  				},
   148  			}
   149  		})
   150  
   151  	MapType = registerGenericType("map", 1,
   152  		func(childTypes []VariableType) VariableType {
   153  			return VariableType{
   154  				localName:  "map",
   155  				celType:    cel.MapType(cel.StringType, childTypes[0].celType),
   156  				childTypes: childTypes,
   157  				converter: func(value any) (any, error) {
   158  					vle, ok := value.(map[string]any)
   159  					if !ok {
   160  						return nil, fmt.Errorf("map requires a map, found: %T", value)
   161  					}
   162  
   163  					converted := make(map[string]any, len(vle))
   164  					for key, item := range vle {
   165  						convertedItem, err := childTypes[0].ConvertValue(item)
   166  						if err != nil {
   167  							return nil, fmt.Errorf("found an invalid value for key `%s`: %w", key, err)
   168  						}
   169  
   170  						converted[key] = convertedItem
   171  					}
   172  
   173  					return converted, nil
   174  				},
   175  			}
   176  		},
   177  	)
   178  )
   179  
   180  func MustListType(childTypes ...VariableType) VariableType {
   181  	t, err := ListType(childTypes...)
   182  	if err != nil {
   183  		panic(err)
   184  	}
   185  	return t
   186  }
   187  
   188  func MustMapType(childTypes ...VariableType) VariableType {
   189  	t, err := MapType(childTypes...)
   190  	if err != nil {
   191  		panic(err)
   192  	}
   193  	return t
   194  }