Skip to content

Commit

Permalink
Aesthetic + small bug fixes to Vizier service
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646600873
  • Loading branch information
xingyousong authored and copybara-github committed Jun 25, 2024
1 parent 086ab0a commit 22609f7
Showing 1 changed file with 30 additions and 45 deletions.
75 changes: 30 additions & 45 deletions vizier/_src/service/vizier_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

"""RPC functions implemented from vizier_service.proto."""

import collections
import datetime
import threading
Expand All @@ -24,7 +25,6 @@
import grpc
import numpy as np
import sqlalchemy as sqla

from vizier import pythia
from vizier import pyvizier as vz
from vizier._src.service import constants
Expand Down Expand Up @@ -55,6 +55,10 @@ def _get_current_time() -> timestamp_pb2.Timestamp:
return now


StudyResource = resources.StudyResource
TrialResource = resources.TrialResource


# TODO: remove context = None
# TODO: remove context = None
class VizierServicer(vizier_service_pb2_grpc.VizierServiceServicer):
Expand Down Expand Up @@ -196,7 +200,7 @@ def CreateStudy(
study_id = study.display_name

# Finally create study in database and return it.
study.name = resources.StudyResource(owner_id, study_id).name
study.name = StudyResource(owner_id, study_id).name
self.datastore.create_study(study)
return study

Expand All @@ -214,8 +218,8 @@ def ListStudies(
context: Optional[grpc.ServicerContext] = None,
) -> vizier_service_pb2.ListStudiesResponse:
"""Lists all the studies in a region for an associated project."""
list_of_studies = self.datastore.list_studies(request.parent)
return vizier_service_pb2.ListStudiesResponse(studies=list_of_studies)
studies = self.datastore.list_studies(request.parent)
return vizier_service_pb2.ListStudiesResponse(studies=studies)

def DeleteStudy(
self,
Expand Down Expand Up @@ -283,7 +287,7 @@ def SuggestTrials(
)
grpc_util.handle_exception(e, context)

study_resource = resources.StudyResource.from_name(study_name)
study_resource = StudyResource.from_name(study_name)
study_id = study_resource.study_id
owner_id = study_resource.owner_id

Expand All @@ -306,14 +310,12 @@ def SuggestTrials(
start_time = _get_current_time()
# Create a new Op if there aren't any active (not done) ops.
try:
new_op_number = (
self.datastore.max_suggestion_operation_number(
study_name, request.client_id
)
+ 1
old_op_number = self.datastore.max_suggestion_operation_number(
study_name, request.client_id
)
except custom_errors.NotFoundError:
new_op_number = 1
old_op_number = 0
new_op_number = old_op_number + 1
new_op_name = resources.SuggestionOperationResource(
owner_id, study_id, request.client_id, new_op_number
).name
Expand Down Expand Up @@ -441,9 +443,7 @@ def SuggestTrials(
new_trial = new_trials.pop()
trial_id = self.datastore.max_trial_id(request.parent) + 1
new_trial.id = str(trial_id)
new_trial.name = resources.TrialResource(
owner_id, study_id, trial_id
).name
new_trial.name = TrialResource(owner_id, study_id, trial_id).name
new_trial.state = study_pb2.Trial.State.ACTIVE
new_trial.start_time.CopyFrom(start_time)
new_trial.client_id = request.client_id
Expand All @@ -455,14 +455,12 @@ def SuggestTrials(
).SerializeToString()

# Store remaining trials as REQUESTED if Pythia over-delivered.
for remaining_trial in new_trials:
for remain_trial in new_trials:
trial_id = self.datastore.max_trial_id(request.parent) + 1
remaining_trial.id = str(trial_id)
remaining_trial.name = resources.TrialResource(
owner_id, study_id, trial_id
).name
remaining_trial.state = study_pb2.Trial.State.REQUESTED
self.datastore.create_trial(new_trial)
remain_trial.id = str(trial_id)
remain_trial.name = TrialResource(owner_id, study_id, trial_id).name
remain_trial.state = study_pb2.Trial.State.REQUESTED
self.datastore.create_trial(remain_trial)

output_op.done = True
self.datastore.update_suggestion_operation(output_op)
Expand Down Expand Up @@ -491,11 +489,8 @@ def CreateTrial(
trial = request.trial
with self._study_name_to_lock[request.parent]:
trial.id = str(self.datastore.max_trial_id(request.parent) + 1)
trial.name = (
resources.StudyResource.from_name(request.parent).trial_resource(
trial_id=trial.id
)
).name
study_resource = StudyResource.from_name(request.parent)
trial.name = (study_resource.trial_resource(trial.id)).name

if trial.state != study_pb2.Trial.State.SUCCEEDED:
trial.state = study_pb2.Trial.State.REQUESTED
Expand Down Expand Up @@ -543,9 +538,7 @@ def AddTrialMeasurement(
ImmutableStudyError: If study was already immutable.
ImmutableTrialError: If the trial cannot be modified.
"""
study_name = resources.TrialResource.from_name(
request.trial_name
).study_resource.name
study_name = TrialResource.from_name(request.trial_name).study_resource.name
if self._study_is_immutable(study_name):
e = custom_errors.ImmutableStudyError(
'Study {} is immutable. Cannot add measurement.'.format(study_name)
Expand Down Expand Up @@ -577,9 +570,7 @@ def CompleteTrial(
context: Optional[grpc.ServicerContext] = None,
) -> study_pb2.Trial:
"""Marks a Trial as complete."""
study_name = resources.TrialResource.from_name(
request.name
).study_resource.name
study_name = TrialResource.from_name(request.name).study_resource.name
if self._study_is_immutable(study_name):
e = custom_errors.ImmutableStudyError(
'Study {} is immutable. Cannot complete trial.'.format(study_name)
Expand Down Expand Up @@ -625,9 +616,7 @@ def DeleteTrial(
context: Optional[grpc.ServicerContext] = None,
) -> empty_pb2.Empty:
"""Deletes a Trial."""
study_name = resources.TrialResource.from_name(
request.name
).study_resource.name
study_name = TrialResource.from_name(request.name).study_resource.name
if self._study_is_immutable(study_name):
e = custom_errors.ImmutableStudyError(
'Study {} is immutable. Cannot delete trial.'.format(study_name)
Expand Down Expand Up @@ -679,7 +668,7 @@ def CheckTrialEarlyStoppingState(
ImmutableStudyError: If study was already immutable.
ImmutableTrialError: If the trial cannot be modified.
"""
trial_resource = resources.TrialResource.from_name(request.trial_name)
trial_resource = TrialResource.from_name(request.trial_name)
study_name = trial_resource.study_resource.name
if self._study_is_immutable(study_name):
e = custom_errors.ImmutableStudyError(
Expand Down Expand Up @@ -841,9 +830,7 @@ def StopTrial(
ImmutableStudyError: If study was already immutable.
ImmutableTrialError: If the trial cannot be modified.
"""
study_name = resources.TrialResource.from_name(
request.name
).study_resource.name
study_name = TrialResource.from_name(request.name).study_resource.name
if self._study_is_immutable(study_name):
e = custom_errors.ImmutableStudyError(
'Study {} is immutable. Cannot stop trial.'.format(study_name)
Expand Down Expand Up @@ -926,12 +913,10 @@ def ListOptimalTrials(
# Find Pareto optimal trials.
ys = np.array(considered_trial_objective_vectors)
n = ys.shape[0]
dominated = np.asarray(
[
[np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)]
for j in range(n)
]
)
dominated = np.asarray([
[np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)]
for j in range(n)
])
optimal_booleans = np.logical_not(np.any(dominated, axis=0))
optimal_trials = []
for i, boolean in enumerate(list(optimal_booleans)):
Expand Down

0 comments on commit 22609f7

Please sign in to comment.