
     1  package tkn
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     7  	pairing ""
     8  )
    10  const (
    11  	bkAttribute   = "internal-boneh-katz-transform-attribute"
    12  	attributeSize = pairing.ScalarSize + 1
    13  )
    15  type Wire struct {
    16  	Label    string
    17  	RawValue string
    18  	Value    *pairing.Scalar
    19  	Positive bool
    20  }
    22  func (w *Wire) String() string {
    23  	if w.Positive {
    24  		return fmt.Sprintf("%s:%s", w.Label, w.RawValue)
    25  	}
    26  	return fmt.Sprintf("not %s:%s", w.Label, w.RawValue)
    27  }
    29  type Policy struct {
    30  	Inputs []Wire
    31  	F      Formula // monotonic boolean formula
    32  }
    34  type Attribute struct {
    35  	wild  bool // false if tame
    36  	Value *pairing.Scalar
    37  }
    39  func (a *Attribute) marshalBinary() ([]byte, error) {
    40  	ret := make([]byte, 1)
    41  	if a.wild {
    42  		ret[0] = 1
    43  	}
    44  	aBytes, err := a.Value.MarshalBinary()
    45  	if err != nil {
    46  		return nil, err
    47  	}
    49  	return append(ret, aBytes...), nil
    50  }
    52  func (a *Attribute) unmarshalBinary(data []byte) error {
    53  	if len(data) != attributeSize {
    54  		return fmt.Errorf("unmarshalling Attribute failed: invalid input length, expected: %d, received: %d",
    55  			attributeSize,
    56  			len(data))
    57  	}
    58  	a.wild = false
    59  	if data[0] == 1 {
    60  		a.wild = true
    61  	}
    62  	a.Value = &pairing.Scalar{}
    63  	err := a.Value.UnmarshalBinary(data[1:])
    64  	if err != nil {
    65  		return fmt.Errorf("unmarshalling Attribute failed: %w", err)
    66  	}
    67  	return nil
    68  }
    70  func (a *Attribute) Equal(b *Attribute) bool {
    71  	return a.wild == b.wild && a.Value.IsEqual(b.Value) == 1
    72  }
    74  type Attributes map[string]Attribute
    76  func (a *Attributes) marshalBinary() ([]byte, error) {
    77  	ret := make([]byte, 2)
    78  	binary.LittleEndian.PutUint16(ret[0:], uint16(len(*a)))
    80  	aBytes, err := marshalBinarySortedMapAttribute(*a)
    81  	if err != nil {
    82  		return nil, fmt.Errorf("marshalling Attributes failed: %w", err)
    83  	}
    84  	ret = append(ret, aBytes...)
    86  	return ret, nil
    87  }
    89  func (a *Attributes) unmarshalBinary(data []byte) error {
    90  	if len(data) < 2 {
    91  		return fmt.Errorf("unmarshalling Attributes failed: data too short")
    92  	}
    93  	n := int(binary.LittleEndian.Uint16(data))
    94  	data = data[2:]
    95  	*a = make(map[string]Attribute, n)
    96  	for i := 0; i < n; i++ {
    97  		labelBytes, rem, err := removeLenPrefixed(data)
    98  		if err != nil {
    99  			return fmt.Errorf("unmarshalling Attributes failed: %w", err)
   100  		}
   101  		if len(rem) < attributeSize {
   102  			return fmt.Errorf("unmarshalling Attributes failed: data too short")
   103  		}
   104  		attr := Attribute{}
   105  		err = attr.unmarshalBinary(rem[:attributeSize])
   106  		if err != nil {
   107  			return fmt.Errorf("unmarshalling Attributes failed: %w", err)
   108  		}
   109  		(*a)[string(labelBytes)] = attr
   110  		data = rem[attributeSize:]
   111  	}
   112  	if len(data) != 0 {
   113  		return fmt.Errorf("unmarshalling Attributes failed: excess bytes remain in data")
   114  	}
   115  	return nil
   116  }
   118  func (a *Attributes) Equal(b *Attributes) bool {
   119  	if len(*a) != len(*b) {
   120  		return false
   121  	}
   122  	for k := range *a {
   123  		v := (*a)[k]
   124  		if v2, ok := (*b)[k]; !(ok && v2.Equal(&v)) {
   125  			return false
   126  		}
   127  	}
   128  	return true
   129  }
   131  func (w *Wire) MarshalBinary() ([]byte, error) {
   132  	strBytes := []byte(w.Label)
   133  	valBytes := []byte(w.RawValue)
   134  	intBytes, err := w.Value.MarshalBinary()
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	totalLen := len(strBytes) + len(valBytes) + len(intBytes) + 2 + 2 + 2 + 1
   139  	ret := make([]byte, totalLen)
   140  	where := 0
   141  	binary.LittleEndian.PutUint16(ret[where:], uint16(len(strBytes)))
   142  	where += 2
   143  	where += copy(ret[where:], strBytes)
   144  	binary.LittleEndian.PutUint16(ret[where:], uint16(len(valBytes)))
   145  	where += 2
   146  	where += copy(ret[where:], valBytes)
   147  	binary.LittleEndian.PutUint16(ret[where:], uint16(len(intBytes)))
   148  	where += 2
   149  	where += copy(ret[where:], intBytes)
   150  	if w.Positive {
   151  		ret[where] = 1
   152  	} else {
   153  		ret[where] = 0
   154  	}
   155  	return ret, nil
   156  }
   158  func (w *Wire) UnmarshalBinary(data []byte) error {
   159  	where := 0
   160  	if len(data) < 2 {
   161  		return fmt.Errorf("data not long enough")
   162  	}
   163  	strLen := int(binary.LittleEndian.Uint16(data[where:]))
   164  	where += 2
   165  	if len(data[where:]) < strLen {
   166  		return fmt.Errorf("data not long enough")
   167  	}
   168  	w.Label = string(data[where : where+strLen])
   169  	where += strLen
   171  	if len(data[where:]) < 2 {
   172  		return fmt.Errorf("data not long enough")
   173  	}
   174  	valLen := int(binary.LittleEndian.Uint16(data[where:]))
   175  	where += 2
   176  	if len(data[where:]) < valLen {
   177  		return fmt.Errorf("data not long enough")
   178  	}
   179  	w.RawValue = string(data[where : where+valLen])
   180  	where += valLen
   182  	if len(data[where:]) < 2 {
   183  		return fmt.Errorf("data not long enough")
   184  	}
   185  	intLen := int(binary.LittleEndian.Uint16(data[where:]))
   186  	where += 2
   187  	if len(data[where:]) < intLen {
   188  		return fmt.Errorf("data not long enough")
   189  	}
   190  	w.Value = &pairing.Scalar{}
   191  	w.Value.SetBytes(data[where : where+intLen])
   192  	where += intLen
   193  	if len(data[where:]) < 1 {
   194  		return fmt.Errorf("data not long enough")
   195  	}
   196  	if data[where] == 1 {
   197  		w.Positive = true
   198  	} else {
   199  		w.Positive = false
   200  	}
   201  	return nil
   202  }
   204  func (w *Wire) Equal(w2 *Wire) bool {
   205  	return w.Label == w2.Label && w.RawValue == w2.RawValue && w.Positive == w2.Positive && w.Value.IsEqual(w2.Value) == 1
   206  }
   208  func (p *Policy) MarshalBinary() ([]byte, error) {
   209  	ret := make([]byte, 2)
   210  	fBytes, err := p.F.MarshalBinary()
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	binary.LittleEndian.PutUint16(ret[0:2], uint16(len(fBytes)))
   215  	ret = append(ret, fBytes...)
   216  	ret = append(ret, 0, 0)
   217  	binary.LittleEndian.PutUint16(ret[len(ret)-2:], uint16(len(p.Inputs)))
   218  	for i := 0; i < len(p.Inputs); i++ {
   219  		input, err := p.Inputs[i].MarshalBinary()
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		ret = append(ret, 0, 0)
   224  		binary.LittleEndian.PutUint16(ret[len(ret)-2:], uint16(len(input)))
   225  		ret = append(ret, input...)
   226  	}
   227  	return ret, nil
   228  }
   230  func (p *Policy) UnmarshalBinary(data []byte) error {
   231  	// Extract formula
   232  	if len(data) < 2 {
   233  		return fmt.Errorf("data not long enough")
   234  	}
   235  	fLen := uint(binary.LittleEndian.Uint16(data))
   236  	data = data[2:]
   237  	err := p.F.UnmarshalBinary(data)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	data = data[fLen:]
   242  	// Extract wires
   243  	if len(data) < 2 {
   244  		return fmt.Errorf("data not long enough")
   245  	}
   246  	nWires := int(binary.LittleEndian.Uint16(data))
   247  	data = data[2:]
   248  	p.Inputs = make([]Wire, nWires)
   249  	for i := 0; i < nWires; i++ {
   250  		wireLen := uint(binary.LittleEndian.Uint16(data))
   251  		data = data[2:]
   252  		err = p.Inputs[i].UnmarshalBinary(data)
   253  		data = data[wireLen:]
   254  		if err != nil {
   255  			return fmt.Errorf("data not long enough")
   256  		}
   257  	}
   258  	return nil
   259  }
   261  func (p *Policy) Equal(p2 *Policy) bool {
   262  	if len(p.Inputs) != len(p2.Inputs) {
   263  		return false
   264  	}
   265  	if !p.F.Equal(p2.F) {
   266  		return false
   267  	}
   268  	for i := range p.Inputs {
   269  		if !p.Inputs[i].Equal(&p2.Inputs[i]) {
   270  			return false
   271  		}
   272  	}
   273  	return true
   274  }
   276  func (p *Policy) String() string {
   277  	// gateAssign takes n wires (intermediates and outputs) and maps to the gate
   278  	// that set them. For details, refer to [Formula].
   279  	offset := len(p.F.Gates) + 1
   280  	gateAssign := make([]int, len(p.F.Gates))
   281  	for i, gate := range p.F.Gates {
   282  		gateAssign[gate.Out-offset] = i
   283  	}
   284  	return p.printWire(gateAssign, 2*len(p.F.Gates))
   285  }
   287  func (p *Policy) printWire(gateAssign []int, wire int) string {
   288  	n := len(p.F.Gates)
   289  	if wire < n+1 {
   290  		return p.Inputs[wire].String()
   291  	}
   292  	gate := p.F.Gates[gateAssign[wire-n-1]]
   293  	return fmt.Sprintf("(%s %s %s)", p.printWire(gateAssign, gate.In0), gate.operator(), p.printWire(gateAssign, gate.In1))
   294  }
   296  type match struct {
   297  	wire  int
   298  	label string
   299  }
   301  type Satisfaction struct {
   302  	matches []match
   303  }
   305  func (p *Policy) pi() []int {
   306  	ret := make([]int, len(p.Inputs))
   307  	counts := make(map[string]int)
   308  	for i := 0; i < len(p.Inputs); i++ {
   309  		// Paper would have us put a +1 here
   310  		// we change the indexing instead
   311  		ret[i] = counts[p.Inputs[i].Label]
   312  		counts[p.Inputs[i].Label]++
   313  	}
   314  	return ret
   315  }
   317  func (p *Policy) Satisfaction(attr *Attributes) (*Satisfaction, error) {
   318  	// For now its all of the wires, so we don't need to look at the formula.
   319  	var matches []match
   320  	for i := 0; i < len(p.Inputs); i++ {
   321  		wire := p.Inputs[i]
   322  		at, ok := (*attr)[wire.Label]
   323  		if !ok {
   324  			continue // missing Attribute might not be needed
   325  		}
   326  		if wire.Positive {
   327  			if (wire.Value.IsEqual(at.Value) == 1) || at.wild {
   328  				matches = append(matches, match{i, wire.Label})
   329  			}
   330  		} else {
   331  			if (wire.Value.IsEqual(at.Value) == 0) || at.wild {
   332  				matches = append(matches, match{i, wire.Label})
   333  			}
   334  		}
   335  	}
   336  	matches, err := p.F.satisfaction(matches)
   337  	if err != nil {
   338  		return nil, err
   339  	}
   341  	return &Satisfaction{
   342  		matches,
   343  	}, nil
   344  }
   346  // Carry Out the augmentation under the BK transform
   347  func (p *Policy) transformBK(val *pairing.Scalar) *Policy {
   348  	ret := new(Policy)
   349  	for i := 0; i < len(p.Inputs); i++ {
   350  		ret.Inputs = append(ret.Inputs, p.Inputs[i])
   351  	}
   352  	ret.Inputs = append(ret.Inputs, Wire{
   353  		Label:    bkAttribute,
   354  		Value:    val,
   355  		Positive: true,
   356  	})
   357  	ret.F = p.F.insertAnd()
   358  	return ret
   359  }
   361  func transformAttrsBK(attr *Attributes) *Attributes {
   362  	ret := make(map[string]Attribute)
   363  	for name, val := range *attr {
   364  		ret[name] = val
   365  	}
   366  	ret[bkAttribute] = Attribute{
   367  		wild:  true,
   368  		Value: &pairing.Scalar{},
   369  	}
   370  	return (*Attributes)(&ret)
   371  }