github.com/hupe1980/go-huggingface@v0.0.15/huggingface.go (about) 1 package huggingface 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net/http" 10 "strings" 11 ) 12 13 var ( 14 // recommendedModels stores the recommended models for each task. 15 recommendedModels map[string]string 16 ) 17 18 // HTTPClient is an interface representing an HTTP client. 19 type HTTPClient interface { 20 Do(req *http.Request) (*http.Response, error) 21 } 22 23 // InferenceClientOptions represents options for the InferenceClient. 24 type InferenceClientOptions struct { 25 Model string 26 Endpoint string 27 InferenceEndpoint string 28 HTTPClient HTTPClient 29 } 30 31 // InferenceClient is a client for performing inference using Hugging Face models. 32 type InferenceClient struct { 33 httpClient HTTPClient 34 token string 35 opts InferenceClientOptions 36 } 37 38 // NewInferenceClient creates a new InferenceClient instance with the specified token. 39 func NewInferenceClient(token string, optFns ...func(o *InferenceClientOptions)) *InferenceClient { 40 opts := InferenceClientOptions{ 41 Endpoint: "https://huggingface.co", 42 InferenceEndpoint: "https://api-inference.huggingface.co", 43 } 44 45 for _, fn := range optFns { 46 fn(&opts) 47 } 48 49 if opts.HTTPClient == nil { 50 opts.HTTPClient = http.DefaultClient 51 } 52 53 return &InferenceClient{ 54 httpClient: opts.HTTPClient, 55 token: token, 56 opts: opts, 57 } 58 } 59 60 func (ic *InferenceClient) SetModel(model string) { 61 ic.opts.Model = model 62 } 63 64 // post sends a POST request to the specified model and task with the provided payload. 65 // It returns the response body or an error if the request fails. 66 func (ic *InferenceClient) post(ctx context.Context, model, task string, payload any) ([]byte, error) { 67 url, err := ic.resolveURL(ctx, model, task) 68 if err != nil { 69 return nil, err 70 } 71 72 body, err := json.Marshal(payload) 73 if err != nil { 74 return nil, err 75 } 76 77 httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) 78 if err != nil { 79 return nil, err 80 } 81 82 httpReq.Header.Set("Content-Type", "application/json") 83 httpReq.Header.Set("Accept", "application/json") 84 85 if ic.token != "" { 86 httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ic.token)) 87 } 88 89 res, err := ic.httpClient.Do(httpReq) 90 if err != nil { 91 return nil, err 92 } 93 94 defer res.Body.Close() 95 96 resBody, err := io.ReadAll(res.Body) 97 if err != nil { 98 return nil, err 99 } 100 101 if res.StatusCode != http.StatusOK { 102 errResp := ErrorResponse{} 103 if err := json.Unmarshal(resBody, &errResp); err != nil { 104 return nil, fmt.Errorf("huggingfaces error: %s", resBody) 105 } 106 107 return nil, fmt.Errorf("huggingfaces error: %s", errResp.Error) 108 } 109 110 return resBody, nil 111 } 112 113 // resolveURL resolves the URL for the specified model and task. 114 // It returns the resolved URL or an error if resolution fails. 115 func (ic *InferenceClient) resolveURL(ctx context.Context, model, task string) (string, error) { 116 if model == "" { 117 model = ic.opts.Model 118 } 119 120 // If model is already a URL, ignore `task` and return directly 121 if model != "" && (strings.HasPrefix(model, "http://") || strings.HasPrefix(model, "https://")) { 122 return model, nil 123 } 124 125 if model == "" { 126 var err error 127 128 model, err = ic.getRecommendedModel(ctx, task) 129 if err != nil { 130 return "", err 131 } 132 } 133 134 // Feature-extraction and sentence-similarity are the only cases where models support multiple tasks 135 if contains([]string{"feature-extraction", "sentence-similarity"}, task) { 136 return fmt.Sprintf("%s/pipeline/%s/%s", ic.opts.InferenceEndpoint, task, model), nil 137 } 138 139 return fmt.Sprintf("%s/models/%s", ic.opts.InferenceEndpoint, model), nil 140 } 141 142 // getRecommendedModel retrieves the recommended model for the specified task. 143 // It returns the recommended model or an error if retrieval fails. 144 func (ic *InferenceClient) getRecommendedModel(ctx context.Context, task string) (string, error) { 145 rModels, err := ic.fetchRecommendedModels(ctx) 146 if err != nil { 147 return "", err 148 } 149 150 model, ok := rModels[task] 151 if !ok { 152 return "", fmt.Errorf("task %s has no recommended model", task) 153 } 154 155 return model, nil 156 } 157 158 // fetchRecommendedModels retrieves the recommended models for all available tasks. 159 // It returns a map of task names to recommended models or an error if retrieval fails. 160 func (ic *InferenceClient) fetchRecommendedModels(ctx context.Context) (map[string]string, error) { 161 if recommendedModels == nil { 162 req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/api/tasks", ic.opts.Endpoint), nil) 163 if err != nil { 164 return nil, err 165 } 166 167 res, err := ic.httpClient.Do(req) 168 if err != nil { 169 return nil, err 170 } 171 defer res.Body.Close() 172 173 var jsonResponse map[string]interface{} 174 175 err = json.NewDecoder(res.Body).Decode(&jsonResponse) 176 if err != nil { 177 return nil, err 178 } 179 180 recommendedModels = make(map[string]string) 181 182 for task, details := range jsonResponse { 183 widgetModels, ok := details.(map[string]interface{})["widgetModels"].([]interface{}) 184 if !ok || len(widgetModels) == 0 { 185 recommendedModels[task] = "" 186 } else { 187 firstModel, _ := widgetModels[0].(string) 188 recommendedModels[task] = firstModel 189 } 190 } 191 } 192 193 return recommendedModels, nil 194 } 195 196 // Contains checks if the given element is present in the collection. 197 func contains[T comparable](collection []T, element T) bool { 198 for _, item := range collection { 199 if item == element { 200 return true 201 } 202 } 203 204 return false 205 } 206 207 func PTR[T any](input T) *T { 208 return &input 209 }