Skip to content

Commit

Permalink
loading of saved model traj
Browse files Browse the repository at this point in the history
  • Loading branch information
waymao committed Jun 6, 2024
1 parent 9ff72c7 commit a160f29
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions compute_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import psutil
import pickle

from ai2thor.server import MultiAgentEvent

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, gt_dataset_path: str):

# initialize the metrics with pre-computed info
self.metrics: List[Metric] = [
AreaCoverage(),
# AreaCoverage(),
CLIP_SemanticUnderstanding(scene_to_cmds=self.scene_to_cmd),
RootMSE(),
DeltaDist(),
Expand All @@ -55,6 +55,49 @@ def _build_traj_index(self):
scene = json_data_step0['scene']
self.cmd_to_traj_name[nl_cmd] = traj_name
self.scene_to_cmd[scene].append(nl_cmd)


def read_eval_traj(self, traj_path) -> Tuple[TrajData, str, str, dict]:
"""
Read a trajectory from a pickle file.
"""
with open(traj_path, 'rb') as f:
traj_data = pickle.load(f)
scene = ""
cmd = ""
img_history = []
xyz_body_history = []
xyz_ee_history = []
yaw_body_history = []
error_history = []
action_history = []
num_steps = 0

# keys are ['task', 'scene', 'img', 'xyz_body', 'xyz_body_delta', 'yaw_body', 'yaw_body_delta', 'pitch_body', 'xyz_ee', 'xyz_ee_delta', 'pickup_dropoff', 'holding_obj', 'control_mode', 'action', 'terminate', 'step', 'timeout', 'error']
for entry in traj_data['trajectory_data']:
scene = entry['scene']
cmd = entry['task']

while len(entry['img'].shape) > 3:
entry['img'] = entry['img'][0]
img_history.append(entry['img'])

xyz_body_history.append(entry['xyz_body'])
yaw_body_history.append(entry['yaw_body'])
xyz_ee_history.append(entry['xyz_ee'])

if type(entry['error']) == list:
entry['error'] = entry['error'][0]
error_history.append(entry['error'])
action_history.append(entry['action'])
num_steps += 1

return TrajData(
img=np.array(img_history), xyz_body=np.array(xyz_body_history), yaw_body=np.array(yaw_body_history),
xyz_ee=np.array(xyz_ee_history), steps=num_steps,
errors=error_history, action=action_history
), cmd, scene, traj_data['final_state']



def convert_gt_hdf5_entry(self, traj_hdf_group: h5py.Group, desired_len: int) -> Tuple[TrajData, str, str]:
Expand Down Expand Up @@ -179,13 +222,39 @@ def main():
evaluator = Evaluator(args.gt_traj_path)
if args.use_gt_for_eval:
eval_gt(evaluator, args.gt_traj_path, args.save_csv_file, print_every_step=args.print_every_step)
else:
eval_model_traj(evaluator, args.eval_traj_path, args.save_csv_file, print_every_step=args.print_every_step)


def ensure_path_exists(filename: str):
# src: https://stackoverflow.com/questions/12517451
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))

def eval_model_traj(evaluator: Evaluator, traj_path: str, save_csv_file: str, print_every_step: bool=False):
"""
Run evaluation on a model trajectory.
"""
result_list = []
files = os.listdir(traj_path)
for file in tqdm(files):
file_name = os.path.join(traj_path, file)
print(file_name)
traj_data, cmd, scene, end_inf_state = evaluator.read_eval_traj(file_name)
breakpoint()
result = evaluator.evaluate_one_traj(scene, cmd, traj_data, end_inf_state)
result['cmd'] = cmd
result['scene'] = scene
result['model_name'] = "model"
result_list.append(result)
if print_every_step:
print(result)

df = pd.DataFrame(result_list)
ensure_path_exists(save_csv_file)
df.to_csv(save_csv_file, index=False)
print("Results written to", save_csv_file)

def eval_gt(evaluator: Evaluator, gt_path: str, save_csv_file: str, print_every_step: bool=False):
"""
Run evaluation on ground truth dataset.
Expand All @@ -204,7 +273,6 @@ def eval_gt(evaluator: Evaluator, gt_path: str, save_csv_file: str, print_every_
result_list.append(result)
if print_every_step:
print(result)
print("Memory Use:", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)

df = pd.DataFrame(result_list)
ensure_path_exists(save_csv_file)
Expand Down

0 comments on commit a160f29

Please sign in to comment.