Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(wfm): rnasum rerun api #713

Merged
merged 7 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lib/workload/stateless/stacks/workflow-manager/deploy/stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ export class WorkflowManagerStack extends Stack {
vpcSubnets: { subnets: this.vpc.privateSubnets },
role: this.lambdaRole,
architecture: Architecture.ARM_64,
memorySize: 1024,
...props,
});
}
Expand All @@ -111,6 +112,7 @@ export class WorkflowManagerStack extends Stack {
}

private createApiHandlerAndIntegration(props: WorkflowManagerStackProps) {
const API_VERSION = 'v1';
const apiFn: PythonFunction = this.createPythonFunction('Api', {
index: 'api.py',
handler: 'handler',
Expand Down Expand Up @@ -145,6 +147,18 @@ export class WorkflowManagerStack extends Stack {
integration: apiIntegration,
routeKey: HttpRouteKey.with('/{proxy+}', HttpMethod.DELETE),
});

// Route and permission for rerun cases where it needs to put event to mainBus
this.mainBus.grantPutEventsTo(apiFn);
new HttpRoute(this, 'PostRerunHttpRoute', {
httpApi: httpApi,
integration: apiIntegration,
authorizer: wfmApi.authStackHttpLambdaAuthorizer,
routeKey: HttpRouteKey.with(
`/api/${API_VERSION}/workflowrun/{orcabusId}/rerun/{proxy+}`,
HttpMethod.POST
),
});
}

private createHandleServiceWrscEventHandler() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import logging
import json
from libumccr.aws import libeb
import workflow_manager.aws_event_bridge.workflowmanager.workflowrunstatechange as wfm

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def emit_wrsc_api_event(event):
"""
Emit events to the event bridge sourced from the workflow manager API
"""
source = "orcabus.workflowmanagerapi"
event_bus_name = os.environ.get("EVENT_BUS_NAME", None)

if event_bus_name is None:
raise ValueError("EVENT_BUS_NAME environment variable is not set.")

logger.info(f"Emitting event: {event}")
response = libeb.emit_event({
'Source': source,
'DetailType': wfm.WorkflowRunStateChange.__name__,
'Detail': json.dumps(wfm.Marshaller.marshall(event)),
'EventBusName': event_bus_name,
})

logger.info(f"Sent a WRSC event to event bus {event_bus_name}:")
logger.info(event)
logger.info(f"{__name__} done.")
return response
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import re
import six
from workflow_manager_proc.domain.workflowmanager import workflowrunstatechange
from workflow_manager.aws_event_bridge.executionservice import workflowrunstatechange

class Marshaller:
PRIMITIVE_TYPES = (float, bool, bytes, six.text_type) + six.integer_types
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import re
import six
from workflow_manager_proc.domain.executionservice import workflowrunstatechange
from workflow_manager.aws_event_bridge.workflowmanager import workflowrunstatechange

class Marshaller:
PRIMITIVE_TYPES = (float, bool, bytes, six.text_type) + six.integer_types
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from datetime import timedelta
import uuid
from datetime import timedelta, datetime, timezone
from typing import List

from workflow_manager.models import Status, State, WorkflowRun
Expand Down Expand Up @@ -142,3 +143,7 @@ def get_latest_state(states: List[State]) -> State:
last = s
return last


def create_portal_run_id() -> str:
date = datetime.now(timezone.utc)
return f"{date.year}{date.month}{date.day}{str(uuid.uuid4())[:8]}"
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import re
from rest_framework import serializers


def to_camel_case(snake_str):
components = re.split(r'[_\-\s]', snake_str)
return components[0].lower() + ''.join(x.title() for x in components[1:])


class SerializersBase(serializers.ModelSerializer):
prefix = ''

def __init__(self, *args, camel_case_data=False, **kwargs):
super().__init__(*args, **kwargs)
self.use_camel_case = camel_case_data

def to_representation(self, instance):
representation = super().to_representation(instance)
representation['orcabus_id'] = self.prefix + str(representation['orcabus_id'])

if self.use_camel_case:
return {to_camel_case(key): value for key, value in representation.items()}
return representation


class OptionalFieldsMixin:
def make_fields_optional(self):
# Make all fields optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
class LibraryBaseSerializer(SerializersBase):
prefix = Library.orcabus_id_prefix


class LibraryListParamSerializer(OptionalFieldsMixin, LibraryBaseSerializer):
class Meta:
model = Library
fields = "__all__"


class LibrarySerializer(LibraryBaseSerializer):
class Meta:
model = Library
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from enum import StrEnum
from typing import Type

from rest_framework import serializers


class AllowedRerunWorkflow(StrEnum):
RNASUM = "rnasum"


class BaseRerunInputSerializer(serializers.Serializer):

def update(self, instance, validated_data):
pass

def create(self, validated_data):
pass


class RnasumRerunInputSerializer(BaseRerunInputSerializer):
"""
For 'rnasum' workflow rerun only allow dataset to be overridden.
"""

# https://github.com/umccr/RNAsum/blob/master/TCGA_projects_summary.md
allowed_dataset_choice = [
# PRIMARY_DATASETS_OPTION
"BRCA", "THCA", "HNSC", "LGG", "KIRC", "LUSC", "LUAD", "PRAD", "STAD", "LIHC", "COAD", "KIRP",
"BLCA", "OV", "SARC", "PCPG", "CESC", "UCEC", "PAAD", "TGCT", "LAML", "ESCA", "GBM", "THYM",
"SKCM", "READ", "UVM", "ACC", "MESO", "KICH", "UCS", "DLBC", "CHOL",
# EXTENDED_DATASETS_OPTION
"LUAD-LCNEC", "BLCA-NET",
"PAAD-IPMN", "PAAD-NET", "PAAD-ACC",
# PAN_CANCER_DATASETS_OPTION
"PANCAN"
]

dataset = serializers.ChoiceField(choices=allowed_dataset_choice, required=True)


RERUN_INPUT_SERIALIZERS: dict[AllowedRerunWorkflow, Type[BaseRerunInputSerializer]] = {
AllowedRerunWorkflow.RNASUM: RnasumRerunInputSerializer,
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@
'drf_spectacular.contrib.djangorestframework_camel_case.camelize_serializer_fields',
'drf_spectacular.hooks.postprocess_schema_enums'
],
'SCHEMA_PATH_PREFIX': f'/api/{API_VERSION}/',
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from enum import Enum
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo

import factory
from django.utils.timezone import make_aware

from workflow_manager.models import Workflow, WorkflowRun, Payload, Library, State
from workflow_manager.models import Workflow, WorkflowRun, Payload, Library, State, LibraryAssociation


class TestConstant(Enum):
Expand Down Expand Up @@ -72,3 +72,72 @@ class Meta:
comment = "Comment"
payload = None
workflow_run = factory.SubFactory(WorkflowRunFactory)


class PrimaryTestData():
WORKFLOW_NAME = "TestWorkflow"

STATUS_DRAFT = "DRAFT"
STATUS_START = "READY"
STATUS_RUNNING = "RUNNING"
STATUS_END = "SUCCEEDED"
STATUS_FAIL = "FAILED"
STATUS_RESOLVED = "RESOLVED"



def create_primary(self, generic_payload, libraries):
"""
Case: a primary workflow with two executions linked to 4 libraries
The first execution failed and led to a repetition that succeeded
"""

wf = WorkflowFactory(workflow_name=self.WORKFLOW_NAME + "Primary")

# The first execution (workflow run 1)
wfr_1: WorkflowRun = WorkflowRunFactory(
workflow_run_name=self.WORKFLOW_NAME + "PrimaryRun1",
portal_run_id="1234",
workflow=wf
)

for i, state in enumerate([self.STATUS_DRAFT, self.STATUS_START, self.STATUS_RUNNING, self.STATUS_FAIL]):
StateFactory(workflow_run=wfr_1, status=state, payload=generic_payload,
timestamp=make_aware(datetime.now() + timedelta(hours=i)))
for i in [0, 1, 2, 3]:
LibraryAssociation.objects.create(
workflow_run=wfr_1,
library=libraries[i],
association_date=make_aware(datetime.now()),
status="ACTIVE",
)

# The second execution (workflow run 2)
wfr_2: WorkflowRun = WorkflowRunFactory(
workflow_run_name=self.WORKFLOW_NAME + "PrimaryRun2",
portal_run_id="1235",
workflow=wf
)
for i, state in enumerate([self.STATUS_DRAFT, self.STATUS_START, self.STATUS_RUNNING, self.STATUS_END]):
StateFactory(workflow_run=wfr_2, status=state, payload=generic_payload,
timestamp=make_aware(datetime.now() + timedelta(hours=i)))
for i in [0, 1, 2, 3]:
LibraryAssociation.objects.create(
workflow_run=wfr_2,
library=libraries[i],
association_date=make_aware(datetime.now()),
status="ACTIVE",
)

def setup(self):

# Common components: payload and libraries
generic_payload = PayloadFactory() # Payload content is not important for now
libraries = [
LibraryFactory(orcabus_id="01J5M2JFE1JPYV62RYQEG99CP1", library_id="L000001"),
LibraryFactory(orcabus_id="02J5M2JFE1JPYV62RYQEG99CP2", library_id="L000002"),
LibraryFactory(orcabus_id="03J5M2JFE1JPYV62RYQEG99CP3", library_id="L000003"),
LibraryFactory(orcabus_id="04J5M2JFE1JPYV62RYQEG99CP4", library_id="L000004")
]

self.create_primary(generic_payload, libraries)
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
import os
from unittest import skip
from unittest.mock import MagicMock

from django.test import TestCase
from libumccr.aws import libeb

from workflow_manager.models import WorkflowRun
from workflow_manager.models.workflow import Workflow
from workflow_manager.tests.factories import PrimaryTestData
from workflow_manager.urls.base import api_base


logger = logging.getLogger()
logger.setLevel(logging.INFO)

Expand All @@ -28,3 +32,47 @@ def test_get_api(self):
response = self.client.get(f"{self.endpoint}/")
logger.info(response)
self.assertEqual(response.status_code, 200, 'Ok status response is expected')


class WorkflowRunRerunViewSetTestCase(TestCase):
endpoint = f"/{api_base}workflowrun"

def setUp(self):
os.environ["EVENT_BUS_NAME"] = "mock-bus"
PrimaryTestData().setup()
self._real_emit_event = libeb.emit_event
libeb.emit_events = MagicMock()

def tearDown(self) -> None:
libeb.emit_event = self._real_emit_event

def test_rerun_api(self):
"""
python manage.py test workflow_manager.tests.test_viewsets.WorkflowRunRerunViewSetTestCase.test_rerun_api
"""
wfl_run = WorkflowRun.objects.all().first()
payload = wfl_run.states.get(status='READY').payload
payload.data = {
"inputs": {
"someUri": "s3://random/prefix/"
},
"engineParameters": {
"sourceUri": f"s3:/bucket/{wfl_run.portal_run_id}/",
}
}
payload.save()

response = self.client.post(f"{self.endpoint}/{wfl_run.orcabus_id}/rerun")
self.assertIn(response.status_code, [400], 'Workflow name associated with the workflow run is not allowed')

# Change the workflow name to 'rnasum' as this is the only allowed workflow name for rerrun
wfl = Workflow.objects.all().first()
wfl.workflow_name = "rnasum"
wfl.save()

response = self.client.post(f"{self.endpoint}/{wfl_run.orcabus_id}/rerun", data={"dataset": "INVALID_CHOICE"})
self.assertIn(response.status_code, [400], 'Invalid payload expected')

response = self.client.post(f"{self.endpoint}/{wfl_run.orcabus_id}/rerun", data={"dataset": "BRCA"})
self.assertIn(response.status_code, [200], 'Expected a successful response')
self.assertTrue(wfl_run.portal_run_id not in str(response.content), 'expect old portal_rub_id replaced')
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from workflow_manager.viewsets.payload import PayloadViewSet
from workflow_manager.viewsets.analysis_context import AnalysisContextViewSet
from workflow_manager.viewsets.state import StateViewSet
from workflow_manager.viewsets.workflow_run_action import WorkflowRunActionViewSet
# from workflow_manager.viewsets.library import LibraryViewSet
from workflow_manager.viewsets.workflow_run_comment import WorkflowRunCommentViewSet
from workflow_manager.settings.base import API_VERSION
Expand All @@ -22,6 +23,7 @@
router.register(r"analysiscontext", AnalysisContextViewSet, basename="analysiscontext")
router.register(r"workflow", WorkflowViewSet, basename="workflow")
router.register(r"workflowrun", WorkflowRunViewSet, basename="workflowrun")
router.register(r"workflowrun", WorkflowRunActionViewSet, basename="workflowrun-action")
router.register(r"payload", PayloadViewSet, basename="payload")

# may no longer need this as it's currently included in the detail response for an individual WorkflowRun record
Expand Down
Loading