github.com/consensys/gnark-crypto@v0.14.0/field/generator/internal/templates/element/vector.go (about)

     1  package element
     2  
     3  const Vector = `
     4  import (
     5  	"io"
     6  	"encoding/binary"
     7  	"strings"
     8  	"bytes"
     9  	"runtime"
    10  	"unsafe"
    11  	"sync"
    12  	"sync/atomic"
    13  	"fmt"
    14  )
    15  
    16  // Vector represents a slice of {{.ElementName}}.
    17  // 
    18  // It implements the following interfaces:
    19  //	- Stringer
    20  //	- io.WriterTo
    21  //	- io.ReaderFrom
    22  //	- encoding.BinaryMarshaler
    23  //	- encoding.BinaryUnmarshaler
    24  //	- sort.Interface
    25  type Vector []{{.ElementName}}
    26  
    27  // MarshalBinary implements encoding.BinaryMarshaler
    28  func (vector *Vector) MarshalBinary() (data []byte, err error) {
    29  	var buf bytes.Buffer
    30  
    31  	if _, err = vector.WriteTo(&buf); err != nil {
    32  		return
    33  	}
    34  	return buf.Bytes(), nil
    35  }
    36  
    37  
    38  // UnmarshalBinary implements encoding.BinaryUnmarshaler
    39  func (vector *Vector) UnmarshalBinary(data []byte) error {
    40  	r := bytes.NewReader(data)
    41  	_, err := vector.ReadFrom(r)
    42  	return err
    43  }
    44  
    45  // WriteTo implements io.WriterTo and writes a vector of big endian encoded {{.ElementName}}.
    46  // Length of the vector is encoded as a uint32 on the first 4 bytes.
    47  func (vector *Vector) WriteTo(w io.Writer) (int64, error) {
    48      // encode slice length
    49      if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil {
    50          return 0, err 
    51      }
    52  
    53  	n := int64(4)
    54  
    55  	var buf [Bytes]byte 
    56  	for i := 0; i < len(*vector); i++ {
    57  		BigEndian.PutElement(&buf, (*vector)[i])
    58  		m, err := w.Write(buf[:])
    59  		n += int64(m)
    60  		if err != nil {
    61  			return n, err 
    62  		} 
    63  	}
    64  	return n, nil
    65  }
    66  
    67  // AsyncReadFrom reads a vector of big endian encoded {{.ElementName}}.
    68  // Length of the vector must be encoded as a uint32 on the first 4 bytes.
    69  // It consumes the needed bytes from the reader and returns the number of bytes read and an error if any.
    70  // It also returns a channel that will be closed when the validation is done.
    71  // The validation consist of checking that the elements are smaller than the modulus, and
    72  // converting them to montgomery form.
    73  func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) {
    74  	chErr := make(chan error, 1)
    75  	var buf [Bytes]byte 
    76  	if read, err := io.ReadFull(r, buf[:4]); err != nil {
    77  		close(chErr)
    78          return int64(read), err, chErr
    79      }
    80  	sliceLen := binary.BigEndian.Uint32(buf[:4])
    81  
    82      n := int64(4)
    83  	(*vector) = make(Vector, sliceLen)
    84  	if sliceLen == 0 {
    85  		close(chErr)
    86  		return n, nil, chErr
    87  	}
    88  
    89  	bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes)
    90  	read, err := io.ReadFull(r, bSlice)
    91  	n += int64(read)
    92  	if err != nil {
    93  		close(chErr)
    94  		return n, err, chErr
    95  	}
    96  
    97  
    98  	go func() {
    99  		var cptErrors uint64
   100  		// process the elements in parallel
   101  		execute(int(sliceLen), func(start, end int) {
   102  			
   103  			var z {{.ElementName}}
   104  			for i:=start; i < end; i++ {
   105  				// we have to set vector[i]
   106  				bstart := i*Bytes
   107  				bend := bstart + Bytes
   108  				b := bSlice[bstart:bend]
   109  				{{- range $i := reverse .NbWordsIndexesFull}}
   110  					{{- $j := mul $i 8}}
   111  					{{- $k := sub $.NbWords 1}}
   112  					{{- $k := sub $k $i}}
   113  					{{- $jj := add $j 8}}
   114  					z[{{$k}}] = binary.BigEndian.Uint64(b[{{$j}}:{{$jj}}])
   115  				{{- end}}
   116  
   117  				if !z.smallerThanModulus() {
   118  					atomic.AddUint64(&cptErrors, 1)
   119  					return
   120  				}
   121  				z.toMont()
   122  				(*vector)[i] = z
   123  			}
   124  		})
   125  
   126  		if cptErrors > 0 {
   127  			chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors)
   128  		}
   129  		close(chErr)
   130  	}()
   131  	return n, nil, chErr
   132  }
   133  
   134  // ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded {{.ElementName}}.
   135  // Length of the vector must be encoded as a uint32 on the first 4 bytes.
   136  func (vector *Vector) ReadFrom(r io.Reader) (int64, error) {
   137  
   138  	var buf [Bytes]byte 
   139  	if read, err := io.ReadFull(r, buf[:4]); err != nil {
   140          return int64(read), err 
   141      }
   142  	sliceLen := binary.BigEndian.Uint32(buf[:4])
   143  
   144      n := int64(4)
   145  	(*vector) = make(Vector, sliceLen)
   146  
   147      for i:=0; i < int(sliceLen); i++ {
   148          read, err := io.ReadFull(r, buf[:])
   149          n += int64(read)
   150          if err != nil {
   151              return n, err
   152          }
   153  		(*vector)[i], err = BigEndian.Element(&buf)
   154  		if err != nil {
   155  			return n, err
   156  		}
   157      }
   158  	
   159  
   160      return n, nil 
   161  }
   162  
   163  // String implements fmt.Stringer interface
   164  func (vector Vector) String() string {
   165      var sbb strings.Builder
   166      sbb.WriteByte('[')
   167      for i:=0; i < len(vector); i++ {
   168          sbb.WriteString(vector[i].String())
   169  		if i != len(vector) - 1 {
   170  			sbb.WriteByte(',')
   171  		}
   172      }
   173      sbb.WriteByte(']')
   174      return sbb.String()
   175  }
   176  
   177  
   178  // Len is the number of elements in the collection.
   179  func (vector Vector) Len() int {
   180  	return len(vector)
   181  }
   182  
   183  // Less reports whether the element with
   184  // index i should sort before the element with index j.
   185  func (vector Vector) Less(i, j int) bool {
   186  	return vector[i].Cmp(&vector[j]) == -1
   187  }
   188  
   189  // Swap swaps the elements with indexes i and j.
   190  func (vector Vector) Swap(i, j int) {
   191  	vector[i], vector[j] = vector[j], vector[i]
   192  }
   193  
   194  
   195  // TODO @gbotrel make a public package out of that.
   196  // execute executes the work function in parallel.
   197  // this is copy paste from internal/parallel/parallel.go
   198  // as we don't want to generate code importing internal/ 
   199  func execute(nbIterations int, work func(int, int), maxCpus ...int) {
   200  
   201  	nbTasks := runtime.NumCPU()
   202  	if len(maxCpus) == 1 {
   203  		nbTasks = maxCpus[0]
   204  		if nbTasks < 1 {
   205  			nbTasks = 1
   206  		} else if nbTasks > 512 {
   207  			nbTasks = 512
   208  		}
   209  	}
   210  
   211  	if nbTasks == 1 {
   212  		// no go routines
   213  		work(0, nbIterations)
   214  		return
   215  	}
   216  
   217  	nbIterationsPerCpus := nbIterations / nbTasks
   218  
   219  	// more CPUs than tasks: a CPU will work on exactly one iteration
   220  	if nbIterationsPerCpus < 1 {
   221  		nbIterationsPerCpus = 1
   222  		nbTasks = nbIterations
   223  	}
   224  
   225  	var wg sync.WaitGroup
   226  
   227  	extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus)
   228  	extraTasksOffset := 0
   229  
   230  	for i := 0; i < nbTasks; i++ {
   231  		wg.Add(1)
   232  		_start := i*nbIterationsPerCpus + extraTasksOffset
   233  		_end := _start + nbIterationsPerCpus
   234  		if extraTasks > 0 {
   235  			_end++
   236  			extraTasks--
   237  			extraTasksOffset++
   238  		}
   239  		go func() {
   240  			work(_start, _end)
   241  			wg.Done()
   242  		}()
   243  	}
   244  
   245  	wg.Wait()
   246  }
   247  
   248  
   249  `