github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/types/methods.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  )
     8  
     9  func (tokens *Tokens) ToBin(useUint32 bool) (*[]byte, error) {
    10  	if useUint32 {
    11  		return tokens.ToBinUint32()
    12  	} else {
    13  		return tokens.ToBinUint16()
    14  	}
    15  }
    16  
    17  func (tokens *Tokens) ToBinUint16() (*[]byte, error) {
    18  	buf := bytes.NewBuffer(make([]byte, 0, len(*tokens)*2))
    19  	for idx := range *tokens {
    20  		bs := (*tokens)[idx]
    21  		if bs > 65535 {
    22  			return nil, fmt.Errorf("integer overflow: tried to write token ID %d as unsigned 16-bit", bs)
    23  		}
    24  		err := binary.Write(buf, binary.LittleEndian, uint16(bs))
    25  		if err != nil {
    26  			return nil, err
    27  		}
    28  	}
    29  	byt := buf.Bytes()
    30  	return &byt, nil
    31  }
    32  
    33  func (tokens *Tokens) ToBinUint32() (*[]byte, error) {
    34  	buf := bytes.NewBuffer(make([]byte, 0, len(*tokens)*4))
    35  	for idx := range *tokens {
    36  		bs := (*tokens)[idx]
    37  		err := binary.Write(buf, binary.LittleEndian, uint32(bs))
    38  		if err != nil {
    39  			return nil, err
    40  		}
    41  	}
    42  	byt := buf.Bytes()
    43  	return &byt, nil
    44  }
    45  
    46  func TokensFromBin(bin *[]byte) *Tokens {
    47  	tokens := make(Tokens, 0, len(*bin)/2)
    48  	buf := bytes.NewReader(*bin)
    49  	for {
    50  		var token uint16
    51  		if err := binary.Read(buf, binary.LittleEndian, &token); err != nil {
    52  			break
    53  		}
    54  		tokens = append(tokens, Token(token))
    55  	}
    56  	return &tokens
    57  }
    58  
    59  func TokensFromBin32(bin *[]byte) *Tokens {
    60  	tokens := make(Tokens, 0, len(*bin)/4)
    61  	buf := bytes.NewReader(*bin)
    62  	for {
    63  		var token Token
    64  		if err := binary.Read(buf, binary.LittleEndian, &token); err != nil {
    65  			break
    66  		}
    67  		tokens = append(tokens, token)
    68  	}
    69  	return &tokens
    70  }