github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/cmd/tokens_transformer/tokens_transformer.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"io"
     6  	"log"
     7  	"os"
     8  
     9  	"github.com/wbrown/gpt_bpe"
    10  )
    11  
    12  func main() {
    13  	inputTokenizerId := flag.String("input_tokenizer", "gpt2",
    14  		"input tokenizer id [gpt2, pile, clip, huggingface-id]")
    15  	outputTokenizerId := flag.String("output_tokenizer", "gpt2",
    16  		"output tokenizer id [gpt2, pile, clip, huggingface-id]")
    17  	contextSize := flag.Int("context_size", 2048,
    18  		"number of tokens to use as context")
    19  	showContexts := flag.Bool("show_contexts", false,
    20  		"show contexts as they are retokenized")
    21  	unitrimBool := flag.Bool("no_unitrim", false,
    22  		"do not trim to valid unicode retokenized contexts")
    23  	in32 := flag.Bool("in32", false,
    24  		"force input tokens to be read as 32-bit")
    25  	out32 := flag.Bool("out32", false,
    26  		"force output tokens to be written as 32-bit")
    27  	inputFile := flag.String("input", "",
    28  		"input file to retokenize")
    29  	outputFile := flag.String("output", "retokenized.tokens",
    30  		"output file to write retokenized data")
    31  	flag.Parse()
    32  	if *inputFile == "" {
    33  		flag.Usage()
    34  		log.Fatal("Must provide -input")
    35  	}
    36  	if *inputTokenizerId == "" {
    37  		flag.Usage()
    38  		log.Fatal("Must provide -input_tokenizer")
    39  	}
    40  	if *outputTokenizerId == "" {
    41  		flag.Usage()
    42  		log.Fatal("Must provide -output_tokenizer")
    43  	}
    44  	if *contextSize < 1 {
    45  		flag.Usage()
    46  		log.Fatal("Context size must be greater than 0")
    47  	}
    48  	// check if input and output tokenizers are the same
    49  	if *inputTokenizerId == *outputTokenizerId {
    50  		log.Fatal("Input and output tokenizers must be different")
    51  	}
    52  	// check if input and output files are the same
    53  	if *inputFile == *outputFile {
    54  		log.Fatal("Input and output files must be different")
    55  	}
    56  	// check if input file exists
    57  	if _, err := os.Stat(*inputFile); os.IsNotExist(err) {
    58  		log.Fatal("Input file does not exist")
    59  	}
    60  
    61  	// Check if it's an internal reference. If not, it's a file path.
    62  	inputTokenizer, inputErr := gpt_bpe.NewEncoder(
    63  		*inputTokenizerId + "-tokenizer")
    64  	if inputErr != nil {
    65  		// Fall back to path-like.
    66  		inputTokenizer, inputErr = gpt_bpe.NewEncoder(*inputTokenizerId)
    67  		if inputErr != nil {
    68  			log.Fatal(inputErr)
    69  		}
    70  	}
    71  	input32Bit := *in32 || len(inputTokenizer.Encoder) > 65536
    72  
    73  	outputTokenizer, outputErr := gpt_bpe.NewEncoder(
    74  		*outputTokenizerId + "-tokenizer")
    75  	if outputErr != nil {
    76  		// Fall back to path-like.
    77  		outputTokenizer, outputErr = gpt_bpe.NewEncoder(*outputTokenizerId)
    78  		if outputErr != nil {
    79  			log.Fatal(outputErr)
    80  		}
    81  	}
    82  	output32Bit := *out32 || len(outputTokenizer.Encoder) > 65536
    83  
    84  	// open input file
    85  	inputFileHandle, inputOpenErr := os.Open(*inputFile)
    86  	if inputOpenErr != nil {
    87  		log.Fatal(inputOpenErr)
    88  	}
    89  	defer inputFileHandle.Close()
    90  	// open output file
    91  	outputFileHandle, outputOpenErr := os.Create(*outputFile)
    92  	if outputOpenErr != nil {
    93  		log.Fatal(outputOpenErr)
    94  	}
    95  	defer outputFileHandle.Close()
    96  	// create context buffer
    97  	contextBuffer := make([]byte, *contextSize*2)
    98  
    99  	// read input context by context
   100  	for {
   101  		// read next context
   102  		bytesRead, readErr := inputFileHandle.Read(contextBuffer)
   103  		if bytesRead <= 0 {
   104  			break
   105  		}
   106  		if readErr != nil {
   107  			// Check if we reached the end of the file
   108  			if readErr == io.EOF {
   109  				break
   110  			} else {
   111  				log.Fatal(readErr)
   112  			}
   113  		} else if bytesRead <= 0 {
   114  			break
   115  		}
   116  
   117  		if input32Bit {
   118  			log.Println("Reading as 32-bit")
   119  		} else {
   120  			log.Println("Reading as 16-bit")
   121  		}
   122  		context := contextBuffer[:bytesRead]
   123  		decoded := inputTokenizer.DecodeBuffer(&context, input32Bit)
   124  		encoded := outputTokenizer.Encode(&decoded)
   125  		// trim encoded tokens to context size
   126  		if len(*encoded) > *contextSize {
   127  			trimmed := (*encoded)[:*contextSize]
   128  			encoded = &trimmed
   129  		}
   130  		if *unitrimBool {
   131  			encoded = outputTokenizer.TrimTokens(encoded)
   132  		}
   133  		// pad out context
   134  		if len(*encoded) < *contextSize {
   135  			padded := make(gpt_bpe.Tokens, *contextSize)
   136  			copy(padded, *encoded)
   137  			for i := len(*encoded); i < *contextSize; i++ {
   138  				padded[i] = outputTokenizer.PadToken
   139  			}
   140  			encoded = &padded
   141  		}
   142  		// write encoded context to output file
   143  		if output32Bit {
   144  			log.Println("Writing as 32-bit")
   145  		} else {
   146  			log.Println("Writing as 16-bit")
   147  		}
   148  		bytesToWrite, _ := encoded.ToBin(output32Bit)
   149  		bytesWritten, writeErr := outputFileHandle.Write(*bytesToWrite)
   150  
   151  		if writeErr != nil {
   152  			log.Fatal(writeErr)
   153  		}
   154  		if bytesWritten != len(*bytesToWrite) {
   155  			log.Fatal("Could not write full context")
   156  		}
   157  		if *showContexts {
   158  			log.Printf("Input: %s", decoded)
   159  			log.Printf("Output: %s", outputTokenizer.Decode(encoded))
   160  		}
   161  	}
   162  }