Skip to content

Commit

Permalink
feat(workflows): track model metrics with results summary
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 15, 2024
1 parent 02f4897 commit 98f269f
Showing 1 changed file with 62 additions and 7 deletions.
69 changes: 62 additions & 7 deletions src/pyrovelocity/workflows/main_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from dataclasses import asdict
from datetime import timedelta
from pathlib import Path

from flytekit import Resources, current_context, dynamic, task, workflow
from flytekit.extras.accelerators import T4, GPUAccelerator
Expand All @@ -13,8 +14,12 @@
PreprocessDataInterface,
PyroVelocityTrainInterface,
)
from pyrovelocity.io.archive import create_tarball_from_filtered_dir
from pyrovelocity.io.archive import (
copy_files_to_directory,
create_tarball_from_filtered_dir,
)
from pyrovelocity.io.gcs import upload_file_concurrently
from pyrovelocity.io.json import add_duration_to_run_info, combine_json_files
from pyrovelocity.logging import configure_logging
from pyrovelocity.tasks.data import download_dataset
from pyrovelocity.tasks.postprocess import postprocess_dataset
Expand Down Expand Up @@ -143,6 +148,7 @@ def train_model(
metrics_path,
run_info_path,
loss_plot_path,
loss_csv_path,
) = train_dataset(
**asdict(train_model_configuration),
)
Expand All @@ -157,6 +163,7 @@ def train_model(
f"\tmetrics path: {metrics_path}\n"
f"\trun info path: {run_info_path}\n"
f"\tloss plot path: {loss_plot_path}\n\n"
f"\tloss csv path: {loss_csv_path}\n\n"
)

return TrainingOutputs(
Expand All @@ -168,6 +175,7 @@ def train_model(
metrics_path=FlyteFile(path=str(metrics_path)),
run_info_path=FlyteFile(path=str(run_info_path)),
loss_plot_path=FlyteFile(path=str(loss_plot_path)),
loss_csv_path=FlyteFile(path=str(loss_csv_path)),
)


Expand Down Expand Up @@ -209,6 +217,7 @@ def postprocess_data(
return PostprocessOutputs(
pyrovelocity_data=FlyteFile(path=str(pyrovelocity_data_path)),
postprocessed_data=FlyteFile(path=str(postprocessed_data_path)),
metrics_path=FlyteFile(path=str(metrics_path)),
)


Expand All @@ -232,6 +241,10 @@ def summarize_data(
postprocessed_data_path = (
postprocessing_outputs.postprocessed_data.download()
)
run_info_path = training_outputs.run_info_path.download()
metrics_path = postprocessing_outputs.metrics_path.download()
loss_plot_path = training_outputs.loss_plot_path.download()
loss_csv_path = training_outputs.loss_csv_path.download()

print(
f"\nmodel_path: {model_path}\n\n",
Expand All @@ -252,9 +265,45 @@ def summarize_data(
f"\ndata_model_reports_path: {data_model_reports_path}\n",
f"\ndataframe_path: {dataframe_path}\n\n",
)

data_model_metrics_path = Path(data_model_reports_path) / "metrics"
copy_files_to_reports = [
run_info_path,
metrics_path,
loss_plot_path,
loss_csv_path,
]

copy_files_result = copy_files_to_directory(
files_to_copy=copy_files_to_reports,
target_directory=data_model_metrics_path,
)
if isinstance(copy_files_result, Failure):
print(
f"\nError copying files to {data_model_reports_path}: {copy_files_result.failure()}\n\n"
)

add_duration_to_run_info(run_info_path)
combined_metrics_path = Path(data_model_reports_path) / "metrics.json"
combine_json_result = combine_json_files(
file1=run_info_path,
file2=metrics_path,
output_file=combined_metrics_path,
)
if isinstance(combine_json_result, Failure):
print(
f"\nError combining metrics and run info files in {combined_metrics_path}:\n",
f"{combine_json_result.failure()}\n\n",
)

return SummarizeOutputs(
data_model_reports=FlyteDirectory(path=str(data_model_reports_path)),
dataframe=FlyteFile(path=str(dataframe_path)),
run_metrics_path=FlyteFile(path=str(metrics_path)),
run_info_path=FlyteFile(path=str(run_info_path)),
loss_plot_path=FlyteFile(path=str(loss_plot_path)),
loss_csv_path=FlyteFile(path=str(loss_csv_path)),
combined_metrics_path=FlyteFile(path=str(combined_metrics_path)),
)


Expand All @@ -268,7 +317,8 @@ def summarize_data(
limits=Resources(cpu="8", mem="16Gi", ephemeral_storage="200Gi"),
)
def upload_summary(
summarize_outputs: SummarizeOutputs, training_outputs: TrainingOutputs
summarize_outputs: SummarizeOutputs,
training_outputs: TrainingOutputs,
) -> FlyteFile:
data_model_reports = summarize_outputs.data_model_reports
reports_path = data_model_reports.download()
Expand All @@ -283,7 +333,12 @@ def upload_summary(
create_tarball_from_filtered_dir(
src_dir=reports_path,
output_filename=archive_name,
extensions=(".png", ".pdf", ".csv"),
extensions=(
".csv",
".json",
".pdf",
".png",
),
)

upload_result = upload_file_concurrently(
Expand Down Expand Up @@ -416,14 +471,14 @@ def training_workflow(
results = []
configurations = [
(simulated_configuration, "simulated"),
(pancreas_configuration, "pancreas"),
]

if not PYROVELOCITY_DATA_SUBSET:
configurations += [
(pbmc68k_configuration, "pbmc68k"),
(pons_configuration, "pons"),
(larry_configuration, "larry"),
(pancreas_configuration, "pancreas"),
# (pbmc68k_configuration, "pbmc68k"),
# (pons_configuration, "pons"),
# (larry_configuration, "larry"),
]

for config, _ in configurations:
Expand Down

0 comments on commit 98f269f

Please sign in to comment.