-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC-0018: torch.monitor - events and counters for PyTorch #30
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
# torch.monitor - Standardizing events and counters for core Pytorch | ||
|
||
This RFC proposes a new PyTorch package called `torch.monitor`. This would provide a standardized interface for logging events as well as always on counters to enable monitoring of training and inference jobs. | ||
|
||
## Motivation | ||
|
||
Generally PyTorch jobs have two systems that are used to log information about the job: | ||
|
||
1. Time series logger such as TensorBoard used during training, requiring that the user manually log certain metrics they're interested in. | ||
2. `torch.profiler` which records operator level stats that can be selectively turned on if a user is experiencing performance issues and needs deeper level insight into their model. The profiler has performance overhead so is typically only used when a performance issue arises. | ||
|
||
Both of these systems are proactive and require manual user intervention to debug their jobs either by adding new metrics to TensorBoard or by manually enabling profiling to get low level performance details. | ||
|
||
For distributed jobs and post-mortem investigation into modelling issues it helps to have always-on metrics and events which can be referred to after the fact to understand the model behavior for training or inference. These always-on metrics can also be used to add high level monitoring to proactively detect issues with training or inference. | ||
|
||
There's a number of existing metrics and event logging systems in PyTorch but with no standardization across the subprojects. This aims to provide a shared interface that all PyTorch projects can use in a consistent way. | ||
|
||
## Description | ||
|
||
The goal of this RFC is to provide an interface to track high level events (`events`) and summary statistics (`stats`). This document defines the high level python interface for these though the core implementation will be done in C++ and tied to the existing PyTorch profiler where possible. | ||
|
||
Future work will add integrations into PyTorch core and distributed to log key events and metrics to the users. It'll also include out of the box integrations for commonly logged stats and logging destinations. | ||
|
||
### Events | ||
|
||
Events have a generic metadata dict to store arbitrary information. These events are intended to have a relatively low QPS compared to the existing `RECORD_FUNCTION` profiler interface and should have the same frequency as one might use stderr logging. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would something like "starting training loop iteration" be an event or is it too fine grained? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good question. That seems like it would be fine from a performance standpoint though it may be fairly spammy to the logs so probably not a great default. Similar to how you normally log metrics during training, you probably would only want a constant number of events per epoch not per batch Some of this needs to be experimented with more to get a good estimate on what is acceptable and under what loggers |
||
|
||
**Interfaces** | ||
|
||
```py | ||
@dataclass | ||
class Event: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We debated including this field, but decided against it. For some events, severity might not make sense. E.g. a DeprecationWarning event would all have the same level of severity (i.e. this API will break if you don't migrate off of it soon). Since severity can be added with helper functions, we decided to keep the Event interface simple as possible. |
||
# type is a globally unique type name | ||
# Ex: torch.elastic.RdzvEvent | ||
type: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assuming this is used by multiple torch libraries, does the event need to include a notion of |
||
# message is a human readable string describing the event at a | ||
# high level | ||
message: Optional[str] | ||
# Timestamp is the unix wall clock time of the event in seconds | ||
timestamp: Optional[float] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wallclock is an underspecified clock that is frail with issues such as drifts and rollbacks. This should with a monotonic clock with millisecond resolution for it to be more broadly useful. |
||
# Metadata can contain any event specific fields and is intended | ||
# for fields to later be aggregated or filtered on. | ||
metadata: Dict[str, Union[str, float, int, bool]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should model events closely to what the system distributed tracing community have come up with. In particular, we're missing the following: Event ID, 8 bytes is more than enough. Notion of trace/span. They makes it easy to understand the execution structure. In particular for events that cover a time range (an epoch, a forward pass, etc). For the purpose of this RFC, this information can be encoded with explicit start/stop flags and a set of parent events. |
||
``` | ||
|
||
All events are sent to all registered event handlers. Event handlers can ignore events that are not relevant to them. The backend for these events will be implemented in C++ and merged with the existing profiler where possible. | ||
|
||
Since handlers can be registered in Python, the frequency of logged events must be relatively low to avoid performance impact. | ||
|
||
```py | ||
event_handlers = [] | ||
|
||
class EventHandler(ABC): | ||
@abstractmethod | ||
def handle(event: Event) -> None: | ||
... | ||
def register(handler: EventHandler): | ||
global event_handlers | ||
event_handlers.append(handler) | ||
|
||
def log(event: Event) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: How about a convenience function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would you envision filtering/using that data from the aggregator? For text logs is it better to just use the standard python logging.Logger? |
||
for handler in event_handlers: | ||
handler.handle(event) | ||
``` | ||
|
||
**Example Usage** | ||
|
||
```py | ||
from torch.monitor import Event, log | ||
|
||
@dataclass | ||
class RdzvEvent(Event): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Who is responsible for setting the timestamp? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on the event. For something like a GPU when you log an event is not the same time as the event occurred and thus an external time tracker is necessary. Though it would be pretty easy to have Event auto-populate that field on creation and thus make it optional to provide that timestamp |
||
def __init__(self, run_id: str, rank: int, message: str) -> None: | ||
super().__init__( | ||
type="elastic_rdvz", | ||
metadata={"run_id": run_id, "rank": pid}, | ||
message=message, | ||
) | ||
|
||
event = RdzvEvent(run_id="1234", rank=0, message="rendezvous started") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Event logging should never happen unconditionally. |
||
log(event) | ||
``` | ||
|
||
```py | ||
class ExampleJSONLHandler(EventHandler): | ||
def handle(event: Event) -> None: | ||
with open(self.file_name, 'a') as f: | ||
json.dump(f, event.asdict()) | ||
|
||
register(ExampleJSONLHandler('logs.jsonl')) | ||
``` | ||
|
||
### Stats | ||
|
||
These stats are designed to be always on metrics to track various key metrics that can be used for monitoring and debugging the performance of your training jobs or inference systems. | ||
|
||
These are defined in Python for readability purposes. The core would be implemented in C++ for minimal overhead tracking of metrics and the ability to log more fine grained metrics automatically from things such as autograd or the optimizers. | ||
|
||
**Interfaces** | ||
|
||
```py | ||
class StatType(Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this API my understanding is that we are performing aggregations on the fly within each time window and send it to the sink (e.g. a time-series database)? Would it make sense to send raw values to the sink and expect it to do the aggregations offline? For instance with this API how can you retrieve the mean of a metric for an arbitrary time frame? I see no way of recording individual items. The closest is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current design allows for implementing a handler to either do aggregations or output raw values. The aggregation is intended to be similar to https://github.com/facebook/fb303 and be able to do high performance logging from cpp. The design does allow for a handler to log the raw values though from a performance standpoint there's a lot to consider. That would likely require some careful sampling or batching to be comparable performance to the aggregations. For time frames, that's up to the handler/user. The out of the box implementation I'm planning to write is only going to support fixed time frames for now since that seems like it would be sufficient for most users. From my own experience, fixed 60s windows are good enough for production use cases |
||
# VALUE exports the most recently set value. | ||
VALUE = "value" | ||
# MEAN computes the mean of the set values within the window. | ||
MEAN = "mean" | ||
# COUNT tracks the number of times a value is set within the window. | ||
COUNT = "count" | ||
# SUM computes the sum of the values set within the window. | ||
SUM = "sum" | ||
|
||
MAX = "max" | ||
MIN = "min" | ||
|
||
# These may not be present in the initial implementation: | ||
|
||
# HISTOGRAM computes summary statistics such as P50, P90, P95, P99 | ||
HISTOGRAM = "histogram" | ||
# STDDEV computes the standard deviation of the values set within | ||
# the window. | ||
STDDEV = "stddev" | ||
|
||
_collectors: Set[StatCollector] = set() | ||
|
||
@dataclass | ||
class Stat: | ||
# key is the name of the stat. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't counter a better name for this? Shouldn't it include help and unit fields as well to make them more discoverable? |
||
# Each type of stat should have a globally unique key. | ||
key: str | ||
# Aggregations is how this stat should be aggregated. | ||
aggregations: Set[StatType] | ||
|
||
def add(self, v: float) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A model where counters only accumulate data and have the collectors sample them is much more efficient and predictable. This significantly reduces the impact of badly implemented collectors and lowers the collection cost. Another thing that we should provide are passive counters that produce data on demand through a callback. |
||
for collector in _collectors: | ||
collector.handle(self, v) | ||
|
||
class StatCollector(ABC): | ||
def handle(self, stat: Stat, v: float) -> None: | ||
... | ||
|
||
def register_stat_collector(collector: StatCollector) -> None: | ||
_collectors.add(collector) | ||
``` | ||
|
||
**Example Usage** | ||
|
||
```py | ||
from torch.monitor import Stat, StatType | ||
|
||
BATCH_LATENCY = Stat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stats are indirectly registered with collectors when they produce data. Another issue is that we can't record that fact that no data was collected for a specific Stat. An explicit count of zero is a lot easier to handle than a missing entry in the logs. Another thing is that we should decouple collection from aggregation so that users can override them on a per-counter base. Funny enough, this example is a great case for that. Means and histograms are not great summary for operation as they are not robust. Some users will rather have P95 and P99 collected. |
||
key="training.batch.latency", | ||
aggregations={StatType.MEAN, StatType.HISTOGRAM}, | ||
) | ||
EPOCH_LATENCY = Stat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we support resource metrics like GPU utilization or CPU/memory usage? |
||
key="training.epoch.latency", | ||
aggregations={StatType.MEAN, StatType.HISTOGRAM}, | ||
) | ||
|
||
def train(...): | ||
for i in range(epochs): | ||
epoch_start = time.time() | ||
for x, y in dataloader: | ||
batch_start = time.time() | ||
y_pred = model(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we logging the model related metrics like loss, accuracy by torch.monitor package? |
||
... | ||
BATCH_LATENCY.add(time.time()-batch_start) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks quite fragile and quite common.
|
||
EPOCH_LATENCY.add(time.time()-epoch_start) | ||
Comment on lines
+166
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the example! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, yes this is definitely within the scope of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @d4l3k thanks for letting me know. Will that be included in this RFC then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jieru-hu the RFC focuses on APIs we're introducing rather than the specifics of which metrics we plan on counting. I think it would be better suited for a PR after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks @edward-io. sorry if my question was not clear. To clarify, I'm not asking about adding the energy metrics. I'm more curious about the mentioned tentative plan for providing out of the box system metrics monitor. The example here isn't clear on how metrics with no user interaction will be handled, since it looks like the one here focuses on monitoring epoch run time (which is a user-specified metric). It would be nice to have an example to show how system metrics with no user interaction will be handled (e.g., CPU utilization). |
||
``` | ||
|
||
Collectors: | ||
|
||
```py | ||
class AggregatingStatCollector(StatCollector): | ||
stats: Set[Stat] | ||
|
||
count: DefaultDict[str, float] | ||
sum: DefaultDict[str, float] | ||
value: Dict[str, float] | ||
|
||
def handle(self, stat: Stat, v: float) -> None: | ||
stats.add(stat) | ||
|
||
if (StatType.MEAN in stat.aggregations | ||
or StatType.COUNT in stat.aggregations): | ||
self.count[stat.key] += 1 | ||
|
||
if (StatType.MEAN in stat.aggregations | ||
or StatType.SUM in stat.aggregations): | ||
self.sum[stat.key] += v | ||
|
||
if StatType.VALUE in stat.aggregations: | ||
self.value[stat.key] = v | ||
|
||
... | ||
|
||
def report(self) -> Dict[str, float]: | ||
out = {} | ||
for stat in self.stats: | ||
for type in stat.aggregations: | ||
if type == StatType.MEAN: | ||
out[stat.key + ".mean"] = ( | ||
self.sum[stat.key] / self.count[stat.key] | ||
) | ||
... | ||
|
||
def reset(self) -> None: | ||
... | ||
|
||
collector = AggregatingStatCollector() | ||
register_stat_collector(collector) | ||
|
||
# in background thread | ||
while True: | ||
stats = collector.report() | ||
collector.reset() | ||
for k, v in stats.items(): | ||
tensorboard_writer.add_scalar(k, v) | ||
Comment on lines
+212
to
+217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should user write these codes in a background thread? Or this is torch.monitor internal implementation? |
||
time.sleep(60) | ||
Comment on lines
+212
to
+218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest against using time-based interval (i.e. aggregate after every 60 seconds) to log events. I think using count-based approach (e.g. aggregate after every 20 items) is more common among users and would be better as an official example. When logging events such as average of losses, average of batch time, it is important for a user to know that they are aggregated from the same number of samples so that the logged numbers are comparable. Some concrete examples where aggregation every N seconds is problematic:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a count based approach the author would specify it on the metric itself? I.e. the author would set that for per epoch data we would log every 1 but for distributed collectives we would log every say 1000. It does add an extra value the author has to configure but seems feasible though potentially susceptible to performance issues depending on the log aggregator We likely need to support time based aggregation sinks such as Prometheous/Grafana anyways. When logging to time series based systems count-based doesn't match super well since we may have stale values if the events stop happening. If we default to zero if no items between the Prometheus polls we're likely to get spiky outputs For the aggregation design, we can export the count as well to help the user handle that behavior There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's feasible to support both though it does add some more complexity to the aggregation layer. We can have an optional field as part of the stat definition with the # of samples required. For conversion to time based loggers we'll likely want to support a "no data" indication so if there's not enough data points we just won't log anything and there will be a gap in the time series -- I believe this should be possible with Prometheus and fb303 though it might be fairly hacky with registering and unregistering series |
||
``` | ||
|
||
## FAQ | ||
|
||
- How does this differ from torch.profiler? Why do we need new ways to track performance? | ||
|
||
- The current profiler is typically only turned on when there is an issue. These metrics and events are intended to be always-on to help monitor production training jobs and inference. | ||
- We plan to extend the profiler on the C++ side to be able to track these events as well though the user facing interface for defining the stats and events will differ from the existing RECORD_FUNCTION interface. | ||
|
||
- Why not log to TensorBoard? | ||
|
||
- These events and metrics can and probably will be logged to TensorBoard for many users. This defines a high level interface that core and related projects can use to surface common information in a standardized way. The pluggable collectors enable logging to any location. | ||
|
||
- Where will this be used? | ||
|
||
- The events system will be immediately used to replace: | ||
|
||
- torch.distributed.elastic.events | ||
- torchx.events | ||
- PyTorch Lightning -- LightningDeprecationEvent | ||
- Tentatively collective operations from NCCL/Gloo | ||
|
||
- The counters system will be immediately used to unify a number of existing metrics libraries within pytorch such as: | ||
- torch.distributed.elastic.metrics | ||
- torch.distributed.rpc.metrics | ||
- torch.distributed.rpc.parameter_server.metrics | ||
- The torch.jit.runtime.static may also be provided through this interface. | ||
|
||
- How are system stats tracked? | ||
|
||
- If the user has an existing system metrics tracking system as part of their loggers there's no required system stats tracking. | ||
- If they don't have one, we plan to provide an out of the box SystemStatsProvider that can be enabled with a couple of lines of code in their main method that provides common stats such as CPU/GPU util and memory usage. | ||
|
||
- How fast can we log? | ||
- For events we expect that they will be used at about the same frequency as you might use the built in python logger. These are intended for rich structured event logging and thus for performance reasons need to be relatively few. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MLflow maintainers are very interested in this event abstraction for the purposes of automatically recording PyTorch model training sessions and their metadata (parameters, metrics, and files) as MLflow Runs. We currently have an autologging integration for PyTorch-Lightning (docs: https://mlflow.org/docs/latest/tracking.html#pytorch-experimental, examples: https://github.com/mlflow/mlflow/blob/master/examples/pytorch/MNIST/mnist_autolog_example.py). However, we've had difficulty implementing this functionality for general PyTorch models due to the lack of: 1. a clear interface for determining when training begins and ends, and 2. instrumentation providing a standardized, meaningful set of parameters / metrics for model training sessions.
In this vein, I have a couple questions:
When events are integrated into PyTorch, is the vision to record an event when model training starts and when it ends? If so, for the "training start" event, will model parameters be recorded? Similarly, for the "training end" event, can the event payload refer to the trained model? This would allow us to record model files to MLflow.
Is there a plan to define a standardized set of metrics for PyTorch models and emit them via the API, or is this up to the user?
Thank you in advance for your feedback, and please let us know if / how we can be involved in the design process here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This proposal is primary intended to standardize performance/system metrics and events for now. How this is exposed to users for tracking modelling metrics hasn't been fully fleshed out and there's lots of existing solutions for that already (i.e. lightning, tensorboard, etc). That would be a good follow up RFC to this though may be much more involved since it'll likely play a bigger role in how people use PyTorch.
Things like dataloader latency will likely be always be available but anything that requires background processing such as system util will likely require a user to set it up.
If you have design ideas for what a good model logging interface would look like for users, feel free to post them as an RFC here or you can setup a meeting with folks to discuss. Always happy to chat :)