Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Jun 28, 2022
1 parent e30c6e0 commit edc1da7
Showing 1 changed file with 38 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
import org.apache.submarine.server.submitter.k8s.model.notebook.NotebookCR;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
import org.apache.submarine.server.submitter.k8s.parser.ExperimentSpecParser;
import org.apache.submarine.server.submitter.k8s.util.MLJobConverter;
import org.apache.submarine.server.submitter.k8s.util.NotebookUtils;
Expand All @@ -91,7 +90,6 @@ public class K8sSubmitter implements Submitter {

private static final String TF_JOB_SELECTOR_KEY = "tf-job-name=";
private static final String PYTORCH_JOB_SELECTOR_KEY = "pytorch-job-name=";
private static final String XGBoost_JOB_SELECTOR_KEY = "xgboost-job-name=";

// Add an exception Consumer, handle the problem that delete operation does not have the resource
public static final Function<ApiException, Object> API_EXCEPTION_404_CONSUMER = e -> {
Expand Down Expand Up @@ -195,39 +193,26 @@ public Experiment createExperiment(ExperimentSpec spec) throws SubmarineRuntimeE
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
mlJob.getMetadata().setOwnerReferences(OwnerReferenceUtils.getOwnerReference());

CustomResourceType customResourceType;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
customResourceType = CustomResourceType.TFJob;
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
customResourceType = CustomResourceType.XGBoost;
} else {
customResourceType = CustomResourceType.PyTorchJob;
}

AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(), customResourceType,
AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(),
mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? CustomResourceType.TFJob : CustomResourceType.PyTorchJob,
spec.getMeta().getExperimentId());

Object object;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
object = k8sClient.getTfJobClient().create(getServerNamespace(), (TFJob) mlJob,
new CreateOptions()).throwsApiException().getObject();
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
object = k8sClient.getXGBoostJobClient().create(getServerNamespace(), (XGBoostJob) mlJob,
new CreateOptions()).throwsApiException().getObject();
} else {
object = k8sClient.getPyTorchJobClient().create(getServerNamespace(), (PyTorchJob) mlJob,
new CreateOptions()).throwsApiException().getObject();
}

k8sClient.getPodClient().create(agentPod).throwsApiException().getObject();
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? k8sClient.getTfJobClient().create(getServerNamespace(), (TFJob) mlJob,
new CreateOptions()).throwsApiException().getObject()
: k8sClient.getPyTorchJobClient().create(getServerNamespace(), (PyTorchJob) mlJob,
new CreateOptions()).throwsApiException().getObject();

V1Pod agentPodResult = k8sClient.getPodClient().create(agentPod).throwsApiException().getObject();
experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
LOG.error("K8s submitter: parse Job object failed by " + e.getMessage(), e);
throw new SubmarineRuntimeException(400, e.getMessage());
} catch (ApiException e) {
LOG.error("K8s submitter: failed to create pod " + e.getMessage(), e);
throw new SubmarineRuntimeException(e.getCode(), "K8s submitter: failed to create pod " +
LOG.error("K8s submitter: parse Job object failed by " + e.getMessage(), e);
throw new SubmarineRuntimeException(e.getCode(), "K8s submitter: parse Job object failed by " +
e.getMessage());
}
return experiment;
Expand All @@ -240,18 +225,11 @@ public Experiment findExperiment(ExperimentSpec spec) throws SubmarineRuntimeExc

MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());

Object object;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
object = k8sClient.getTfJobClient().get(getServerNamespace(),
mlJob.getMetadata().getName()).throwsApiException().getObject();
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
object = k8sClient.getXGBoostJobClient().get(getServerNamespace(),
mlJob.getMetadata().getName()).throwsApiException().getObject();
} else {
object = k8sClient.getPyTorchJobClient().get(getServerNamespace(),
mlJob.getMetadata().getName()).throwsApiException().getObject();
}
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? k8sClient.getTfJobClient().get(getServerNamespace(), mlJob.getMetadata().getName())
.throwsApiException().getObject()
: k8sClient.getPyTorchJobClient().get(getServerNamespace(), mlJob.getMetadata().getName())
.throwsApiException().getObject();

experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);

Expand All @@ -275,24 +253,16 @@ public Experiment patchExperiment(ExperimentSpec spec) throws SubmarineRuntimeEx
PatchOptions patchOptions = new PatchOptions();
patchOptions.setFieldManager(spec.getMeta().getExperimentId());
patchOptions.setForce(true);
Object object;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
object = k8sClient.getTfJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(new Gson().toJson(mlJob)),
patchOptions).throwsApiException().getObject();
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
object = k8sClient.getXGBoostJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(new Gson().toJson(mlJob)),
patchOptions).throwsApiException().getObject();
} else {
object = k8sClient.getPyTorchJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(new Gson().toJson(mlJob)),
patchOptions).throwsApiException().getObject();
}

Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? k8sClient.getTfJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(new Gson().toJson(mlJob)),
patchOptions).throwsApiException().getObject()
: k8sClient.getPyTorchJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(new Gson().toJson(mlJob)),
patchOptions).throwsApiException().getObject()
;
experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
throw new SubmarineRuntimeException(409, e.getMessage());
Expand All @@ -311,31 +281,18 @@ public Experiment deleteExperiment(ExperimentSpec spec) throws SubmarineRuntimeE
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());

CustomResourceType customResourceType;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
customResourceType = CustomResourceType.TFJob;
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
customResourceType = CustomResourceType.XGBoost;
} else {
customResourceType = CustomResourceType.PyTorchJob;
}

AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(), customResourceType,
AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(),
mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? CustomResourceType.TFJob : CustomResourceType.PyTorchJob,
spec.getMeta().getExperimentId());

Object object;
if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
object = k8sClient.getTfJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob));
} else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
object = k8sClient.getXGBoostJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
.throwsApiException().getStatus();
} else {
object = k8sClient.getPyTorchJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
.throwsApiException().getStatus();
}
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? k8sClient.getTfJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
.throwsApiException().getStatus()
: k8sClient.getPyTorchJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
.throwsApiException().getStatus();

LOG.info(String.format("Experiment:%s had been deleted, start to delete agent pod:%s",
spec.getMeta().getName(), agentPod.getMetadata().getName()));
Expand Down Expand Up @@ -582,11 +539,7 @@ private String getJobLabelSelector(ExperimentSpec experimentSpec) {
if (experimentSpec.getMeta().getFramework()
.equalsIgnoreCase(ExperimentMeta.SupportedMLFramework.TENSORFLOW.getName())) {
return TF_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
} else if (experimentSpec.getMeta().getFramework()
.equalsIgnoreCase(ExperimentMeta.SupportedMLFramework.XGBOOST.getName())) {
return XGBoost_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
}
else {
} else {
return PYTORCH_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
}
}
Expand Down

0 comments on commit edc1da7

Please sign in to comment.