github.com/hupe1980/go-huggingface@v0.0.15/huggingface.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"strings"
    11  )
    12  
    13  var (
    14  	// recommendedModels stores the recommended models for each task.
    15  	recommendedModels map[string]string
    16  )
    17  
    18  // HTTPClient is an interface representing an HTTP client.
    19  type HTTPClient interface {
    20  	Do(req *http.Request) (*http.Response, error)
    21  }
    22  
    23  // InferenceClientOptions represents options for the InferenceClient.
    24  type InferenceClientOptions struct {
    25  	Model             string
    26  	Endpoint          string
    27  	InferenceEndpoint string
    28  	HTTPClient        HTTPClient
    29  }
    30  
    31  // InferenceClient is a client for performing inference using Hugging Face models.
    32  type InferenceClient struct {
    33  	httpClient HTTPClient
    34  	token      string
    35  	opts       InferenceClientOptions
    36  }
    37  
    38  // NewInferenceClient creates a new InferenceClient instance with the specified token.
    39  func NewInferenceClient(token string, optFns ...func(o *InferenceClientOptions)) *InferenceClient {
    40  	opts := InferenceClientOptions{
    41  		Endpoint:          "https://huggingface.co",
    42  		InferenceEndpoint: "https://api-inference.huggingface.co",
    43  	}
    44  
    45  	for _, fn := range optFns {
    46  		fn(&opts)
    47  	}
    48  
    49  	if opts.HTTPClient == nil {
    50  		opts.HTTPClient = http.DefaultClient
    51  	}
    52  
    53  	return &InferenceClient{
    54  		httpClient: opts.HTTPClient,
    55  		token:      token,
    56  		opts:       opts,
    57  	}
    58  }
    59  
    60  func (ic *InferenceClient) SetModel(model string) {
    61  	ic.opts.Model = model
    62  }
    63  
    64  // post sends a POST request to the specified model and task with the provided payload.
    65  // It returns the response body or an error if the request fails.
    66  func (ic *InferenceClient) post(ctx context.Context, model, task string, payload any) ([]byte, error) {
    67  	url, err := ic.resolveURL(ctx, model, task)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	body, err := json.Marshal(payload)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	httpReq.Header.Set("Content-Type", "application/json")
    83  	httpReq.Header.Set("Accept", "application/json")
    84  
    85  	if ic.token != "" {
    86  		httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ic.token))
    87  	}
    88  
    89  	res, err := ic.httpClient.Do(httpReq)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	defer res.Body.Close()
    95  
    96  	resBody, err := io.ReadAll(res.Body)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	if res.StatusCode != http.StatusOK {
   102  		errResp := ErrorResponse{}
   103  		if err := json.Unmarshal(resBody, &errResp); err != nil {
   104  			return nil, fmt.Errorf("huggingfaces error: %s", resBody)
   105  		}
   106  
   107  		return nil, fmt.Errorf("huggingfaces error: %s", errResp.Error)
   108  	}
   109  
   110  	return resBody, nil
   111  }
   112  
   113  // resolveURL resolves the URL for the specified model and task.
   114  // It returns the resolved URL or an error if resolution fails.
   115  func (ic *InferenceClient) resolveURL(ctx context.Context, model, task string) (string, error) {
   116  	if model == "" {
   117  		model = ic.opts.Model
   118  	}
   119  
   120  	// If model is already a URL, ignore `task` and return directly
   121  	if model != "" && (strings.HasPrefix(model, "http://") || strings.HasPrefix(model, "https://")) {
   122  		return model, nil
   123  	}
   124  
   125  	if model == "" {
   126  		var err error
   127  
   128  		model, err = ic.getRecommendedModel(ctx, task)
   129  		if err != nil {
   130  			return "", err
   131  		}
   132  	}
   133  
   134  	// Feature-extraction and sentence-similarity are the only cases where models support multiple tasks
   135  	if contains([]string{"feature-extraction", "sentence-similarity"}, task) {
   136  		return fmt.Sprintf("%s/pipeline/%s/%s", ic.opts.InferenceEndpoint, task, model), nil
   137  	}
   138  
   139  	return fmt.Sprintf("%s/models/%s", ic.opts.InferenceEndpoint, model), nil
   140  }
   141  
   142  // getRecommendedModel retrieves the recommended model for the specified task.
   143  // It returns the recommended model or an error if retrieval fails.
   144  func (ic *InferenceClient) getRecommendedModel(ctx context.Context, task string) (string, error) {
   145  	rModels, err := ic.fetchRecommendedModels(ctx)
   146  	if err != nil {
   147  		return "", err
   148  	}
   149  
   150  	model, ok := rModels[task]
   151  	if !ok {
   152  		return "", fmt.Errorf("task %s has no recommended model", task)
   153  	}
   154  
   155  	return model, nil
   156  }
   157  
   158  // fetchRecommendedModels retrieves the recommended models for all available tasks.
   159  // It returns a map of task names to recommended models or an error if retrieval fails.
   160  func (ic *InferenceClient) fetchRecommendedModels(ctx context.Context) (map[string]string, error) {
   161  	if recommendedModels == nil {
   162  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/api/tasks", ic.opts.Endpoint), nil)
   163  		if err != nil {
   164  			return nil, err
   165  		}
   166  
   167  		res, err := ic.httpClient.Do(req)
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  		defer res.Body.Close()
   172  
   173  		var jsonResponse map[string]interface{}
   174  
   175  		err = json.NewDecoder(res.Body).Decode(&jsonResponse)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  
   180  		recommendedModels = make(map[string]string)
   181  
   182  		for task, details := range jsonResponse {
   183  			widgetModels, ok := details.(map[string]interface{})["widgetModels"].([]interface{})
   184  			if !ok || len(widgetModels) == 0 {
   185  				recommendedModels[task] = ""
   186  			} else {
   187  				firstModel, _ := widgetModels[0].(string)
   188  				recommendedModels[task] = firstModel
   189  			}
   190  		}
   191  	}
   192  
   193  	return recommendedModels, nil
   194  }
   195  
   196  // Contains checks if the given element is present in the collection.
   197  func contains[T comparable](collection []T, element T) bool {
   198  	for _, item := range collection {
   199  		if item == element {
   200  			return true
   201  		}
   202  	}
   203  
   204  	return false
   205  }
   206  
   207  func PTR[T any](input T) *T {
   208  	return &input
   209  }