Skip to content

Commit 3d95f8d

Browse files
committedApr 14, 2024
FORK: fix bucket root query strings (upstream kubeflow#10319)
Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com>
1 parent 3963f34 commit 3d95f8d

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed
 

Diff for: ‎backend/src/v2/driver/driver.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ import (
1717
"context"
1818
"encoding/json"
1919
"fmt"
20-
"path"
2120
"strconv"
22-
"strings"
2321
"time"
2422

2523
"github.com/golang/glog"
@@ -1062,7 +1060,9 @@ func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.C
10621060
outputs.Artifacts[name] = &pipelinespec.ArtifactList{
10631061
Artifacts: []*pipelinespec.RuntimeArtifact{
10641062
{
1065-
Uri: generateOutputURI(pipelineRoot, name, taskName),
1063+
// Do not preserve the query string for output artifacts, as otherwise
1064+
// they'd appear in file and artifact names.
1065+
Uri: metadata.GenerateOutputURI(pipelineRoot, []string{taskName, name}, false),
10661066
Type: artifact.GetArtifactType(),
10671067
Metadata: artifact.GetMetadata(),
10681068
},
@@ -1078,12 +1078,6 @@ func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.C
10781078
return outputs
10791079
}
10801080

1081-
func generateOutputURI(root, artifactName string, taskName string) string {
1082-
// we cannot path.Join(root, taskName, artifactName), because root
1083-
// contains scheme like gs:// and path.Join cleans up scheme to gs:/
1084-
return fmt.Sprintf("%s/%s", strings.TrimRight(root, "/"), path.Join(taskName, artifactName))
1085-
}
1086-
10871081
var accessModeMap = map[string]k8score.PersistentVolumeAccessMode{
10881082
"ReadWriteOnce": k8score.ReadWriteOnce,
10891083
"ReadOnlyMany": k8score.ReadOnlyMany,

Diff for: ‎backend/src/v2/metadata/client.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,26 @@ func (e *Execution) FingerPrint() string {
260260
return e.execution.GetCustomProperties()[keyCacheFingerPrint].GetStringValue()
261261
}
262262

263+
// GenerateOutputURI appends the specified paths to the pipeline root.
264+
// It may be configured to preserve the query part of the pipeline root
265+
// by splitting it off and appending it back to the full URI.
266+
func GenerateOutputURI(pipelineRoot string, paths []string, preserveQueryString bool) string {
267+
querySplit := strings.Split(pipelineRoot, "?")
268+
query := ""
269+
if len(querySplit) == 2 {
270+
pipelineRoot = querySplit[0]
271+
if preserveQueryString {
272+
query = "?" + querySplit[1]
273+
}
274+
} else if len(querySplit) > 2 {
275+
// this should never happen, but just in case.
276+
glog.Warningf("Unexpected pipeline root: %v", pipelineRoot)
277+
}
278+
// we cannot path.Join(root, taskName, artifactName), because root
279+
// contains scheme like gs:// and path.Join cleans up scheme to gs:/
280+
return fmt.Sprintf("%s/%s%s", strings.TrimRight(pipelineRoot, "/"), path.Join(paths...), query)
281+
}
282+
263283
// GetPipeline returns the current pipeline represented by the specified
264284
// pipeline name and run ID.
265285
func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error) {
@@ -272,7 +292,7 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace
272292
keyNamespace: stringValue(namespace),
273293
keyResourceName: stringValue(runResource),
274294
// pipeline root of this run
275-
keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, runID)),
295+
keyPipelineRoot: stringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
276296
}
277297
runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata)
278298
glog.Infof("Pipeline Run Context: %+v", runContext)

Diff for: ‎backend/src/v2/metadata/client_test.go

+55-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func Test_GetPipeline_Twice(t *testing.T) {
143143
// The second call to GetPipeline won't fail because it avoid inserting to MLMD again.
144144
samePipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot)
145145
fatalIf(err)
146-
if (pipeline.GetCtxID() != samePipeline.GetCtxID()) {
146+
if pipeline.GetCtxID() != samePipeline.GetCtxID() {
147147
t.Errorf("Expect pipeline context ID %d, actual is %d", pipeline.GetCtxID(), samePipeline.GetCtxID())
148148
}
149149
}
@@ -214,6 +214,60 @@ func Test_GetPipelineConcurrently(t *testing.T) {
214214
wg.Wait()
215215
}
216216

217+
func Test_GenerateOutputURI(t *testing.T) {
218+
// Const define the artifact name
219+
const (
220+
pipelineName = "my-pipeline-name"
221+
runID = "my-run-id"
222+
pipelineRoot = "minio://mlpipeline/v2/artifacts"
223+
pipelineRootQuery = "?query=string&another=query"
224+
)
225+
tests := []struct {
226+
name string
227+
queryString string
228+
paths []string
229+
preserveQueryString bool
230+
want string
231+
}{
232+
{
233+
name: "plain pipeline root without preserveQueryString",
234+
queryString: "",
235+
paths: []string{pipelineName, runID},
236+
preserveQueryString: false,
237+
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
238+
},
239+
{
240+
name: "plain pipeline root with preserveQueryString",
241+
queryString: "",
242+
paths: []string{pipelineName, runID},
243+
preserveQueryString: true,
244+
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
245+
},
246+
{
247+
name: "pipeline root with query string without preserveQueryString",
248+
queryString: pipelineRootQuery,
249+
paths: []string{pipelineName, runID},
250+
preserveQueryString: false,
251+
want: fmt.Sprintf("%s/%s/%s", pipelineRoot, pipelineName, runID),
252+
},
253+
{
254+
name: "pipeline root with query string with preserveQueryString",
255+
queryString: pipelineRootQuery,
256+
paths: []string{pipelineName, runID},
257+
preserveQueryString: true,
258+
want: fmt.Sprintf("%s/%s/%s%s", pipelineRoot, pipelineName, runID, pipelineRootQuery),
259+
},
260+
}
261+
for _, tt := range tests {
262+
t.Run(tt.name, func(t *testing.T) {
263+
got := metadata.GenerateOutputURI(fmt.Sprintf("%s%s", pipelineRoot, tt.queryString), tt.paths, tt.preserveQueryString)
264+
if diff := cmp.Diff(got, tt.want); diff != "" {
265+
t.Errorf("GenerateOutputURI() = %v, want %v\nDiff (-want, +got)\n%s", got, tt.want, diff)
266+
}
267+
})
268+
}
269+
}
270+
217271
func Test_DAG(t *testing.T) {
218272
t.Skip("Temporarily disable the test that requires cluster connection.")
219273

0 commit comments

Comments
 (0)