github.com/cloudflare/circl@v1.5.0/abe/cpabe/tkn20/internal/tkn/formula.go (about)

     1  package tkn
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  )
     8  
     9  const (
    10  	Andgate = iota
    11  	Orgate
    12  )
    13  
    14  // Gate is a Gate in a monotone boolean formula.
    15  type Gate struct {
    16  	Class int // either Andgate or Orgate
    17  	In0   int // numbering of wires
    18  	In1   int
    19  	Out   int
    20  }
    21  
    22  func (g Gate) operator() string {
    23  	switch g.Class {
    24  	case Andgate:
    25  		return "and"
    26  	case Orgate:
    27  		return "or"
    28  	default:
    29  		return "unknown"
    30  	}
    31  }
    32  
    33  // Formula represents a monotone boolean circuit with Inputs not
    34  // repeated.  The representation is as follows: for n Gates there n+1
    35  // input wires, 1 output Wire, and n-1 intermediate wires.  That's
    36  // because there are 2n Inputs to all Gates and n outputs since every
    37  // Gate is 2:1.
    38  //
    39  // The wires are conceptually in an array. Wires 0 through n are
    40  // the input wires, and Wire 2n is the output Wire. If there are wires
    41  // between n and 2n they are intermediate wires.
    42  //
    43  // All intermediate and input wires must be used exactly once as Inputs.
    44  type Formula struct {
    45  	Gates []Gate
    46  }
    47  
    48  func (g Gate) Equal(g2 Gate) bool {
    49  	if (g.Class != g2.Class) || (g.Out != g2.Out) {
    50  		return false
    51  	}
    52  	if g.In0 == g2.In0 && g.In1 == g2.In1 {
    53  		return true
    54  	}
    55  	if g.In0 == g2.In1 && g.In1 == g2.In0 {
    56  		return true
    57  	}
    58  	return false
    59  }
    60  
    61  func (f *Formula) MarshalBinary() ([]byte, error) {
    62  	n := len(f.Gates)
    63  	ret := make([]byte, 2+7*n)
    64  	binary.LittleEndian.PutUint16(ret, uint16(len(f.Gates)))
    65  	for i := 0; i < n; i++ {
    66  		ret[7*i+2] = byte(f.Gates[i].Class)
    67  		binary.LittleEndian.PutUint16(ret[7*i+2+1:], uint16(f.Gates[i].In0))
    68  		binary.LittleEndian.PutUint16(ret[7*i+2+3:], uint16(f.Gates[i].In1))
    69  		binary.LittleEndian.PutUint16(ret[7*i+2+5:], uint16(f.Gates[i].Out))
    70  	}
    71  	return ret, nil
    72  }
    73  
    74  func (f *Formula) UnmarshalBinary(data []byte) error {
    75  	if len(data) < 2 {
    76  		return fmt.Errorf("too short data")
    77  	}
    78  	n := int(binary.LittleEndian.Uint16(data[0:2]))
    79  	f.Gates = make([]Gate, n)
    80  	for i := 0; i < n; i++ {
    81  		f.Gates[i].Class = int(data[7*i+2])
    82  		f.Gates[i].In0 = int(binary.LittleEndian.Uint16(data[7*i+2+1:]))
    83  		f.Gates[i].In1 = int(binary.LittleEndian.Uint16(data[7*i+2+3:]))
    84  		f.Gates[i].Out = int(binary.LittleEndian.Uint16(data[7*i+2+5:]))
    85  	}
    86  	return nil
    87  }
    88  
    89  func (f *Formula) wellformed() error {
    90  	// Check every Wire used once
    91  	n := len(f.Gates)
    92  	inputs := make([]bool, 2*n) // n+1 already, n-1 intermediates
    93  	outputs := make([]bool, n)
    94  	for i, gate := range f.Gates {
    95  		if gate.In0 > 2*n-1 || gate.In0 < 0 {
    96  			return fmt.Errorf("Gate %d has an Out of range In0", i)
    97  		}
    98  		if inputs[gate.In0] {
    99  			return fmt.Errorf("Gate %d has In0 that is already used", i)
   100  		}
   101  		inputs[gate.In0] = true
   102  		if gate.In1 > 2*n-1 || gate.In1 < 0 {
   103  			return fmt.Errorf("Gate %d has an Out of range In1", i)
   104  		}
   105  		if inputs[gate.In1] {
   106  			return fmt.Errorf("Gate %d has In1 that is already used", i)
   107  		}
   108  		inputs[gate.In1] = true
   109  		if gate.Out > 2*n || gate.Out < n+1 {
   110  			return fmt.Errorf("Gate %d has an Out of range Out", i)
   111  		}
   112  		outputs[gate.Out-(n+1)] = true
   113  	}
   114  	for i, wire := range inputs {
   115  		if !wire {
   116  			return fmt.Errorf("unused input Wire %d", i)
   117  		}
   118  	}
   119  	for i, wire := range outputs {
   120  		if !wire {
   121  			return fmt.Errorf("unused output Wire %d", i+(n+1))
   122  		}
   123  	}
   124  	return nil
   125  }
   126  
   127  // Sort the Gates so that Inputs are set before outputs.
   128  func (f *Formula) toposort() error {
   129  	err := f.wellformed()
   130  	if err != nil {
   131  		return err
   132  	}
   133  	n := len(f.Gates)
   134  	if n == 0 {
   135  		return nil
   136  	}
   137  	// Intermediate wires are indexed after subtracting n+1
   138  	outputGate := make([]int, n) // the Gate that sets this Wire
   139  	inputGate := make([]int, n)  // the Gate that uses this intermediate Wire.
   140  	counts := make([]int, n)     // the number of Inputs no yet output
   141  	queue := make([]int, 0, n)
   142  	reordered := make([]Gate, 0, n)
   143  	inputGate[n-1] = -1 // No Gate uses the output as input
   144  
   145  	for i, gate := range f.Gates {
   146  		outputGate[gate.Out-(n+1)] = i
   147  		if gate.In0 > n {
   148  			inputGate[gate.In0-(n+1)] = i
   149  			counts[i]++
   150  		}
   151  		if gate.In1 > n {
   152  			inputGate[gate.In1-(n+1)] = i
   153  			counts[i]++
   154  		}
   155  	}
   156  	for i := 0; i < n; i++ {
   157  		if counts[i] == 0 {
   158  			queue = append(queue, i)
   159  		}
   160  	}
   161  	if len(queue) == 0 {
   162  		return fmt.Errorf("no starting gates")
   163  	}
   164  	for len(queue) > 0 {
   165  		reordered = append(reordered, f.Gates[queue[0]])
   166  		next := inputGate[f.Gates[queue[0]].Out-(n+1)]
   167  		if next >= 0 {
   168  			counts[next]--
   169  			if counts[next] == 0 {
   170  				queue = append(queue, next)
   171  			}
   172  		}
   173  		queue = queue[1:]
   174  	}
   175  	if len(reordered) != n {
   176  		return fmt.Errorf("not all gates were extracted. check for loops")
   177  	}
   178  
   179  	f.Gates = reordered
   180  	return nil
   181  }
   182  
   183  // Given a set of possible Inputs (not necessarily in order!)
   184  // return a subset that satisfy the formula with no extras.
   185  func (f *Formula) satisfaction(available []match) ([]match, error) {
   186  	err := f.toposort()
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	n := len(f.Gates)
   191  	assignments := make([][]int, 2*n+1)
   192  	for _, match := range available {
   193  		assignments[match.wire] = []int{match.wire}
   194  	}
   195  	for _, gate := range f.Gates {
   196  		switch gate.Class {
   197  		case Andgate:
   198  			if assignments[gate.In0] == nil || assignments[gate.In1] == nil {
   199  				assignments[gate.Out] = nil
   200  			} else {
   201  				assignments[gate.Out] = make([]int, 0, len(assignments[gate.In0])+len(assignments[gate.In1]))
   202  				assignments[gate.Out] = append(assignments[gate.Out], assignments[gate.In0]...)
   203  				assignments[gate.Out] = append(assignments[gate.Out], assignments[gate.In1]...)
   204  			}
   205  		case Orgate:
   206  			if assignments[gate.In0] == nil && assignments[gate.In1] == nil {
   207  				assignments[gate.Out] = nil
   208  			} else {
   209  				assignments[gate.Out] = assignments[gate.In0]
   210  				if assignments[gate.Out] == nil {
   211  					assignments[gate.Out] = assignments[gate.In1]
   212  				}
   213  				if (len(assignments[gate.In1]) < len(assignments[gate.Out])) && assignments[gate.In1] != nil {
   214  					assignments[gate.Out] = assignments[gate.In1]
   215  				}
   216  			}
   217  		default:
   218  			return nil, fmt.Errorf("unmatched case")
   219  		}
   220  	}
   221  	if assignments[2*n] == nil {
   222  		return nil, fmt.Errorf("no satisfying assignment")
   223  	}
   224  	ret := make([]match, 0)
   225  	for _, wire := range assignments[2*n] {
   226  		for _, match := range available {
   227  			if match.wire == wire {
   228  				ret = append(ret, match)
   229  			}
   230  		}
   231  	}
   232  	return ret, nil
   233  }
   234  
   235  // share distributes an input into shares for a secret sharing system
   236  // for the formula: the original vector can be recovered from shares
   237  // that satisfy the formula, by adding them all up.
   238  func (f *Formula) share(rand io.Reader, k *matrixZp) ([]*matrixZp, error) {
   239  	err := f.toposort()
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  	n := len(f.Gates)
   244  	shares := make([]*matrixZp, 2*n+1)
   245  	// Reverse order: we want to set the share of the output ahead of the Inputs
   246  	shares[2*n] = k
   247  	for i := len(f.Gates) - 1; i >= 0; i-- {
   248  		gate := f.Gates[i]
   249  		switch gate.Class {
   250  		case Andgate:
   251  			shares[gate.In0], err = randomMatrixZp(rand, k.rows, k.cols)
   252  			if err != nil {
   253  				return nil, err
   254  			}
   255  			shares[gate.In1] = newMatrixZp(k.rows, k.cols)
   256  			shares[gate.In0].sub(shares[gate.Out], shares[gate.In1])
   257  
   258  		case Orgate:
   259  			shares[gate.In0] = newMatrixZp(k.rows, k.cols)
   260  			shares[gate.In0].set(shares[gate.Out])
   261  			shares[gate.In1] = newMatrixZp(k.rows, k.cols)
   262  			shares[gate.In1].set(shares[gate.Out])
   263  		}
   264  	}
   265  	return shares[0 : n+1], nil
   266  }
   267  
   268  // insertAnd adds and Gate for a new input
   269  func (f *Formula) insertAnd() Formula {
   270  	// Let n=3
   271  	// The old Inputs are 0, 1, 2, 3.
   272  	// Old intermediates 4, 5,
   273  	// Old output 6.
   274  	// The old Inputs are Inputs 0,1,2,3 and new input 4
   275  	// Intermediates are all shifted up by 1: 5, 6
   276  	// Old output is also shifted up but is the intermediate 7
   277  	// New output 8.
   278  	n := len(f.Gates)
   279  	gates := make([]Gate, len(f.Gates)+1)
   280  	newInput := func(in int) int {
   281  		if in > n {
   282  			return in + 1
   283  		} else {
   284  			return in
   285  		}
   286  	}
   287  
   288  	for i := 0; i < n; i++ {
   289  		gates[i].Class = f.Gates[i].Class
   290  		gates[i].In0 = newInput(f.Gates[i].In0)
   291  		gates[i].In1 = newInput(f.Gates[i].In1)
   292  		gates[i].Out = f.Gates[i].Out + 1
   293  	}
   294  	gates[n].Class = Andgate
   295  	// if there were zero gates, then In0 = 0, In1 = 1, Out = 2
   296  	if n == 0 {
   297  		gates[n].In0 = n
   298  	} else {
   299  		gates[n].In0 = n + 1
   300  	}
   301  	gates[n].In1 = 2*n + 1
   302  	gates[n].Out = 2*n + 2
   303  	return Formula{
   304  		Gates: gates,
   305  	}
   306  }
   307  
   308  func (f *Formula) Equal(g Formula) bool {
   309  	if len(f.Gates) != len(g.Gates) {
   310  		return false
   311  	}
   312  	for i := 0; i < len(f.Gates); i++ {
   313  		if !f.Gates[i].Equal(g.Gates[i]) {
   314  			return false
   315  		}
   316  	}
   317  	return true
   318  }