github.com/goplus/llgo@v0.8.3/c/llama2/llama2.go (about) 1 /* 2 * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package llama2 18 19 import ( 20 _ "unsafe" 21 22 "github.com/goplus/llgo/c" 23 ) 24 25 const ( 26 LLGoPackage = "link" 27 ) 28 29 // ----------------------------------------------------------------------------- 30 31 // llgo:type C 32 type TokenIndex struct { 33 Str *c.Char 34 Id c.Int 35 } 36 37 // llgo:type C 38 type Tokenizer struct { 39 Vocab **c.Char 40 VocabScores *c.Float 41 SortedVocab *TokenIndex 42 VocabSize c.Int 43 MaxTokenLength c.Uint 44 BytePieces [512]uint8 // stores all single-byte strings 45 } 46 47 //go:linkname BuildTokenizer C.build_tokenizer 48 func BuildTokenizer(t *Tokenizer, tokenizerPath *c.Char, vocabSize c.Int) 49 50 //go:linkname FreeTokenizer C.free_tokenizer 51 func FreeTokenizer(t *Tokenizer) 52 53 // ----------------------------------------------------------------------------- 54 55 // llgo:type C 56 type Config struct { 57 Dim c.Int // transformer dimension 58 HiddenDim c.Int // for ffn layers 59 NLayers c.Int // number of layers 60 NHeads c.Int // number of query heads 61 NKVHeads c.Int // number of key/value heads (can be < query heads because of multiquery) 62 VocabSize c.Int // vocabulary size, usually 256 (byte-level) 63 SeqLen c.Int // max sequence length 64 } 65 66 // llgo:type C 67 type TransformerWeights struct { 68 // token embedding table 69 TokenEmbeddingTable *c.Float // (vocab_size, dim) 70 // weights for rmsnorms 71 RmsAttWeight *c.Float // (layer, dim) rmsnorm weights 72 RmsFfnWeight *c.Float // (layer, dim) 73 // weights for matmuls. note dim == n_heads * head_size 74 Wq *c.Float // (layer, dim, n_heads * head_size) 75 Wk *c.Float // (layer, dim, n_kv_heads * head_size) 76 Wv *c.Float // (layer, dim, n_kv_heads * head_size) 77 Wo *c.Float // (layer, n_heads * head_size, dim) 78 // weights for ffn 79 W1 *c.Float // (layer, hidden_dim, dim) 80 W2 *c.Float // (layer, dim, hidden_dim) 81 W3 *c.Float // (layer, hidden_dim, dim) 82 // final rmsnorm 83 RmsFinalWeight *c.Float // (dim,) 84 // (optional) classifier weights for the logits, on the last layer 85 Wcls *c.Float 86 } 87 88 // llgo:type C 89 type RunState struct { 90 // current wave of activations 91 X *c.Float // activation at current time stamp (dim,) 92 Xb *c.Float // same, but inside a residual branch (dim,) 93 Xb2 *c.Float // an additional buffer just for convenience (dim,) 94 Hb *c.Float // buffer for hidden dimension in the ffn (hidden_dim,) 95 Hb2 *c.Float // buffer for hidden dimension in the ffn (hidden_dim,) 96 Q *c.Float // query (dim,) 97 K *c.Float // key (dim,) 98 V *c.Float // value (dim,) 99 Att *c.Float // buffer for scores/attention values (n_heads, seq_len) 100 Logits *c.Float // output logits 101 // kv cache 102 KeyCache *c.Float // (layer, seq_len, dim) 103 ValueCache *c.Float // (layer, seq_len, dim) 104 } 105 106 // llgo:type C 107 type Transformer struct { 108 Config Config // the hyperparameters of the architecture (the blueprint) 109 Weights TransformerWeights // the weights of the model 110 State RunState // buffers for the "wave" of activations in the forward pass 111 112 // some more state needed to properly clean up the memory mapping (sigh) 113 Fd c.Int // file descriptor for memory mapping 114 Data *c.Float // memory mapped data pointer 115 FileSize uintptr // size of the checkpoint file in bytes 116 } 117 118 //go:linkname BuildTransformer C.build_transformer 119 func BuildTransformer(t *Transformer, checkpointPath *c.Char) 120 121 //go:linkname FreeTransformer C.free_transformer 122 func FreeTransformer(t *Transformer) 123 124 // ----------------------------------------------------------------------------- 125 126 // llgo:type C 127 type ProbIndex struct { 128 Prob c.Float 129 Index c.Int 130 } // struct used when sorting probabilities during top-p sampling 131 132 // llgo:type C 133 type Sampler struct { 134 VocabSize c.Int 135 Probindex *ProbIndex // buffer used in top-p sampling 136 Temperature c.Float 137 Topp c.Float 138 RngState uint64 139 } 140 141 //go:linkname BuildSampler C.build_sampler 142 func BuildSampler(sampler *Sampler, vocabSize c.Int, temperature c.Float, topp c.Float, rngSeed uint64) 143 144 //go:linkname FreeSampler C.free_sampler 145 func FreeSampler(sampler *Sampler) 146 147 // ----------------------------------------------------------------------------- 148 149 //go:linkname Generate C.generate 150 func Generate( 151 transformer *Transformer, tokenizer *Tokenizer, sampler *Sampler, 152 prompt *c.Char, steps c.Int) 153 154 //go:linkname Chat C.chat 155 func Chat( 156 transformer *Transformer, tokenizer *Tokenizer, sampler *Sampler, 157 cliUserPrompt *c.Char, cliSystemPrompt *c.Char, steps c.Int) 158 159 // -----------------------------------------------------------------------------