github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ext/etl/communicator.go (about)

     1  // Package etl provides utilities to initialize and use transformation pods.
     2  /*
     3   * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package etl
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httputil"
    13  	"net/url"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/NVIDIA/aistore/api/apc"
    18  	"github.com/NVIDIA/aistore/cmn"
    19  	"github.com/NVIDIA/aistore/cmn/cos"
    20  	"github.com/NVIDIA/aistore/cmn/debug"
    21  	"github.com/NVIDIA/aistore/cmn/nlog"
    22  	"github.com/NVIDIA/aistore/core"
    23  	"github.com/NVIDIA/aistore/core/meta"
    24  	"github.com/NVIDIA/aistore/memsys"
    25  )
    26  
    27  type (
    28  	CommStats interface {
    29  		ObjCount() int64
    30  		InBytes() int64
    31  		OutBytes() int64
    32  	}
    33  
    34  	// Communicator is responsible for managing communications with local ETL container.
    35  	// It listens to cluster membership changes and terminates ETL container, if need be.
    36  	Communicator interface {
    37  		meta.Slistener
    38  
    39  		Name() string
    40  		Xact() core.Xact
    41  		PodName() string
    42  		SvcName() string
    43  
    44  		String() string
    45  
    46  		// InlineTransform uses one of the two ETL container endpoints:
    47  		//  - Method "PUT", Path "/"
    48  		//  - Method "GET", Path "/bucket/object"
    49  		InlineTransform(w http.ResponseWriter, r *http.Request, bck *meta.Bck, objName string) error
    50  
    51  		// OfflineTransform interface implementations realize offline ETL.
    52  		// OfflineTransform is driven by `OfflineDP` - not to confuse
    53  		// with GET requests from users (such as training models and apps)
    54  		// to perform on-the-fly transformation.
    55  		OfflineTransform(bck *meta.Bck, objName string, timeout time.Duration) (cos.ReadCloseSizer, error)
    56  		Stop()
    57  
    58  		CommStats
    59  	}
    60  
    61  	baseComm struct {
    62  		listener meta.Slistener
    63  		boot     *etlBootstrapper
    64  	}
    65  	pushComm struct {
    66  		baseComm
    67  		command []string
    68  	}
    69  	redirectComm struct {
    70  		baseComm
    71  	}
    72  	revProxyComm struct {
    73  		baseComm
    74  		rp *httputil.ReverseProxy
    75  	}
    76  
    77  	// TODO: Generalize and move to `cos` package
    78  	cbWriter struct {
    79  		w       io.Writer
    80  		writeCb func(int)
    81  	}
    82  )
    83  
    84  // interface guard
    85  var (
    86  	_ Communicator = (*pushComm)(nil)
    87  	_ Communicator = (*redirectComm)(nil)
    88  	_ Communicator = (*revProxyComm)(nil)
    89  
    90  	_ io.Writer = (*cbWriter)(nil)
    91  )
    92  
    93  //////////////
    94  // baseComm //
    95  //////////////
    96  
    97  func newCommunicator(listener meta.Slistener, boot *etlBootstrapper) Communicator {
    98  	switch boot.msg.CommTypeX {
    99  	case Hpush, HpushStdin:
   100  		pc := &pushComm{}
   101  		pc.listener, pc.boot = listener, boot
   102  		if boot.msg.CommTypeX == HpushStdin { // io://
   103  			pc.command = boot.originalCommand
   104  		}
   105  		return pc
   106  	case Hpull:
   107  		rc := &redirectComm{}
   108  		rc.listener, rc.boot = listener, boot
   109  		return rc
   110  	case Hrev:
   111  		rp := &revProxyComm{}
   112  		rp.listener, rp.boot = listener, boot
   113  
   114  		transformerURL, err := url.Parse(boot.uri)
   115  		debug.AssertNoErr(err)
   116  		revProxy := &httputil.ReverseProxy{
   117  			Director: func(req *http.Request) {
   118  				// Replacing the `req.URL` host with ETL container host
   119  				req.URL.Scheme = transformerURL.Scheme
   120  				req.URL.Host = transformerURL.Host
   121  				req.URL.RawQuery = pruneQuery(req.URL.RawQuery)
   122  				if _, ok := req.Header["User-Agent"]; !ok {
   123  					// Explicitly disable `User-Agent` so it's not set to default value.
   124  					req.Header.Set("User-Agent", "")
   125  				}
   126  			},
   127  		}
   128  		rp.rp = revProxy
   129  		return rp
   130  	}
   131  
   132  	debug.Assert(false, "unknown comm-type '"+boot.msg.CommTypeX+"'")
   133  	return nil
   134  }
   135  
   136  func (c *baseComm) Name() string    { return c.boot.originalPodName }
   137  func (c *baseComm) PodName() string { return c.boot.pod.Name }
   138  func (c *baseComm) SvcName() string { return c.boot.pod.Name /*same as pod name*/ }
   139  
   140  func (c *baseComm) ListenSmapChanged() { c.listener.ListenSmapChanged() }
   141  
   142  func (c *baseComm) String() string {
   143  	return fmt.Sprintf("%s[%s]-%s", c.boot.originalPodName, c.boot.xctn.ID(), c.boot.msg.CommTypeX)
   144  }
   145  
   146  func (c *baseComm) Xact() core.Xact { return c.boot.xctn }
   147  func (c *baseComm) ObjCount() int64 { return c.boot.xctn.Objs() }
   148  func (c *baseComm) InBytes() int64  { return c.boot.xctn.InBytes() }
   149  func (c *baseComm) OutBytes() int64 { return c.boot.xctn.OutBytes() }
   150  
   151  func (c *baseComm) Stop() { c.boot.xctn.Finish() }
   152  
   153  func (c *baseComm) getWithTimeout(url string, size int64, timeout time.Duration) (r cos.ReadCloseSizer, err error) {
   154  	if err := c.boot.xctn.AbortErr(); err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	var (
   159  		req    *http.Request
   160  		resp   *http.Response
   161  		cancel func()
   162  	)
   163  	if timeout != 0 {
   164  		var ctx context.Context
   165  		ctx, cancel = context.WithTimeout(context.Background(), timeout)
   166  		req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
   167  	} else {
   168  		req, err = http.NewRequest(http.MethodGet, url, http.NoBody)
   169  	}
   170  	if err == nil {
   171  		resp, err = core.T.DataClient().Do(req) //nolint:bodyclose // Closed by the caller.
   172  	}
   173  	if err != nil {
   174  		if cancel != nil {
   175  			cancel()
   176  		}
   177  		return nil, err
   178  	}
   179  
   180  	return cos.NewReaderWithArgs(cos.ReaderArgs{
   181  		R:      resp.Body,
   182  		Size:   resp.ContentLength,
   183  		ReadCb: func(n int, _ error) { c.boot.xctn.InObjsAdd(0, int64(n)) },
   184  		DeferCb: func() {
   185  			if cancel != nil {
   186  				cancel()
   187  			}
   188  			c.boot.xctn.InObjsAdd(1, 0)
   189  			c.boot.xctn.OutObjsAdd(1, size) // see also: `coi.objsAdd`
   190  		},
   191  	}), nil
   192  }
   193  
   194  //////////////
   195  // pushComm: implements (Hpush | HpushStdin)
   196  //////////////
   197  
   198  func (pc *pushComm) doRequest(bck *meta.Bck, lom *core.LOM, timeout time.Duration) (r cos.ReadCloseSizer, err error) {
   199  	var ecode int
   200  	if err := lom.InitBck(bck.Bucket()); err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	lom.Lock(false)
   205  	r, ecode, err = pc.do(lom, timeout)
   206  	lom.Unlock(false)
   207  
   208  	if err != nil && cos.IsNotExist(err, ecode) && bck.IsRemote() {
   209  		_, err = core.T.GetCold(context.Background(), lom, cmn.OwtGetLock)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  		lom.Lock(false)
   214  		r, _, err = pc.do(lom, timeout)
   215  		lom.Unlock(false)
   216  	}
   217  	return
   218  }
   219  
   220  func (pc *pushComm) do(lom *core.LOM, timeout time.Duration) (_ cos.ReadCloseSizer, ecode int, err error) {
   221  	var (
   222  		body   io.ReadCloser
   223  		cancel func()
   224  		req    *http.Request
   225  		resp   *http.Response
   226  		u      string
   227  	)
   228  	if err := pc.boot.xctn.AbortErr(); err != nil {
   229  		return nil, 0, err
   230  	}
   231  	if err := lom.Load(false /*cache it*/, true /*locked*/); err != nil {
   232  		return nil, 0, err
   233  	}
   234  	size := lom.SizeBytes()
   235  
   236  	switch pc.boot.msg.ArgTypeX {
   237  	case ArgTypeDefault, ArgTypeURL:
   238  		// to remove the following assert (and the corresponding limitation):
   239  		// - container must be ready to receive complete bucket name including namespace
   240  		// - see `bck.AddToQuery` and api/bucket.go for numerous examples
   241  		debug.Assertf(lom.Bck().Ns.IsGlobal(), lom.Bck().Cname("")+" - bucket with namespace")
   242  		u = pc.boot.uri + "/" + lom.Bck().Name + "/" + lom.ObjName
   243  
   244  		fh, err := cos.NewFileHandle(lom.FQN)
   245  		if err != nil {
   246  			return nil, 0, err
   247  		}
   248  		body = fh
   249  	case ArgTypeFQN:
   250  		body = http.NoBody
   251  		u = cos.JoinPath(pc.boot.uri, url.PathEscape(lom.FQN)) // compare w/ rc.redirectURL()
   252  	default:
   253  		debug.Assert(false, "unexpected msg type:", pc.boot.msg.ArgTypeX) // is validated at construction time
   254  	}
   255  
   256  	if timeout != 0 {
   257  		var ctx context.Context
   258  		ctx, cancel = context.WithTimeout(context.Background(), timeout)
   259  		req, err = http.NewRequestWithContext(ctx, http.MethodPut, u, body)
   260  	} else {
   261  		req, err = http.NewRequest(http.MethodPut, u, body)
   262  	}
   263  	if err != nil {
   264  		cos.Close(body)
   265  		goto finish
   266  	}
   267  
   268  	if len(pc.command) != 0 {
   269  		// HpushStdin case
   270  		q := req.URL.Query()
   271  		q["command"] = []string{"bash", "-c", strings.Join(pc.command, " ")}
   272  		req.URL.RawQuery = q.Encode()
   273  	}
   274  	req.ContentLength = size
   275  	req.Header.Set(cos.HdrContentType, cos.ContentBinary)
   276  
   277  	//
   278  	// Do it
   279  	//
   280  	resp, err = core.T.DataClient().Do(req) //nolint:bodyclose // Closed by the caller.
   281  
   282  finish:
   283  	if err != nil {
   284  		if cancel != nil {
   285  			cancel()
   286  		}
   287  		if resp != nil {
   288  			ecode = resp.StatusCode
   289  		}
   290  		return nil, ecode, err
   291  	}
   292  	args := cos.ReaderArgs{
   293  		R:      resp.Body,
   294  		Size:   resp.ContentLength,
   295  		ReadCb: func(n int, _ error) { pc.boot.xctn.InObjsAdd(0, int64(n)) },
   296  		DeferCb: func() {
   297  			if cancel != nil {
   298  				cancel()
   299  			}
   300  			pc.boot.xctn.InObjsAdd(1, 0)
   301  			pc.boot.xctn.OutObjsAdd(1, size) // see also: `coi.objsAdd`
   302  		},
   303  	}
   304  	return cos.NewReaderWithArgs(args), 0, nil
   305  }
   306  
   307  func (pc *pushComm) InlineTransform(w http.ResponseWriter, _ *http.Request, bck *meta.Bck, objName string) error {
   308  	lom := core.AllocLOM(objName)
   309  	r, err := pc.doRequest(bck, lom, 0 /*timeout*/)
   310  	core.FreeLOM(lom)
   311  	if err != nil {
   312  		return err
   313  	}
   314  	if cmn.Rom.FastV(5, cos.SmoduleETL) {
   315  		nlog.Infoln(Hpush, lom.Cname(), err)
   316  	}
   317  
   318  	size := r.Size()
   319  	if size < 0 {
   320  		size = memsys.DefaultBufSize // TODO: track an average
   321  	}
   322  	buf, slab := core.T.PageMM().AllocSize(size)
   323  	_, err = io.CopyBuffer(w, r, buf)
   324  
   325  	slab.Free(buf)
   326  	r.Close()
   327  	return err
   328  }
   329  
   330  func (pc *pushComm) OfflineTransform(bck *meta.Bck, objName string, timeout time.Duration) (r cos.ReadCloseSizer, err error) {
   331  	lom := core.AllocLOM(objName)
   332  	r, err = pc.doRequest(bck, lom, timeout)
   333  	if err == nil && cmn.Rom.FastV(5, cos.SmoduleETL) {
   334  		nlog.Infoln(Hpush, lom.Cname(), err)
   335  	}
   336  	core.FreeLOM(lom)
   337  	return
   338  }
   339  
   340  //////////////////
   341  // redirectComm: implements Hpull
   342  //////////////////
   343  
   344  func (rc *redirectComm) InlineTransform(w http.ResponseWriter, r *http.Request, bck *meta.Bck, objName string) error {
   345  	if err := rc.boot.xctn.AbortErr(); err != nil {
   346  		return err
   347  	}
   348  
   349  	lom := core.AllocLOM(objName)
   350  	size, err := lomLoad(lom, bck)
   351  	if err != nil {
   352  		core.FreeLOM(lom)
   353  		return err
   354  	}
   355  	if size > 0 {
   356  		rc.boot.xctn.OutObjsAdd(1, size)
   357  	}
   358  
   359  	http.Redirect(w, r, rc.redirectURL(lom), http.StatusTemporaryRedirect)
   360  
   361  	if cmn.Rom.FastV(5, cos.SmoduleETL) {
   362  		nlog.Infoln(Hpull, lom.Cname())
   363  	}
   364  	core.FreeLOM(lom)
   365  	return nil
   366  }
   367  
   368  func (rc *redirectComm) redirectURL(lom *core.LOM) string {
   369  	switch rc.boot.msg.ArgTypeX {
   370  	case ArgTypeDefault, ArgTypeURL:
   371  		return cos.JoinPath(rc.boot.uri, transformerPath(lom.Bck(), lom.ObjName))
   372  	case ArgTypeFQN:
   373  		return cos.JoinPath(rc.boot.uri, url.PathEscape(lom.FQN))
   374  	}
   375  	cos.Assert(false) // is validated at construction time
   376  	return ""
   377  }
   378  
   379  func (rc *redirectComm) OfflineTransform(bck *meta.Bck, objName string, timeout time.Duration) (cos.ReadCloseSizer, error) {
   380  	lom := core.AllocLOM(objName)
   381  	size, errV := lomLoad(lom, bck)
   382  	if errV != nil {
   383  		core.FreeLOM(lom)
   384  		return nil, errV
   385  	}
   386  
   387  	etlURL := rc.redirectURL(lom)
   388  	r, err := rc.getWithTimeout(etlURL, size, timeout)
   389  
   390  	if cmn.Rom.FastV(5, cos.SmoduleETL) {
   391  		nlog.Infoln(Hpull, lom.Cname(), err)
   392  	}
   393  	core.FreeLOM(lom)
   394  	return r, err
   395  }
   396  
   397  //////////////////
   398  // revProxyComm: implements Hrev
   399  //////////////////
   400  
   401  func (rp *revProxyComm) InlineTransform(w http.ResponseWriter, r *http.Request, bck *meta.Bck, objName string) error {
   402  	lom := core.AllocLOM(objName)
   403  	size, err := lomLoad(lom, bck)
   404  	if err != nil {
   405  		core.FreeLOM(lom)
   406  		return err
   407  	}
   408  	if size > 0 {
   409  		rp.boot.xctn.OutObjsAdd(1, size)
   410  	}
   411  	path := transformerPath(bck, objName)
   412  	core.FreeLOM(lom)
   413  
   414  	r.URL.Path, _ = url.PathUnescape(path) // `Path` must be unescaped otherwise it will be escaped again.
   415  	r.URL.RawPath = path                   // `RawPath` should be escaped version of `Path`.
   416  	rp.rp.ServeHTTP(w, r)
   417  
   418  	return nil
   419  }
   420  
   421  func (rp *revProxyComm) OfflineTransform(bck *meta.Bck, objName string, timeout time.Duration) (cos.ReadCloseSizer, error) {
   422  	lom := core.AllocLOM(objName)
   423  	size, errV := lomLoad(lom, bck)
   424  	if errV != nil {
   425  		core.FreeLOM(lom)
   426  		return nil, errV
   427  	}
   428  	etlURL := cos.JoinPath(rp.boot.uri, transformerPath(bck, objName))
   429  	r, err := rp.getWithTimeout(etlURL, size, timeout)
   430  
   431  	if cmn.Rom.FastV(5, cos.SmoduleETL) {
   432  		nlog.Infoln(Hrev, lom.Cname(), err)
   433  	}
   434  	core.FreeLOM(lom)
   435  	return r, err
   436  }
   437  
   438  //////////////
   439  // cbWriter //
   440  //////////////
   441  
   442  func (cw *cbWriter) Write(b []byte) (n int, err error) {
   443  	n, err = cw.w.Write(b)
   444  	cw.writeCb(n)
   445  	return
   446  }
   447  
   448  //
   449  // utils
   450  //
   451  
   452  // prune query (received from AIS proxy) prior to reverse-proxying the request to/from container -
   453  // not removing apc.QparamETLName, for instance, would cause infinite loop.
   454  func pruneQuery(rawQuery string) string {
   455  	vals, err := url.ParseQuery(rawQuery)
   456  	if err != nil {
   457  		nlog.Errorf("failed to parse raw query %q, err: %v", rawQuery, err)
   458  		return ""
   459  	}
   460  	for _, filtered := range []string{apc.QparamETLName, apc.QparamProxyID, apc.QparamUnixTime} {
   461  		vals.Del(filtered)
   462  	}
   463  	return vals.Encode()
   464  }
   465  
   466  // TODO -- FIXME: unify the way we encode bucket/object:
   467  // - url.PathEscape(uname) - see below - versus
   468  // - Bck().Name + "/" + lom.ObjName - see pushComm above - versus
   469  // - bck.AddToQuery() elsewhere
   470  func transformerPath(bck *meta.Bck, objName string) string {
   471  	return "/" + url.PathEscape(bck.MakeUname(objName))
   472  }
   473  
   474  func lomLoad(lom *core.LOM, bck *meta.Bck) (size int64, err error) {
   475  	if err = lom.InitBck(bck.Bucket()); err != nil {
   476  		return
   477  	}
   478  	if err = lom.Load(true /*cacheIt*/, false /*locked*/); err != nil {
   479  		if cos.IsNotExist(err, 0) && bck.IsRemote() {
   480  			err = nil // NOTE: size == 0
   481  		}
   482  	} else {
   483  		size = lom.SizeBytes()
   484  	}
   485  	return
   486  }