
     1  /*
     2   * Copyright (c) 2024 The GoPlus Authors ( All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    17  package llama2
    19  import (
    20  	_ "unsafe"
    22  	""
    23  )
    25  const (
    26  	LLGoPackage = "link"
    27  )
    29  // -----------------------------------------------------------------------------
    31  // llgo:type C
    32  type TokenIndex struct {
    33  	Str *c.Char
    34  	Id  c.Int
    35  }
    37  // llgo:type C
    38  type Tokenizer struct {
    39  	Vocab          **c.Char
    40  	VocabScores    *c.Float
    41  	SortedVocab    *TokenIndex
    42  	VocabSize      c.Int
    43  	MaxTokenLength c.Uint
    44  	BytePieces     [512]uint8 // stores all single-byte strings
    45  }
    47  //go:linkname BuildTokenizer C.build_tokenizer
    48  func BuildTokenizer(t *Tokenizer, tokenizerPath *c.Char, vocabSize c.Int)
    50  //go:linkname FreeTokenizer C.free_tokenizer
    51  func FreeTokenizer(t *Tokenizer)
    53  // -----------------------------------------------------------------------------
    55  // llgo:type C
    56  type Config struct {
    57  	Dim       c.Int // transformer dimension
    58  	HiddenDim c.Int // for ffn layers
    59  	NLayers   c.Int // number of layers
    60  	NHeads    c.Int // number of query heads
    61  	NKVHeads  c.Int // number of key/value heads (can be < query heads because of multiquery)
    62  	VocabSize c.Int // vocabulary size, usually 256 (byte-level)
    63  	SeqLen    c.Int // max sequence length
    64  }
    66  // llgo:type C
    67  type TransformerWeights struct {
    68  	// token embedding table
    69  	TokenEmbeddingTable *c.Float // (vocab_size, dim)
    70  	// weights for rmsnorms
    71  	RmsAttWeight *c.Float // (layer, dim) rmsnorm weights
    72  	RmsFfnWeight *c.Float // (layer, dim)
    73  	// weights for matmuls. note dim == n_heads * head_size
    74  	Wq *c.Float // (layer, dim, n_heads * head_size)
    75  	Wk *c.Float // (layer, dim, n_kv_heads * head_size)
    76  	Wv *c.Float // (layer, dim, n_kv_heads * head_size)
    77  	Wo *c.Float // (layer, n_heads * head_size, dim)
    78  	// weights for ffn
    79  	W1 *c.Float // (layer, hidden_dim, dim)
    80  	W2 *c.Float // (layer, dim, hidden_dim)
    81  	W3 *c.Float // (layer, hidden_dim, dim)
    82  	// final rmsnorm
    83  	RmsFinalWeight *c.Float // (dim,)
    84  	// (optional) classifier weights for the logits, on the last layer
    85  	Wcls *c.Float
    86  }
    88  // llgo:type C
    89  type RunState struct {
    90  	// current wave of activations
    91  	X      *c.Float // activation at current time stamp (dim,)
    92  	Xb     *c.Float // same, but inside a residual branch (dim,)
    93  	Xb2    *c.Float // an additional buffer just for convenience (dim,)
    94  	Hb     *c.Float // buffer for hidden dimension in the ffn (hidden_dim,)
    95  	Hb2    *c.Float // buffer for hidden dimension in the ffn (hidden_dim,)
    96  	Q      *c.Float // query (dim,)
    97  	K      *c.Float // key (dim,)
    98  	V      *c.Float // value (dim,)
    99  	Att    *c.Float // buffer for scores/attention values (n_heads, seq_len)
   100  	Logits *c.Float // output logits
   101  	// kv cache
   102  	KeyCache   *c.Float // (layer, seq_len, dim)
   103  	ValueCache *c.Float // (layer, seq_len, dim)
   104  }
   106  // llgo:type C
   107  type Transformer struct {
   108  	Config  Config             // the hyperparameters of the architecture (the blueprint)
   109  	Weights TransformerWeights // the weights of the model
   110  	State   RunState           // buffers for the "wave" of activations in the forward pass
   112  	// some more state needed to properly clean up the memory mapping (sigh)
   113  	Fd       c.Int    // file descriptor for memory mapping
   114  	Data     *c.Float // memory mapped data pointer
   115  	FileSize uintptr  // size of the checkpoint file in bytes
   116  }
   118  //go:linkname BuildTransformer C.build_transformer
   119  func BuildTransformer(t *Transformer, checkpointPath *c.Char)
   121  //go:linkname FreeTransformer C.free_transformer
   122  func FreeTransformer(t *Transformer)
   124  // -----------------------------------------------------------------------------
   126  // llgo:type C
   127  type ProbIndex struct {
   128  	Prob  c.Float
   129  	Index c.Int
   130  } // struct used when sorting probabilities during top-p sampling
   132  // llgo:type C
   133  type Sampler struct {
   134  	VocabSize   c.Int
   135  	Probindex   *ProbIndex // buffer used in top-p sampling
   136  	Temperature c.Float
   137  	Topp        c.Float
   138  	RngState    uint64
   139  }
   141  //go:linkname BuildSampler C.build_sampler
   142  func BuildSampler(sampler *Sampler, vocabSize c.Int, temperature c.Float, topp c.Float, rngSeed uint64)
   144  //go:linkname FreeSampler C.free_sampler
   145  func FreeSampler(sampler *Sampler)
   147  // -----------------------------------------------------------------------------
   149  //go:linkname Generate C.generate
   150  func Generate(
   151  	transformer *Transformer, tokenizer *Tokenizer, sampler *Sampler,
   152  	prompt *c.Char, steps c.Int)
   154  //go:linkname Chat
   155  func Chat(
   156  	transformer *Transformer, tokenizer *Tokenizer, sampler *Sampler,
   157  	cliUserPrompt *c.Char, cliSystemPrompt *c.Char, steps c.Int)
   159  // -----------------------------------------------------------------------------