Skip to content

Commit

Permalink
Support Chakra traces in execution_trace.py (#89)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #89

Support Chakra traces in execution_trace.py

Reviewed By: briancoutinho

Differential Revision: D48409819

fbshipit-source-id: 653642539c68d783be6f06624551dae8f3908e5b
  • Loading branch information
TaekyungHeo authored and facebook-github-bot committed Nov 29, 2023
1 parent 16c64ea commit fd5ad3a
Showing 1 changed file with 97 additions and 21 deletions.
118 changes: 97 additions & 21 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import sys
from enum import Enum
from typing import Any, Iterable, List, Optional, Set, TextIO
from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple

import pydot

Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
id: int,
parent_id: int,
fw_parent_id: int,
seq_id: int,
pid: int,
tid: int,
fw_tid: int,
Expand All @@ -146,6 +147,7 @@ def __init__(
self.fw_tid: int = fw_tid
self.op_schema: str = op_schema
self.fw_parent_id: int = fw_parent_id
self.seq_id: int = seq_id
self.scope: int = scope
self.type: NodeType = self.detect_type(name, inputs, outputs)
# self.inputs: List[Any] = [tuple(i) if isinstance(i, list) else i for i in inputs]
Expand Down Expand Up @@ -284,36 +286,32 @@ def __init__(self, json):
self.proc_group = {}
# list of node ids that start an iteration
self.iteration_ids = []
self.schema: str = json["schema"]
pid = json["pid"]
self.proc_group = {pid: {}}
nodes_list = json["nodes"]
for x in nodes_list:
id = x["id"]
tid = x["tid"]
self.nodes[id] = Node(
x["name"],
id,
x["parent"],
x["fw_parent"],
pid,
tid,
x["fw_tid"],
x["op_schema"] if "op_schema" in x.keys() else "",
x["scope"],
x["inputs"],
x["input_types"],
x["input_shapes"],
x["outputs"],
x["output_types"],
x["output_shapes"],
x.get("rf_id", None),

# Depending on schema, call the right method
node_creation_func = {
"1.0.1": ExecutionTrace._create_node_v1_0_1,
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# Add future versions here
}
create_node = node_creation_func.get(self.schema, None)
if create_node is None:
raise ValueError(
f"No corresponding node creation function found for schema version {self.schema}"
)

for x in nodes_list:
id = x["id"]
self.nodes[id] = create_node(pid, x)
input_tensors = self.nodes[id].get_input_tensors()
output_tensors = self.nodes[id].get_output_tensors()

# track the various process and threads we have
if x["name"] == "__ROOT_THREAD__":
tid = self.nodes[id].tid
self.proc_group[pid][tid] = id

# build tensor reference table
Expand Down Expand Up @@ -349,6 +347,84 @@ def __init__(self, json):
# remove all dataloader ops
self.remove_dataloader_ops()

@staticmethod
def _read_attrs(node: Dict[str, Any]) -> Tuple:
attr_types = {
"fw_parent": int,
"seq_id": int,
"fw_tid": int,
"op_schema": str,
"rf_id": int,
"scope": int,
"tid": int,
}
attr_dict = {
attr["name"]: attr_types[attr["name"]](attr["value"])
for attr in node["attrs"]
if attr["name"] in attr_types.keys()
}

# Ensure all keys have values
if attr_dict.keys() != attr_types.keys():
raise ValueError(
"Not all keys in attr_dict have updated values. Node:" + str(node)
)
return tuple(attr_dict[key] for key in attr_types.keys())

@staticmethod
def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node:
return Node(
x["name"],
x["id"],
x["parent"],
x["fw_parent"],
x["seq_id"],
pid,
x["tid"],
x["fw_tid"],
x.get("op_schema", ""),
x["scope"],
x["inputs"],
x["input_types"],
x["input_shapes"],
x["outputs"],
x["output_types"],
x["output_shapes"],
x.get("rf_id", None),
)

@staticmethod
def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
(
fw_parent,
seq_id,
fw_tid,
op_schema,
rf_id,
scope,
tid,
) = ExecutionTrace._read_attrs(x)

return Node(
x["name"],
x["id"],
x["ctrl_deps"],
fw_parent,
seq_id,
pid,
tid,
fw_tid,
op_schema,
scope,
x["inputs"]["values"],
x["inputs"]["types"],
x["inputs"]["shapes"],
x["outputs"]["values"],
x["outputs"]["types"],
x["outputs"]["shapes"],
rf_id,
)

def get_nodes(self, clean: bool = False):
if clean:
return self.clean_nodes
Expand Down

0 comments on commit fd5ad3a

Please sign in to comment.