Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: zazulam <[email protected]>
  • Loading branch information
zazulam committed Jul 15, 2024
1 parent 08b6210 commit 31035ae
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
3 changes: 1 addition & 2 deletions backend/src/v2/compiler/argocompiler/argo.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,14 @@ func (c *workflowCompiler) argumentsPlaceholder(componentName string) (string, e
return workflowParameter(componentName), nil
}

// extractBaseComponentName removes the iteration suffix that the IR compiler
// ExtractBaseComponentName removes the iteration suffix that the IR compiler
// adds to the component name.
func ExtractBaseComponentName(componentName string) string {
baseComponentName := componentName
componentNameArray := strings.Split(componentName, "-")

if _, err := strconv.Atoi(componentNameArray[len(componentNameArray)-1]); err == nil {
baseComponentName = strings.Join(componentNameArray[:len(componentNameArray)-1], "-")

}

return baseComponentName
Expand Down
15 changes: 15 additions & 0 deletions backend/src/v2/compiler/argocompiler/argo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,21 @@ func Test_extractBaseComponentName(t *testing.T) {
componentName: "component",
expectedBaseName: "component",
},
{
name: "Last char is int",
componentName: "component-v2",
expectedBaseName: "component-v2",
},
{
name: "Multiple dashes, ends with int",
componentName: "service-api-v2",
expectedBaseName: "service-api-v2",
},
{
name: "Multiple dashes and ints",
componentName: "module-1-2-3",
expectedBaseName: "module-1-2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
28 changes: 14 additions & 14 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def comp():


@dsl.component
def return_1() -> int:
def return_one() -> int:
return 1


Expand Down Expand Up @@ -3369,43 +3369,43 @@ def test_cpu_memory_optional(self):

@dsl.pipeline
def simple_pipeline():
return_1()
return_1().set_cpu_limit('5')
return_1().set_memory_limit('50G')
return_1().set_cpu_request('2').set_cpu_limit(
return_one()
return_one().set_cpu_limit('5')
return_one().set_memory_limit('50G')
return_one().set_cpu_request('2').set_cpu_limit(
'5').set_memory_request('4G').set_memory_limit('50G')

dict_format = json_format.MessageToDict(simple_pipeline.pipeline_spec)

self.assertNotIn(
'resources', dict_format['deploymentSpec']['executors']
['exec-return-1']['container'])
['exec-return-one']['container'])

self.assertEqual(
5, dict_format['deploymentSpec']['executors']['exec-return-1-2']
5, dict_format['deploymentSpec']['executors']['exec-return-one-2']
['container']['resources']['cpuLimit'])
self.assertNotIn(
'memoryLimit', dict_format['deploymentSpec']['executors']
['exec-return-1-2']['container']['resources'])
['exec-return-one-2']['container']['resources'])

self.assertEqual(
50, dict_format['deploymentSpec']['executors']['exec-return-1-3']
50, dict_format['deploymentSpec']['executors']['exec-return-one-3']
['container']['resources']['memoryLimit'])
self.assertNotIn(
'cpuLimit', dict_format['deploymentSpec']['executors']
['exec-return-1-3']['container']['resources'])
['exec-return-one-3']['container']['resources'])

self.assertEqual(
2, dict_format['deploymentSpec']['executors']['exec-return-1-4']
2, dict_format['deploymentSpec']['executors']['exec-return-one-4']
['container']['resources']['cpuRequest'])
self.assertEqual(
5, dict_format['deploymentSpec']['executors']['exec-return-1-4']
5, dict_format['deploymentSpec']['executors']['exec-return-one-4']
['container']['resources']['cpuLimit'])
self.assertEqual(
4, dict_format['deploymentSpec']['executors']['exec-return-1-4']
4, dict_format['deploymentSpec']['executors']['exec-return-one-4']
['container']['resources']['memoryRequest'])
self.assertEqual(
50, dict_format['deploymentSpec']['executors']['exec-return-1-4']
50, dict_format['deploymentSpec']['executors']['exec-return-one-4']
['container']['resources']['memoryLimit'])


Expand Down
7 changes: 7 additions & 0 deletions sdk/python/kfp/dsl/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ class ComponentInfo():

def _python_function_name_to_component_name(name):
name_with_spaces = re.sub(' +', ' ', name.replace('_', ' ')).strip(' ')
name_list = name_with_spaces.split(' ')

if name_list[-1].isdigit():
raise ValueError(
f'Invalid function name "{name}". The function name must not end in `_<int>`.'
)

return name_with_spaces[0].upper() + name_with_spaces[1:]


Expand Down
24 changes: 24 additions & 0 deletions sdk/python/kfp/dsl/component_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,30 @@ def comp(Output: OutputPath(str), text: str) -> str:
pass


class TestPythonFunctionName(unittest.TestCase):

def test_invalid_function_name(self):

with self.assertRaisesRegex(
ValueError,
r'Invalid function name "comp_2". The function name must not end in `_<int>`.'
):

@component
def comp_2(text: str) -> str:
pass

def test_valid_function_name(self):

@component
def comp_v2(text: str) -> str:
pass

@component
def comp_(text: str) -> str:
pass


class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):

def test_python_component_input(self):
Expand Down

0 comments on commit 31035ae

Please sign in to comment.