github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ais/test/etl_tar2tf_test.go (about)

     1  // Package integration_test.
     2  /*
     3   * Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package integration_test
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"path/filepath"
    14  	"testing"
    15  
    16  	"github.com/NVIDIA/aistore/api"
    17  	"github.com/NVIDIA/aistore/api/apc"
    18  	"github.com/NVIDIA/aistore/cmn"
    19  	"github.com/NVIDIA/aistore/cmn/cos"
    20  	"github.com/NVIDIA/aistore/ext/etl"
    21  	"github.com/NVIDIA/aistore/tools"
    22  	"github.com/NVIDIA/aistore/tools/readers"
    23  	"github.com/NVIDIA/aistore/tools/tassert"
    24  	"github.com/NVIDIA/aistore/tools/tetl"
    25  	"github.com/NVIDIA/aistore/tools/tlog"
    26  	"github.com/NVIDIA/go-tfdata/tfdata/core"
    27  )
    28  
    29  func startTar2TfTransformer(t *testing.T) (etlName string) {
    30  	etlName = tetl.Tar2TF // TODO: add more
    31  
    32  	spec, err := tetl.GetTransformYaml(etlName)
    33  	tassert.CheckError(t, err)
    34  
    35  	msg := &etl.InitSpecMsg{}
    36  	{
    37  		msg.IDX = etlName
    38  		msg.CommTypeX = etl.Hpull
    39  		msg.Spec = spec
    40  	}
    41  	tassert.CheckError(t, msg.Validate())
    42  	tassert.Fatalf(t, msg.Name() == tetl.Tar2TF, "%q vs %q", msg.Name(), tetl.Tar2TF)
    43  
    44  	// Starting transformer
    45  	xid, err := api.ETLInit(baseParams, msg)
    46  	tassert.CheckFatal(t, err)
    47  
    48  	tlog.Logf("ETL %q: running x-etl-spec[%s]\n", etlName, xid)
    49  	return
    50  }
    51  
    52  func TestETLTar2TFS3(t *testing.T) {
    53  	tools.CheckSkip(t, &tools.SkipTestArgs{RequiredDeployment: tools.ClusterTypeK8s})
    54  
    55  	const (
    56  		tarObjName   = "small-mnist-3.tar"
    57  		tfRecordFile = "small-mnist-3.record"
    58  	)
    59  
    60  	var (
    61  		tarPath      = filepath.Join("data", tarObjName)
    62  		tfRecordPath = filepath.Join("data", tfRecordFile)
    63  		proxyURL     = tools.RandomProxyURL()
    64  		bck          = cmn.Bck{
    65  			Name:     testBucketName,
    66  			Provider: apc.AIS,
    67  		}
    68  		baseParams = tools.BaseAPIParams(proxyURL)
    69  	)
    70  
    71  	tools.CreateBucket(t, proxyURL, bck, nil, true /*cleanup*/)
    72  
    73  	// PUT TAR to the cluster
    74  	f, err := readers.NewExistingFile(tarPath, cos.ChecksumXXHash)
    75  	tassert.CheckFatal(t, err)
    76  	putArgs := api.PutArgs{
    77  		BaseParams: baseParams,
    78  		Bck:        bck,
    79  		ObjName:    tarObjName,
    80  		Cksum:      f.Cksum(),
    81  		Reader:     f,
    82  	}
    83  	_, err = api.PutObject(&putArgs)
    84  	tassert.CheckFatal(t, err)
    85  	defer api.DeleteObject(baseParams, bck, tarObjName)
    86  
    87  	etlName := startTar2TfTransformer(t)
    88  	t.Cleanup(func() { tetl.StopAndDeleteETL(t, baseParams, etlName) })
    89  	// GET TFRecord from TAR
    90  	outFileBuffer := bytes.NewBuffer(nil)
    91  
    92  	// This is to mimic external S3 clients like Tensorflow
    93  	bck.Provider = ""
    94  
    95  	_, err = api.GetObjectS3(
    96  		baseParams,
    97  		bck,
    98  		tarObjName,
    99  		api.GetArgs{
   100  			Writer: outFileBuffer,
   101  			Query:  url.Values{apc.QparamETLName: {etlName}},
   102  		})
   103  	tassert.CheckFatal(t, err)
   104  
   105  	// Comparing actual vs expected
   106  	tfRecord, err := os.Open(tfRecordPath)
   107  	tassert.CheckFatal(t, err)
   108  	defer tfRecord.Close()
   109  
   110  	expectedRecords, err := core.NewTFRecordReader(tfRecord).ReadAllExamples()
   111  	tassert.CheckFatal(t, err)
   112  	actualRecords, err := core.NewTFRecordReader(outFileBuffer).ReadAllExamples()
   113  	tassert.CheckFatal(t, err)
   114  
   115  	equal, err := tfRecordsEqual(expectedRecords, actualRecords)
   116  	tassert.CheckFatal(t, err)
   117  	tassert.Errorf(t, equal == true, "actual and expected records different")
   118  }
   119  
   120  func TestETLTar2TFRanges(t *testing.T) {
   121  	// TestETLTar2TFS3 already runs in short tests, no need for short here as well.
   122  	tools.CheckSkip(t, &tools.SkipTestArgs{RequiredDeployment: tools.ClusterTypeK8s, Long: true})
   123  
   124  	type testCase struct {
   125  		start, end int64
   126  	}
   127  
   128  	var (
   129  		tarObjName = "small-mnist-3.tar"
   130  		tarPath    = filepath.Join("data", tarObjName)
   131  		proxyURL   = tools.RandomProxyURL()
   132  		bck        = cmn.Bck{
   133  			Name:     testBucketName,
   134  			Provider: apc.AIS,
   135  		}
   136  		baseParams     = tools.BaseAPIParams(proxyURL)
   137  		rangeBytesBuff = bytes.NewBuffer(nil)
   138  
   139  		tcs = []testCase{
   140  			{start: 0, end: 1},
   141  			{start: 0, end: 50},
   142  			{start: 1, end: 20},
   143  			{start: 15, end: 100},
   144  			{start: 120, end: 240},
   145  			{start: 123, end: 1234},
   146  		}
   147  	)
   148  
   149  	tools.CreateBucket(t, proxyURL, bck, nil, true /*cleanup*/)
   150  
   151  	// PUT TAR to the cluster
   152  	f, err := readers.NewExistingFile(tarPath, cos.ChecksumXXHash)
   153  	tassert.CheckFatal(t, err)
   154  	putArgs := api.PutArgs{
   155  		BaseParams: baseParams,
   156  		Bck:        bck,
   157  		ObjName:    tarObjName,
   158  		Cksum:      f.Cksum(),
   159  		Reader:     f,
   160  	}
   161  	_, err = api.PutObject(&putArgs)
   162  	tassert.CheckFatal(t, err)
   163  
   164  	etlName := startTar2TfTransformer(t)
   165  	t.Cleanup(func() { tetl.StopAndDeleteETL(t, baseParams, etlName) })
   166  
   167  	// This is to mimic external S3 clients like Tensorflow
   168  	bck.Provider = ""
   169  
   170  	// GET TFRecord from TAR
   171  	wholeTFRecord := bytes.NewBuffer(nil)
   172  	_, err = api.GetObjectS3(
   173  		baseParams,
   174  		bck,
   175  		tarObjName,
   176  		api.GetArgs{
   177  			Writer: wholeTFRecord,
   178  			Query:  url.Values{apc.QparamETLName: {etlName}},
   179  		})
   180  	tassert.CheckFatal(t, err)
   181  
   182  	for _, tc := range tcs {
   183  		rangeBytesBuff.Reset()
   184  
   185  		// Request only a subset of bytes
   186  		header := http.Header{}
   187  		header.Set(cos.HdrRange, fmt.Sprintf("bytes=%d-%d", tc.start, tc.end))
   188  		_, err = api.GetObjectS3(
   189  			baseParams,
   190  			bck,
   191  			tarObjName,
   192  			api.GetArgs{
   193  				Writer: rangeBytesBuff,
   194  				Header: header,
   195  				Query:  url.Values{apc.QparamETLName: {etlName}},
   196  			})
   197  		tassert.CheckFatal(t, err)
   198  
   199  		tassert.Errorf(t, bytes.Equal(rangeBytesBuff.Bytes(),
   200  			wholeTFRecord.Bytes()[tc.start:tc.end+1]), "[start: %d, end: %d] bytes different", tc.start, tc.end)
   201  	}
   202  }