github.com/cloudflare/circl@v1.5.0/abe/cpabe/tkn20/internal/tkn/policy.go (about) 1 package tkn 2 3 import ( 4 "encoding/binary" 5 "fmt" 6 7 pairing "github.com/cloudflare/circl/ecc/bls12381" 8 ) 9 10 const ( 11 bkAttribute = "internal-boneh-katz-transform-attribute" 12 attributeSize = pairing.ScalarSize + 1 13 ) 14 15 type Wire struct { 16 Label string 17 RawValue string 18 Value *pairing.Scalar 19 Positive bool 20 } 21 22 func (w *Wire) String() string { 23 if w.Positive { 24 return fmt.Sprintf("%s:%s", w.Label, w.RawValue) 25 } 26 return fmt.Sprintf("not %s:%s", w.Label, w.RawValue) 27 } 28 29 type Policy struct { 30 Inputs []Wire 31 F Formula // monotonic boolean formula 32 } 33 34 type Attribute struct { 35 wild bool // false if tame 36 Value *pairing.Scalar 37 } 38 39 func (a *Attribute) marshalBinary() ([]byte, error) { 40 ret := make([]byte, 1) 41 if a.wild { 42 ret[0] = 1 43 } 44 aBytes, err := a.Value.MarshalBinary() 45 if err != nil { 46 return nil, err 47 } 48 49 return append(ret, aBytes...), nil 50 } 51 52 func (a *Attribute) unmarshalBinary(data []byte) error { 53 if len(data) != attributeSize { 54 return fmt.Errorf("unmarshalling Attribute failed: invalid input length, expected: %d, received: %d", 55 attributeSize, 56 len(data)) 57 } 58 a.wild = false 59 if data[0] == 1 { 60 a.wild = true 61 } 62 a.Value = &pairing.Scalar{} 63 err := a.Value.UnmarshalBinary(data[1:]) 64 if err != nil { 65 return fmt.Errorf("unmarshalling Attribute failed: %w", err) 66 } 67 return nil 68 } 69 70 func (a *Attribute) Equal(b *Attribute) bool { 71 return a.wild == b.wild && a.Value.IsEqual(b.Value) == 1 72 } 73 74 type Attributes map[string]Attribute 75 76 func (a *Attributes) marshalBinary() ([]byte, error) { 77 ret := make([]byte, 2) 78 binary.LittleEndian.PutUint16(ret[0:], uint16(len(*a))) 79 80 aBytes, err := marshalBinarySortedMapAttribute(*a) 81 if err != nil { 82 return nil, fmt.Errorf("marshalling Attributes failed: %w", err) 83 } 84 ret = append(ret, aBytes...) 85 86 return ret, nil 87 } 88 89 func (a *Attributes) unmarshalBinary(data []byte) error { 90 if len(data) < 2 { 91 return fmt.Errorf("unmarshalling Attributes failed: data too short") 92 } 93 n := int(binary.LittleEndian.Uint16(data)) 94 data = data[2:] 95 *a = make(map[string]Attribute, n) 96 for i := 0; i < n; i++ { 97 labelBytes, rem, err := removeLenPrefixed(data) 98 if err != nil { 99 return fmt.Errorf("unmarshalling Attributes failed: %w", err) 100 } 101 if len(rem) < attributeSize { 102 return fmt.Errorf("unmarshalling Attributes failed: data too short") 103 } 104 attr := Attribute{} 105 err = attr.unmarshalBinary(rem[:attributeSize]) 106 if err != nil { 107 return fmt.Errorf("unmarshalling Attributes failed: %w", err) 108 } 109 (*a)[string(labelBytes)] = attr 110 data = rem[attributeSize:] 111 } 112 if len(data) != 0 { 113 return fmt.Errorf("unmarshalling Attributes failed: excess bytes remain in data") 114 } 115 return nil 116 } 117 118 func (a *Attributes) Equal(b *Attributes) bool { 119 if len(*a) != len(*b) { 120 return false 121 } 122 for k := range *a { 123 v := (*a)[k] 124 if v2, ok := (*b)[k]; !(ok && v2.Equal(&v)) { 125 return false 126 } 127 } 128 return true 129 } 130 131 func (w *Wire) MarshalBinary() ([]byte, error) { 132 strBytes := []byte(w.Label) 133 valBytes := []byte(w.RawValue) 134 intBytes, err := w.Value.MarshalBinary() 135 if err != nil { 136 return nil, err 137 } 138 totalLen := len(strBytes) + len(valBytes) + len(intBytes) + 2 + 2 + 2 + 1 139 ret := make([]byte, totalLen) 140 where := 0 141 binary.LittleEndian.PutUint16(ret[where:], uint16(len(strBytes))) 142 where += 2 143 where += copy(ret[where:], strBytes) 144 binary.LittleEndian.PutUint16(ret[where:], uint16(len(valBytes))) 145 where += 2 146 where += copy(ret[where:], valBytes) 147 binary.LittleEndian.PutUint16(ret[where:], uint16(len(intBytes))) 148 where += 2 149 where += copy(ret[where:], intBytes) 150 if w.Positive { 151 ret[where] = 1 152 } else { 153 ret[where] = 0 154 } 155 return ret, nil 156 } 157 158 func (w *Wire) UnmarshalBinary(data []byte) error { 159 where := 0 160 if len(data) < 2 { 161 return fmt.Errorf("data not long enough") 162 } 163 strLen := int(binary.LittleEndian.Uint16(data[where:])) 164 where += 2 165 if len(data[where:]) < strLen { 166 return fmt.Errorf("data not long enough") 167 } 168 w.Label = string(data[where : where+strLen]) 169 where += strLen 170 171 if len(data[where:]) < 2 { 172 return fmt.Errorf("data not long enough") 173 } 174 valLen := int(binary.LittleEndian.Uint16(data[where:])) 175 where += 2 176 if len(data[where:]) < valLen { 177 return fmt.Errorf("data not long enough") 178 } 179 w.RawValue = string(data[where : where+valLen]) 180 where += valLen 181 182 if len(data[where:]) < 2 { 183 return fmt.Errorf("data not long enough") 184 } 185 intLen := int(binary.LittleEndian.Uint16(data[where:])) 186 where += 2 187 if len(data[where:]) < intLen { 188 return fmt.Errorf("data not long enough") 189 } 190 w.Value = &pairing.Scalar{} 191 w.Value.SetBytes(data[where : where+intLen]) 192 where += intLen 193 if len(data[where:]) < 1 { 194 return fmt.Errorf("data not long enough") 195 } 196 if data[where] == 1 { 197 w.Positive = true 198 } else { 199 w.Positive = false 200 } 201 return nil 202 } 203 204 func (w *Wire) Equal(w2 *Wire) bool { 205 return w.Label == w2.Label && w.RawValue == w2.RawValue && w.Positive == w2.Positive && w.Value.IsEqual(w2.Value) == 1 206 } 207 208 func (p *Policy) MarshalBinary() ([]byte, error) { 209 ret := make([]byte, 2) 210 fBytes, err := p.F.MarshalBinary() 211 if err != nil { 212 return nil, err 213 } 214 binary.LittleEndian.PutUint16(ret[0:2], uint16(len(fBytes))) 215 ret = append(ret, fBytes...) 216 ret = append(ret, 0, 0) 217 binary.LittleEndian.PutUint16(ret[len(ret)-2:], uint16(len(p.Inputs))) 218 for i := 0; i < len(p.Inputs); i++ { 219 input, err := p.Inputs[i].MarshalBinary() 220 if err != nil { 221 return nil, err 222 } 223 ret = append(ret, 0, 0) 224 binary.LittleEndian.PutUint16(ret[len(ret)-2:], uint16(len(input))) 225 ret = append(ret, input...) 226 } 227 return ret, nil 228 } 229 230 func (p *Policy) UnmarshalBinary(data []byte) error { 231 // Extract formula 232 if len(data) < 2 { 233 return fmt.Errorf("data not long enough") 234 } 235 fLen := uint(binary.LittleEndian.Uint16(data)) 236 data = data[2:] 237 err := p.F.UnmarshalBinary(data) 238 if err != nil { 239 return err 240 } 241 data = data[fLen:] 242 // Extract wires 243 if len(data) < 2 { 244 return fmt.Errorf("data not long enough") 245 } 246 nWires := int(binary.LittleEndian.Uint16(data)) 247 data = data[2:] 248 p.Inputs = make([]Wire, nWires) 249 for i := 0; i < nWires; i++ { 250 wireLen := uint(binary.LittleEndian.Uint16(data)) 251 data = data[2:] 252 err = p.Inputs[i].UnmarshalBinary(data) 253 data = data[wireLen:] 254 if err != nil { 255 return fmt.Errorf("data not long enough") 256 } 257 } 258 return nil 259 } 260 261 func (p *Policy) Equal(p2 *Policy) bool { 262 if len(p.Inputs) != len(p2.Inputs) { 263 return false 264 } 265 if !p.F.Equal(p2.F) { 266 return false 267 } 268 for i := range p.Inputs { 269 if !p.Inputs[i].Equal(&p2.Inputs[i]) { 270 return false 271 } 272 } 273 return true 274 } 275 276 func (p *Policy) String() string { 277 // gateAssign takes n wires (intermediates and outputs) and maps to the gate 278 // that set them. For details, refer to [Formula]. 279 offset := len(p.F.Gates) + 1 280 gateAssign := make([]int, len(p.F.Gates)) 281 for i, gate := range p.F.Gates { 282 gateAssign[gate.Out-offset] = i 283 } 284 return p.printWire(gateAssign, 2*len(p.F.Gates)) 285 } 286 287 func (p *Policy) printWire(gateAssign []int, wire int) string { 288 n := len(p.F.Gates) 289 if wire < n+1 { 290 return p.Inputs[wire].String() 291 } 292 gate := p.F.Gates[gateAssign[wire-n-1]] 293 return fmt.Sprintf("(%s %s %s)", p.printWire(gateAssign, gate.In0), gate.operator(), p.printWire(gateAssign, gate.In1)) 294 } 295 296 type match struct { 297 wire int 298 label string 299 } 300 301 type Satisfaction struct { 302 matches []match 303 } 304 305 func (p *Policy) pi() []int { 306 ret := make([]int, len(p.Inputs)) 307 counts := make(map[string]int) 308 for i := 0; i < len(p.Inputs); i++ { 309 // Paper would have us put a +1 here 310 // we change the indexing instead 311 ret[i] = counts[p.Inputs[i].Label] 312 counts[p.Inputs[i].Label]++ 313 } 314 return ret 315 } 316 317 func (p *Policy) Satisfaction(attr *Attributes) (*Satisfaction, error) { 318 // For now its all of the wires, so we don't need to look at the formula. 319 var matches []match 320 for i := 0; i < len(p.Inputs); i++ { 321 wire := p.Inputs[i] 322 at, ok := (*attr)[wire.Label] 323 if !ok { 324 continue // missing Attribute might not be needed 325 } 326 if wire.Positive { 327 if (wire.Value.IsEqual(at.Value) == 1) || at.wild { 328 matches = append(matches, match{i, wire.Label}) 329 } 330 } else { 331 if (wire.Value.IsEqual(at.Value) == 0) || at.wild { 332 matches = append(matches, match{i, wire.Label}) 333 } 334 } 335 } 336 matches, err := p.F.satisfaction(matches) 337 if err != nil { 338 return nil, err 339 } 340 341 return &Satisfaction{ 342 matches, 343 }, nil 344 } 345 346 // Carry Out the augmentation under the BK transform 347 func (p *Policy) transformBK(val *pairing.Scalar) *Policy { 348 ret := new(Policy) 349 for i := 0; i < len(p.Inputs); i++ { 350 ret.Inputs = append(ret.Inputs, p.Inputs[i]) 351 } 352 ret.Inputs = append(ret.Inputs, Wire{ 353 Label: bkAttribute, 354 Value: val, 355 Positive: true, 356 }) 357 ret.F = p.F.insertAnd() 358 return ret 359 } 360 361 func transformAttrsBK(attr *Attributes) *Attributes { 362 ret := make(map[string]Attribute) 363 for name, val := range *attr { 364 ret[name] = val 365 } 366 ret[bkAttribute] = Attribute{ 367 wild: true, 368 Value: &pairing.Scalar{}, 369 } 370 return (*Attributes)(&ret) 371 }