github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/types/basic.go (about) 1 package types 2 3 import ( 4 "encoding/base64" 5 "fmt" 6 "math/big" 7 "time" 8 9 "github.com/authzed/spicedb/pkg/spiceerrors" 10 11 "github.com/authzed/cel-go/cel" 12 ) 13 14 func requireType[T any](value any) (any, error) { 15 vle, ok := value.(T) 16 if !ok { 17 return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value) 18 } 19 return vle, nil 20 } 21 22 func convertNumericType[T int64 | uint64 | float64](value any) (any, error) { 23 directValue, ok := value.(T) 24 if ok { 25 return directValue, nil 26 } 27 28 floatValue, ok := value.(float64) 29 bigFloat := big.NewFloat(floatValue) 30 if !ok { 31 stringValue, ok := value.(string) 32 if !ok { 33 return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value) 34 } 35 36 f, _, err := big.ParseFloat(stringValue, 10, 64, 0) 37 if err != nil { 38 return nil, fmt.Errorf("a %T value is required, but found invalid string value `%v`", *new(T), value) 39 } 40 41 bigFloat = f 42 } 43 44 // Convert the float to the int or uint if necessary. 45 n := *new(T) 46 switch any(n).(type) { 47 case int64: 48 if !bigFloat.IsInt() { 49 return nil, fmt.Errorf("a int value is required, but found numeric value `%s`", bigFloat.String()) 50 } 51 52 numericValue, _ := bigFloat.Int64() 53 return numericValue, nil 54 55 case uint64: 56 if !bigFloat.IsInt() { 57 return nil, fmt.Errorf("a uint value is required, but found numeric value `%s`", bigFloat.String()) 58 } 59 60 numericValue, _ := bigFloat.Int64() 61 if numericValue < 0 { 62 return nil, fmt.Errorf("a uint value is required, but found int64 value `%s`", bigFloat.String()) 63 } 64 return uint64(numericValue), nil 65 66 case float64: 67 numericValue, _ := bigFloat.Float64() 68 return numericValue, nil 69 70 default: 71 return nil, spiceerrors.MustBugf("unsupported numeric type in caveat number type conversion: %T", n) 72 } 73 } 74 75 var ( 76 AnyType = registerBasicType("any", cel.DynType, func(value any) (any, error) { return value, nil }) 77 BooleanType = registerBasicType("bool", cel.BoolType, requireType[bool]) 78 StringType = registerBasicType("string", cel.StringType, requireType[string]) 79 IntType = registerBasicType("int", cel.IntType, convertNumericType[int64]) 80 UIntType = registerBasicType("uint", cel.IntType, convertNumericType[uint64]) 81 DoubleType = registerBasicType("double", cel.DoubleType, convertNumericType[float64]) 82 83 BytesType = registerBasicType("bytes", cel.BytesType, func(value any) (any, error) { 84 vle, ok := value.(string) 85 if !ok { 86 return nil, fmt.Errorf("bytes requires a base64 unicode string, found: %T `%v`", value, value) 87 } 88 89 decoded, err := base64.StdEncoding.DecodeString(vle) 90 if err != nil { 91 return nil, fmt.Errorf("bytes requires a base64 encoded string: %w", err) 92 } 93 94 return decoded, nil 95 }) 96 97 DurationType = registerBasicType("duration", cel.DurationType, func(value any) (any, error) { 98 vle, ok := value.(string) 99 if !ok { 100 return nil, fmt.Errorf("durations requires a duration string, found: %T", value) 101 } 102 103 d, err := time.ParseDuration(vle) 104 if err != nil { 105 return nil, fmt.Errorf("could not parse duration string `%s`: %w", vle, err) 106 } 107 108 return d, nil 109 }) 110 111 TimestampType = registerBasicType("timestamp", cel.TimestampType, func(value any) (any, error) { 112 vle, ok := value.(string) 113 if !ok { 114 return nil, fmt.Errorf("timestamps requires a RFC 3339 formatted timestamp string, found: %T `%v`", value, value) 115 } 116 117 d, err := time.Parse(time.RFC3339, vle) 118 if err != nil { 119 return nil, fmt.Errorf("could not parse RFC 3339 formatted timestamp string `%s`: %w", vle, err) 120 } 121 122 return d, nil 123 }) 124 125 ListType = registerGenericType("list", 1, 126 func(childTypes []VariableType) VariableType { 127 return VariableType{ 128 localName: "list", 129 celType: cel.ListType(childTypes[0].celType), 130 childTypes: childTypes, 131 converter: func(value any) (any, error) { 132 vle, ok := value.([]any) 133 if !ok { 134 return nil, fmt.Errorf("list requires a list, found: %T", value) 135 } 136 137 converted := make([]any, 0, len(vle)) 138 for index, item := range vle { 139 convertedItem, err := childTypes[0].ConvertValue(item) 140 if err != nil { 141 return nil, fmt.Errorf("found an invalid value for item at index %d: %w", index, err) 142 } 143 converted = append(converted, convertedItem) 144 } 145 146 return converted, nil 147 }, 148 } 149 }) 150 151 MapType = registerGenericType("map", 1, 152 func(childTypes []VariableType) VariableType { 153 return VariableType{ 154 localName: "map", 155 celType: cel.MapType(cel.StringType, childTypes[0].celType), 156 childTypes: childTypes, 157 converter: func(value any) (any, error) { 158 vle, ok := value.(map[string]any) 159 if !ok { 160 return nil, fmt.Errorf("map requires a map, found: %T", value) 161 } 162 163 converted := make(map[string]any, len(vle)) 164 for key, item := range vle { 165 convertedItem, err := childTypes[0].ConvertValue(item) 166 if err != nil { 167 return nil, fmt.Errorf("found an invalid value for key `%s`: %w", key, err) 168 } 169 170 converted[key] = convertedItem 171 } 172 173 return converted, nil 174 }, 175 } 176 }, 177 ) 178 ) 179 180 func MustListType(childTypes ...VariableType) VariableType { 181 t, err := ListType(childTypes...) 182 if err != nil { 183 panic(err) 184 } 185 return t 186 } 187 188 func MustMapType(childTypes ...VariableType) VariableType { 189 t, err := MapType(childTypes...) 190 if err != nil { 191 panic(err) 192 } 193 return t 194 }