Skip to content

Commit

Permalink
Add blob typechecker (#5519)
Browse files Browse the repository at this point in the history
Signed-off-by: ddl-rliu <[email protected]>
  • Loading branch information
ddl-rliu authored Jul 17, 2024
1 parent 03a1c9e commit 9638db0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
24 changes: 24 additions & 0 deletions flytepropeller/pkg/compiler/validators/typing.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ func (t mapTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool {
return false
}

type blobTypeChecker struct {
literalType *flyte.LiteralType
}

// CastsFrom checks that the target blob type can be cast to the current blob type. When the blob has no format
// specified, it accepts all blob inputs since it is generic.
func (t blobTypeChecker) CastsFrom(upstreamType *flyte.LiteralType) bool {
blobType := upstreamType.GetBlob()
if blobType == nil {
return false
}

// Empty blobs should match any blob.
if blobType.GetFormat() == "" || t.literalType.GetBlob().GetFormat() == "" {
return true
}

return blobType.GetFormat() == t.literalType.GetBlob().GetFormat()
}

type collectionTypeChecker struct {
literalType *flyte.LiteralType
}
Expand Down Expand Up @@ -333,6 +353,10 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker {
return mapTypeChecker{
literalType: t,
}
case *flyte.LiteralType_Blob:
return blobTypeChecker{
literalType: t,
}
case *flyte.LiteralType_Schema:
return schemaTypeChecker{
literalType: t,
Expand Down
55 changes: 55 additions & 0 deletions flytepropeller/pkg/compiler/validators/typing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,3 +893,58 @@ func TestStructuredDatasetCasting(t *testing.T) {
assert.True(t, castable, "StructuredDataset are nullable")
})
}

func TestBlobCasting(t *testing.T) {
emptyBlob := &core.LiteralType{
Type: &core.LiteralType_Blob{
Blob: &core.BlobType{
Format: "",
},
},
}
genericBlob := &core.LiteralType{
Type: &core.LiteralType_Blob{
Blob: &core.BlobType{
Format: "csv",
},
},
}
mismatchedFormatBlob := &core.LiteralType{
Type: &core.LiteralType_Blob{
Blob: &core.BlobType{
Format: "pdf",
},
},
}

t.Run("BaseCase_GenericBlob", func(t *testing.T) {
castable := AreTypesCastable(genericBlob, genericBlob)
assert.True(t, castable, "Blob() should be castable to Blob()")
})

t.Run("GenericToEmptyFormat", func(t *testing.T) {
castable := AreTypesCastable(genericBlob, emptyBlob)
assert.True(t, castable, "Blob(format='csv') should be castable to Blob()")
})

t.Run("EmptyFormatToGeneric", func(t *testing.T) {
castable := AreTypesCastable(genericBlob, emptyBlob)
assert.True(t, castable, "Blob() should be castable to Blob(format='csv')")
})

t.Run("MismatchedFormat", func(t *testing.T) {
castable := AreTypesCastable(genericBlob, mismatchedFormatBlob)
assert.False(t, castable, "Blob(format='csv') should not be castable to Blob(format='pdf')")
})

t.Run("BlobsAreNullable", func(t *testing.T) {
castable := AreTypesCastable(
&core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_NONE,
},
},
genericBlob)
assert.False(t, castable, "Blob is not nullable")
})
}

0 comments on commit 9638db0

Please sign in to comment.