github.com/cloudflare/circl@v1.5.0/abe/cpabe/tkn20/internal/tkn/matrixGT.go (about)

     1  package tkn
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  
     8  	pairing "github.com/cloudflare/circl/ecc/bls12381"
     9  )
    10  
    11  // matrixGT represents a matrix of GT elements. They are stored in row-major order.
    12  type matrixGT struct {
    13  	rows    int
    14  	cols    int
    15  	entries []pairing.Gt
    16  }
    17  
    18  func (m *matrixGT) marshalBinary() ([]byte, error) {
    19  	ret := make([]byte, 4+pairing.GtSize*m.rows*m.cols)
    20  	binary.LittleEndian.PutUint16(ret[0:], uint16(m.rows))
    21  	binary.LittleEndian.PutUint16(ret[2:], uint16(m.cols))
    22  	for i := 0; i < m.rows*m.cols; i++ {
    23  		pt, err := m.entries[i].MarshalBinary()
    24  		if err != nil {
    25  			return nil, err
    26  		}
    27  		if len(pt) != pairing.GtSize {
    28  			panic("matrixGT: incorrect assumption of size")
    29  		}
    30  		copy(ret[pairing.GtSize*i+4:], pt)
    31  	}
    32  	return ret, nil
    33  }
    34  
    35  func (m *matrixGT) unmarshalBinary(data []byte) error {
    36  	if len(data) < 4 {
    37  		return fmt.Errorf("matrixGT deserialization failure: input too short")
    38  	}
    39  	m.rows = int(binary.LittleEndian.Uint16(data[0:]))
    40  	m.cols = int(binary.LittleEndian.Uint16(data[2:]))
    41  	data = data[4:]
    42  	if len(data) != pairing.GtSize*m.rows*m.cols {
    43  		return fmt.Errorf("matrixGT deserialization failure: invalid entries length: expected %d, actual %d",
    44  			pairing.GtSize*m.cols*m.rows,
    45  			len(data))
    46  	}
    47  	m.entries = make([]pairing.Gt, m.rows*m.cols)
    48  	var err error
    49  	for i := 0; i < m.rows*m.cols; i++ {
    50  		err = m.entries[i].UnmarshalBinary(data[pairing.GtSize*i : pairing.GtSize*(i+1)])
    51  		if err != nil {
    52  			return fmt.Errorf("matrixGT deserialization failure: error from bytes %v: %w",
    53  				data[pairing.GtSize*i:pairing.GtSize*(i+1)],
    54  				err)
    55  		}
    56  	}
    57  	return nil
    58  }
    59  
    60  // exp computes the naive matrix exponential of a with respect to the basepoint.
    61  func (m *matrixGT) exp(a *matrixZp) {
    62  	basepoint := gtBaseVal
    63  	m.resize(a.rows, a.cols)
    64  	for i := 0; i < m.rows*m.cols; i++ {
    65  		m.entries[i].Exp(basepoint, &a.entries[i])
    66  	}
    67  }
    68  
    69  // resize sets up m to be r x c
    70  func (m *matrixGT) resize(r int, c int) {
    71  	if m.rows != r || m.cols != c {
    72  		m.rows = r
    73  		m.cols = c
    74  		m.entries = make([]pairing.Gt, m.rows*m.cols)
    75  	}
    76  }
    77  
    78  // clear sets m to be the "zero" matrix
    79  func (m *matrixGT) clear() {
    80  	for i := 0; i < len(m.entries); i++ {
    81  		m.entries[i].SetIdentity()
    82  	}
    83  }
    84  
    85  // conformal returns true iff m and a have the same dimensions.
    86  func (m *matrixGT) conformal(a *matrixGT) bool {
    87  	return a.rows == m.rows && a.cols == m.cols
    88  }
    89  
    90  // Equal returns true if m == b.
    91  func (m *matrixGT) Equal(b *matrixGT) bool {
    92  	if !m.conformal(b) {
    93  		return false
    94  	}
    95  	for i := 0; i < m.rows; i++ {
    96  		for j := 0; j < m.cols; j++ {
    97  			if !m.entries[i*m.cols+j].IsEqual(&b.entries[i*b.cols+j]) {
    98  				return false
    99  			}
   100  		}
   101  	}
   102  	return true
   103  }
   104  
   105  // set sets m to b.
   106  func (m *matrixGT) set(b *matrixGT) {
   107  	m.resize(b.rows, b.cols)
   108  	copy(m.entries, b.entries)
   109  }
   110  
   111  // add sets m to a+b.
   112  func (m *matrixGT) add(a *matrixGT, b *matrixGT) {
   113  	if !a.conformal(b) {
   114  		panic(errBadMatrixSize)
   115  	}
   116  	m.resize(a.rows, a.cols)
   117  	for i := 0; i < m.rows*m.cols; i++ {
   118  		m.entries[i].Mul(&a.entries[i], &b.entries[i])
   119  	}
   120  }
   121  
   122  // leftMult multiples a*b with a matrixZp, b matrixGT.
   123  func (m *matrixGT) leftMult(a *matrixZp, b *matrixGT) {
   124  	if a.cols != b.rows {
   125  		panic(errBadMatrixSize)
   126  	}
   127  	if m == b {
   128  		c := newMatrixGT(a.rows, a.cols)
   129  		c.set(b)
   130  		b = c
   131  	}
   132  	m.resize(a.rows, b.cols)
   133  	m.clear()
   134  	tmp := &pairing.Gt{}
   135  	for i := 0; i < m.rows; i++ {
   136  		for j := 0; j < m.cols; j++ {
   137  			for k := 0; k < a.cols; k++ {
   138  				tmp.Exp(&b.entries[k*b.cols+j], &a.entries[i*a.cols+k])
   139  				m.entries[i*m.cols+j].Mul(&m.entries[i*m.cols+j], tmp)
   140  			}
   141  		}
   142  	}
   143  }
   144  
   145  // rightMult multiplies a*b with a matrixG1, b matrixZp.
   146  func (m *matrixGT) rightMult(a *matrixGT, b *matrixZp) {
   147  	if a.cols != b.rows {
   148  		panic(errBadMatrixSize)
   149  	}
   150  	if m == a {
   151  		c := newMatrixGT(a.rows, a.cols)
   152  		c.set(a)
   153  		a = c
   154  	}
   155  	m.resize(a.rows, b.cols)
   156  	m.clear()
   157  	tmp := &pairing.Gt{}
   158  	// to transpose can index bt[i,j] as b.entries[j*b.rows+i]
   159  	for i := 0; i < m.rows; i++ {
   160  		for j := 0; j < m.cols; j++ {
   161  			for k := 0; k < a.cols; k++ {
   162  				tmp.Exp(&a.entries[i*a.cols+k], &b.entries[k*b.cols+j])
   163  				m.entries[i*m.cols+j].Mul(&m.entries[i*m.cols+j], tmp)
   164  			}
   165  		}
   166  	}
   167  }
   168  
   169  func newMatrixGT(r int, c int) *matrixGT {
   170  	ret := new(matrixGT)
   171  	ret.resize(r, c)
   172  	ret.clear()
   173  	return ret
   174  }
   175  
   176  func randomMatrixGT(rand io.Reader, r int, c int) (*matrixGT, error) {
   177  	a, err := randomMatrixZp(rand, r, c)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	ret := newMatrixGT(r, c)
   182  	ret.exp(a)
   183  	return ret, nil
   184  }