github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/cmd/model_downloader/model_resolver.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"github.com/wbrown/gpt_bpe/resources"
     6  	"log"
     7  	"os"
     8  )
     9  
    10  func main() {
    11  	modelId := flag.String("model", "",
    12  		"model URL, path, or HuggingFace id to fetch")
    13  	destPath := flag.String("dest", "./",
    14  		"where to download the model to")
    15  	modelType := flag.String("type", "transformers",
    16  		"model type (transformers or diffusers)")
    17  	tokenizerOnly := flag.Bool("tokenizer-only", false,
    18  		"only download the tokenizer")
    19  	flag.Parse()
    20  	if *modelId == "" {
    21  		flag.Usage()
    22  		log.Fatal("Must provide -model")
    23  	}
    24  
    25  	// map modelType to resource type enum
    26  	var rsrcType resources.ResourceType
    27  	switch *modelType {
    28  	case "transformers":
    29  		rsrcType = resources.RESOURCETYPE_TRANSFORMERS
    30  	case "diffusers":
    31  		rsrcType = resources.RESOURCETYPE_DIFFUSERS
    32  	default:
    33  		flag.Usage()
    34  		log.Fatalf("Invalid model type: %s", *modelType)
    35  	}
    36  
    37  	var rsrcLvl resources.ResourceFlag
    38  	if *tokenizerOnly {
    39  		rsrcLvl = resources.RESOURCE_DERIVED
    40  	} else {
    41  		rsrcLvl = resources.RESOURCE_MODEL
    42  	}
    43  
    44  	// get HF_API_TOKEN from env for huggingface auth
    45  	hfApiToken := os.Getenv("HF_API_TOKEN")
    46  
    47  	if mkdirErr := os.MkdirAll(*destPath, 0755); mkdirErr != nil {
    48  		log.Fatalf("Error creating output directory: %s", mkdirErr)
    49  	}
    50  	_, rsrcErr := resources.ResolveResources(*modelId, destPath,
    51  		rsrcLvl, rsrcType, hfApiToken)
    52  	if rsrcErr != nil {
    53  		log.Fatalf("Error downloading model resources: %s", rsrcErr)
    54  	}
    55  }