github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/lib/library.go (about)

     1  package main
     2  
     3  /*
     4  #include "library.h"
     5  */
     6  import "C"
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"reflect"
    11  	"time"
    12  	"unsafe"
    13  
    14  	"github.com/wbrown/gpt_bpe"
    15  	"github.com/wbrown/gpt_bpe/types"
    16  )
    17  
    18  var tokenizers map[string]*gpt_bpe.GPTEncoder
    19  
    20  func init() {
    21  	tokenizers = make(map[string]*gpt_bpe.GPTEncoder)
    22  }
    23  
    24  // initTokenizer accepts a vocabulary id as a C string, and if it does not
    25  // exist in the global tokenizers map, initializes a tokenizer for that
    26  // vocabulary.
    27  //
    28  //export initTokenizer
    29  func initTokenizer(vocab_id *C.char) bool {
    30  	vocab_id_str := C.GoString(vocab_id)
    31  	if encoder, err := gpt_bpe.NewEncoder(vocab_id_str); err != nil {
    32  		panic(err)
    33  	} else {
    34  		tokenizers[vocab_id_str] = encoder
    35  		return true
    36  	}
    37  }
    38  
    39  // create a byte array using C memory for internal use
    40  func createBuffer(buf unsafe.Pointer, size int) *[]byte {
    41  	var res []byte
    42  	hdr := (*reflect.SliceHeader)(unsafe.Pointer(&res))
    43  	hdr.Data = uintptr(unsafe.Pointer(buf))
    44  	hdr.Len = size
    45  	hdr.Cap = size
    46  	return &res
    47  }
    48  
    49  //export tokenizeBuffer
    50  func tokenizeBuffer(vocabIdStr *C.char, buf *C.char, sz C.size_t) C.Tokens {
    51  	tokenizerId := C.GoString(vocabIdStr)
    52  	encoder, ok := tokenizers[tokenizerId]
    53  	if !ok {
    54  		initTokenizer(vocabIdStr)
    55  		encoder = tokenizers[tokenizerId]
    56  	}
    57  	goBuf := createBuffer(unsafe.Pointer(buf), int(sz))
    58  	encoded, tokenCount := encoder.EncodeBuffer(goBuf)
    59  	tokensArr := C.CBytes(*encoded)
    60  	tokens := C.Tokens{
    61  		tokens: (*C.uint32_t)(tokensArr),
    62  		len:    (C.size_t)(tokenCount),
    63  	}
    64  	return tokens
    65  }
    66  
    67  // tokenize accepts a vocabulary and text as a C string, and returns a C.Tokens
    68  // that contains a malloc'ed array of little-endian uint32_t tokens along with
    69  // the number of tokens.
    70  //
    71  //export tokenize
    72  func tokenize(vocabIdStr *C.char, str *C.char) C.Tokens {
    73  	tokenizerId := C.GoString(vocabIdStr)
    74  	encoder, ok := tokenizers[tokenizerId]
    75  	if !ok {
    76  		initTokenizer(vocabIdStr)
    77  		encoder = tokenizers[tokenizerId]
    78  	}
    79  	s := C.GoString(str)
    80  	fmt.Printf("input: %s\n", s)
    81  	encoded := *encoder.Encode(&s)
    82  	fmt.Printf("Tokens: %v\n", encoded)
    83  	encodedBinary, err := encoded.ToBin(true)
    84  	if err == nil || encodedBinary == nil {
    85  		_, _ = fmt.Fprintf(os.Stderr, "tokenize: failed to write tokens as uint32_t")
    86  		return C.Tokens{tokens: nil, len: 0}
    87  	}
    88  	tokensArr := C.CBytes(*encodedBinary)
    89  	tokens := C.Tokens{
    90  		tokens: (*C.uint32_t)(tokensArr),
    91  		len:    C.size_t(len(encoded)),
    92  	}
    93  	fmt.Printf("tokens: %p\n", &tokens)
    94  	fmt.Printf("tokens.tokens: %p\n", tokens.tokens)
    95  	fmt.Printf("*tokens.tokens: %v\n", tokens.tokens)
    96  	fmt.Printf("tokens.len: %v\n", len(encoded))
    97  	return tokens
    98  }
    99  
   100  // decode accepts a vocabulary id and a C.Tokens struct, and returns a malloc'ed
   101  // C.char* containing the decoded string.
   102  //
   103  //export decode
   104  func decode(vocabIdStr *C.char, tokens C.Tokens) *C.char {
   105  	tokenizerId := C.GoString(vocabIdStr)
   106  	encoder, ok := tokenizers[tokenizerId]
   107  	if !ok {
   108  		initTokenizer(vocabIdStr)
   109  		encoder = tokenizers[tokenizerId]
   110  	}
   111  	tokensArr := C.GoBytes(unsafe.Pointer(tokens.tokens), C.int(tokens.len)*4)
   112  	goTokens := types.TokensFromBin32(&tokensArr)
   113  	fmt.Printf("goTokens: %v\n", goTokens)
   114  	decoded := encoder.Decode(goTokens)
   115  	fmt.Printf("Decoded: %s\n", decoded)
   116  	return C.CString(decoded)
   117  }
   118  
   119  //export freeTokens
   120  func freeTokens(tokens C.Tokens) {
   121  	C.free(unsafe.Pointer(tokens.tokens))
   122  }
   123  
   124  // testBuffer tests the C interface to the tokenizer, and is here rather than
   125  // in the test package as the test package is incompatible with CGo.
   126  func testBuffer(vocab string, buf []byte) (time.Duration, uint64) {
   127  	vocabC := C.CString(vocab)
   128  	corpusBuff := (*C.char)(C.CBytes(buf))
   129  	start := time.Now()
   130  	tokens := tokenizeBuffer(vocabC, corpusBuff, C.size_t(len(buf)))
   131  	duration := time.Now().Sub(start)
   132  	return duration, uint64(tokens.len)
   133  }
   134  
   135  // wrapInitTokenizer is a wrapper around initTokenizer that simulates a C call
   136  // from golang.
   137  func wrapInitTokenizer(vocab_id string) bool {
   138  	vocab_id_str := C.CString(vocab_id)
   139  	return initTokenizer(vocab_id_str)
   140  }
   141  
   142  func main() {}