github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ais/test/etl_tar2tf_test.go (about) 1 // Package integration_test. 2 /* 3 * Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved. 4 */ 5 package integration_test 6 7 import ( 8 "bytes" 9 "fmt" 10 "net/http" 11 "net/url" 12 "os" 13 "path/filepath" 14 "testing" 15 16 "github.com/NVIDIA/aistore/api" 17 "github.com/NVIDIA/aistore/api/apc" 18 "github.com/NVIDIA/aistore/cmn" 19 "github.com/NVIDIA/aistore/cmn/cos" 20 "github.com/NVIDIA/aistore/ext/etl" 21 "github.com/NVIDIA/aistore/tools" 22 "github.com/NVIDIA/aistore/tools/readers" 23 "github.com/NVIDIA/aistore/tools/tassert" 24 "github.com/NVIDIA/aistore/tools/tetl" 25 "github.com/NVIDIA/aistore/tools/tlog" 26 "github.com/NVIDIA/go-tfdata/tfdata/core" 27 ) 28 29 func startTar2TfTransformer(t *testing.T) (etlName string) { 30 etlName = tetl.Tar2TF // TODO: add more 31 32 spec, err := tetl.GetTransformYaml(etlName) 33 tassert.CheckError(t, err) 34 35 msg := &etl.InitSpecMsg{} 36 { 37 msg.IDX = etlName 38 msg.CommTypeX = etl.Hpull 39 msg.Spec = spec 40 } 41 tassert.CheckError(t, msg.Validate()) 42 tassert.Fatalf(t, msg.Name() == tetl.Tar2TF, "%q vs %q", msg.Name(), tetl.Tar2TF) 43 44 // Starting transformer 45 xid, err := api.ETLInit(baseParams, msg) 46 tassert.CheckFatal(t, err) 47 48 tlog.Logf("ETL %q: running x-etl-spec[%s]\n", etlName, xid) 49 return 50 } 51 52 func TestETLTar2TFS3(t *testing.T) { 53 tools.CheckSkip(t, &tools.SkipTestArgs{RequiredDeployment: tools.ClusterTypeK8s}) 54 55 const ( 56 tarObjName = "small-mnist-3.tar" 57 tfRecordFile = "small-mnist-3.record" 58 ) 59 60 var ( 61 tarPath = filepath.Join("data", tarObjName) 62 tfRecordPath = filepath.Join("data", tfRecordFile) 63 proxyURL = tools.RandomProxyURL() 64 bck = cmn.Bck{ 65 Name: testBucketName, 66 Provider: apc.AIS, 67 } 68 baseParams = tools.BaseAPIParams(proxyURL) 69 ) 70 71 tools.CreateBucket(t, proxyURL, bck, nil, true /*cleanup*/) 72 73 // PUT TAR to the cluster 74 f, err := readers.NewExistingFile(tarPath, cos.ChecksumXXHash) 75 tassert.CheckFatal(t, err) 76 putArgs := api.PutArgs{ 77 BaseParams: baseParams, 78 Bck: bck, 79 ObjName: tarObjName, 80 Cksum: f.Cksum(), 81 Reader: f, 82 } 83 _, err = api.PutObject(&putArgs) 84 tassert.CheckFatal(t, err) 85 defer api.DeleteObject(baseParams, bck, tarObjName) 86 87 etlName := startTar2TfTransformer(t) 88 t.Cleanup(func() { tetl.StopAndDeleteETL(t, baseParams, etlName) }) 89 // GET TFRecord from TAR 90 outFileBuffer := bytes.NewBuffer(nil) 91 92 // This is to mimic external S3 clients like Tensorflow 93 bck.Provider = "" 94 95 _, err = api.GetObjectS3( 96 baseParams, 97 bck, 98 tarObjName, 99 api.GetArgs{ 100 Writer: outFileBuffer, 101 Query: url.Values{apc.QparamETLName: {etlName}}, 102 }) 103 tassert.CheckFatal(t, err) 104 105 // Comparing actual vs expected 106 tfRecord, err := os.Open(tfRecordPath) 107 tassert.CheckFatal(t, err) 108 defer tfRecord.Close() 109 110 expectedRecords, err := core.NewTFRecordReader(tfRecord).ReadAllExamples() 111 tassert.CheckFatal(t, err) 112 actualRecords, err := core.NewTFRecordReader(outFileBuffer).ReadAllExamples() 113 tassert.CheckFatal(t, err) 114 115 equal, err := tfRecordsEqual(expectedRecords, actualRecords) 116 tassert.CheckFatal(t, err) 117 tassert.Errorf(t, equal == true, "actual and expected records different") 118 } 119 120 func TestETLTar2TFRanges(t *testing.T) { 121 // TestETLTar2TFS3 already runs in short tests, no need for short here as well. 122 tools.CheckSkip(t, &tools.SkipTestArgs{RequiredDeployment: tools.ClusterTypeK8s, Long: true}) 123 124 type testCase struct { 125 start, end int64 126 } 127 128 var ( 129 tarObjName = "small-mnist-3.tar" 130 tarPath = filepath.Join("data", tarObjName) 131 proxyURL = tools.RandomProxyURL() 132 bck = cmn.Bck{ 133 Name: testBucketName, 134 Provider: apc.AIS, 135 } 136 baseParams = tools.BaseAPIParams(proxyURL) 137 rangeBytesBuff = bytes.NewBuffer(nil) 138 139 tcs = []testCase{ 140 {start: 0, end: 1}, 141 {start: 0, end: 50}, 142 {start: 1, end: 20}, 143 {start: 15, end: 100}, 144 {start: 120, end: 240}, 145 {start: 123, end: 1234}, 146 } 147 ) 148 149 tools.CreateBucket(t, proxyURL, bck, nil, true /*cleanup*/) 150 151 // PUT TAR to the cluster 152 f, err := readers.NewExistingFile(tarPath, cos.ChecksumXXHash) 153 tassert.CheckFatal(t, err) 154 putArgs := api.PutArgs{ 155 BaseParams: baseParams, 156 Bck: bck, 157 ObjName: tarObjName, 158 Cksum: f.Cksum(), 159 Reader: f, 160 } 161 _, err = api.PutObject(&putArgs) 162 tassert.CheckFatal(t, err) 163 164 etlName := startTar2TfTransformer(t) 165 t.Cleanup(func() { tetl.StopAndDeleteETL(t, baseParams, etlName) }) 166 167 // This is to mimic external S3 clients like Tensorflow 168 bck.Provider = "" 169 170 // GET TFRecord from TAR 171 wholeTFRecord := bytes.NewBuffer(nil) 172 _, err = api.GetObjectS3( 173 baseParams, 174 bck, 175 tarObjName, 176 api.GetArgs{ 177 Writer: wholeTFRecord, 178 Query: url.Values{apc.QparamETLName: {etlName}}, 179 }) 180 tassert.CheckFatal(t, err) 181 182 for _, tc := range tcs { 183 rangeBytesBuff.Reset() 184 185 // Request only a subset of bytes 186 header := http.Header{} 187 header.Set(cos.HdrRange, fmt.Sprintf("bytes=%d-%d", tc.start, tc.end)) 188 _, err = api.GetObjectS3( 189 baseParams, 190 bck, 191 tarObjName, 192 api.GetArgs{ 193 Writer: rangeBytesBuff, 194 Header: header, 195 Query: url.Values{apc.QparamETLName: {etlName}}, 196 }) 197 tassert.CheckFatal(t, err) 198 199 tassert.Errorf(t, bytes.Equal(rangeBytesBuff.Bytes(), 200 wholeTFRecord.Bytes()[tc.start:tc.end+1]), "[start: %d, end: %d] bytes different", tc.start, tc.end) 201 } 202 }