Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0018: torch.monitor - events and counters for PyTorch #30

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions RFC-0018-torch-monitor.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# torch.monitor - Standardizing events and counters for core Pytorch

This RFC proposes a new PyTorch package called `torch.monitor`. This would provide a standardized interface for logging events as well as always on counters to enable monitoring of training and inference jobs.

## Motivation

Generally PyTorch jobs have two systems that are used to log information about the job:

1. Time series logger such as TensorBoard used during training, requiring that the user manually log certain metrics they're interested in.
2. `torch.profiler` which records operator level stats that can be selectively turned on if a user is experiencing performance issues and needs deeper level insight into their model. The profiler has performance overhead so is typically only used when a performance issue arises.

Both of these systems are proactive and require manual user intervention to debug their jobs either by adding new metrics to TensorBoard or by manually enabling profiling to get low level performance details.

For distributed jobs and post-mortem investigation into modelling issues it helps to have always-on metrics and events which can be referred to after the fact to understand the model behavior for training or inference. These always-on metrics can also be used to add high level monitoring to proactively detect issues with training or inference.

There's a number of existing metrics and event logging systems in PyTorch but with no standardization across the subprojects. This aims to provide a shared interface that all PyTorch projects can use in a consistent way.

## Description

The goal of this RFC is to provide an interface to track high level events (`events`) and summary statistics (`stats`). This document defines the high level python interface for these though the core implementation will be done in C++ and tied to the existing PyTorch profiler where possible.

Future work will add integrations into PyTorch core and distributed to log key events and metrics to the users. It'll also include out of the box integrations for commonly logged stats and logging destinations.

### Events
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MLflow maintainers are very interested in this event abstraction for the purposes of automatically recording PyTorch model training sessions and their metadata (parameters, metrics, and files) as MLflow Runs. We currently have an autologging integration for PyTorch-Lightning (docs: https://mlflow.org/docs/latest/tracking.html#pytorch-experimental, examples: https://github.com/mlflow/mlflow/blob/master/examples/pytorch/MNIST/mnist_autolog_example.py). However, we've had difficulty implementing this functionality for general PyTorch models due to the lack of: 1. a clear interface for determining when training begins and ends, and 2. instrumentation providing a standardized, meaningful set of parameters / metrics for model training sessions.

In this vein, I have a couple questions:

  1. When events are integrated into PyTorch, is the vision to record an event when model training starts and when it ends? If so, for the "training start" event, will model parameters be recorded? Similarly, for the "training end" event, can the event payload refer to the trained model? This would allow us to record model files to MLflow.

  2. Is there a plan to define a standardized set of metrics for PyTorch models and emit them via the API, or is this up to the user?

Thank you in advance for your feedback, and please let us know if / how we can be involved in the design process here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This proposal is primary intended to standardize performance/system metrics and events for now. How this is exposed to users for tracking modelling metrics hasn't been fully fleshed out and there's lots of existing solutions for that already (i.e. lightning, tensorboard, etc). That would be a good follow up RFC to this though may be much more involved since it'll likely play a bigger role in how people use PyTorch.

  1. As you rightly pointed out PyTorch doesn't have a concept of training start or end and this RFC change that. There may be some ways to add heuristics to track that (i.e. when optimizer is setup on a model params) but that isn't covered by this RFC
  2. There are plans to have standardized metrics for things like system util, memory, dataloaders etc. The full set for that isn't complete and the intention is that it'll be easy to add new ones as need arises. (PRs welcome)
    Things like dataloader latency will likely be always be available but anything that requires background processing such as system util will likely require a user to set it up.

If you have design ideas for what a good model logging interface would look like for users, feel free to post them as an RFC here or you can setup a meeting with folks to discuss. Always happy to chat :)


Events have a generic metadata dict to store arbitrary information. These events are intended to have a relatively low QPS compared to the existing `RECORD_FUNCTION` profiler interface and should have the same frequency as one might use stderr logging.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would something like "starting training loop iteration" be an event or is it too fine grained?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. That seems like it would be fine from a performance standpoint though it may be fairly spammy to the logs so probably not a great default. Similar to how you normally log metrics during training, you probably would only want a constant number of events per epoch not per batch

Some of this needs to be experimented with more to get a good estimate on what is acceptable and under what loggers


**Interfaces**

```py
@dataclass
class Event:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Event also contain a "severity" field (i.e. INFO, WARN, ERROR) similar to a conventional logger? I believe one can use the metadata attribute for that, but would it make sense to have first-class support for it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We debated including this field, but decided against it. For some events, severity might not make sense. E.g. a DeprecationWarning event would all have the same level of severity (i.e. this API will break if you don't migrate off of it soon). Since severity can be added with helper functions, we decided to keep the Event interface simple as possible.

# type is a globally unique type name
# Ex: torch.elastic.RdzvEvent
type: str

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming this is used by multiple torch libraries, does the event need to include a notion of namespace to reduce the chance of collisions around the type ? or would that be handled based on naming conventions around type ?

# message is a human readable string describing the event at a
# high level
message: Optional[str]
# Timestamp is the unix wall clock time of the event in seconds
timestamp: Optional[float]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wallclock is an underspecified clock that is frail with issues such as drifts and rollbacks.

This should with a monotonic clock with millisecond resolution for it to be more broadly useful.

# Metadata can contain any event specific fields and is intended
# for fields to later be aggregated or filtered on.
metadata: Dict[str, Union[str, float, int, bool]]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should model events closely to what the system distributed tracing community have come up with.
The Open Telemetry Concepts page covers it pretty well.

In particular, we're missing the following:

Event ID, 8 bytes is more than enough.
An event ID is important if we want to do any forensics on event streams, specially across ranks.

Notion of trace/span. They makes it easy to understand the execution structure. In particular for events that cover a time range (an epoch, a forward pass, etc). For the purpose of this RFC, this information can be encoded with explicit start/stop flags and a set of parent events.

```

All events are sent to all registered event handlers. Event handlers can ignore events that are not relevant to them. The backend for these events will be implemented in C++ and merged with the existing profiler where possible.

Since handlers can be registered in Python, the frequency of logged events must be relatively low to avoid performance impact.

```py
event_handlers = []

class EventHandler(ABC):
@abstractmethod
def handle(event: Event) -> None:
...
def register(handler: EventHandler):
global event_handlers
event_handlers.append(handler)

def log(event: Event) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: How about a convenience function def log(type: str, message: Optional[str] = None)? The syntax would be cleaner for simple use cases compared to log(Event(type=...))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you envision filtering/using that data from the aggregator? For text logs is it better to just use the standard python logging.Logger?

for handler in event_handlers:
handler.handle(event)
```

**Example Usage**

```py
from torch.monitor import Event, log

@dataclass
class RdzvEvent(Event):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is responsible for setting the timestamp? Event itself or the log() function? In my opinion an event "materializes" by the time it gets logged, so log() makes more sense. In such case, not sure whether timestamp should be part of the Event struct or something that is internally recorded along with the Event struct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the event. For something like a GPU when you log an event is not the same time as the event occurred and thus an external time tracker is necessary. Though it would be pretty easy to have Event auto-populate that field on creation and thus make it optional to provide that timestamp

def __init__(self, run_id: str, rank: int, message: str) -> None:
super().__init__(
type="elastic_rdvz",
metadata={"run_id": run_id, "rank": pid},
message=message,
)

event = RdzvEvent(run_id="1234", rank=0, message="rendezvous started")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Event logging should never happen unconditionally.
The API must offer an efficient way to test whether an event should be emitted and make it easy for users to filter/sample event emission.

log(event)
```

```py
class ExampleJSONLHandler(EventHandler):
def handle(event: Event) -> None:
with open(self.file_name, 'a') as f:
json.dump(f, event.asdict())

register(ExampleJSONLHandler('logs.jsonl'))
```

### Stats

These stats are designed to be always on metrics to track various key metrics that can be used for monitoring and debugging the performance of your training jobs or inference systems.

These are defined in Python for readability purposes. The core would be implemented in C++ for minimal overhead tracking of metrics and the ability to log more fine grained metrics automatically from things such as autograd or the optimizers.

**Interfaces**

```py
class StatType(Enum):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this API my understanding is that we are performing aggregations on the fly within each time window and send it to the sink (e.g. a time-series database)? Would it make sense to send raw values to the sink and expect it to do the aggregations offline? For instance with this API how can you retrieve the mean of a metric for an arbitrary time frame? I see no way of recording individual items. The closest is VALUE but it only records the most recent item in the window.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design allows for implementing a handler to either do aggregations or output raw values. The aggregation is intended to be similar to https://github.com/facebook/fb303 and be able to do high performance logging from cpp.

The design does allow for a handler to log the raw values though from a performance standpoint there's a lot to consider. That would likely require some careful sampling or batching to be comparable performance to the aggregations.

For time frames, that's up to the handler/user. The out of the box implementation I'm planning to write is only going to support fixed time frames for now since that seems like it would be sufficient for most users. From my own experience, fixed 60s windows are good enough for production use cases

# VALUE exports the most recently set value.
VALUE = "value"
# MEAN computes the mean of the set values within the window.
MEAN = "mean"
# COUNT tracks the number of times a value is set within the window.
COUNT = "count"
# SUM computes the sum of the values set within the window.
SUM = "sum"

MAX = "max"
MIN = "min"

# These may not be present in the initial implementation:

# HISTOGRAM computes summary statistics such as P50, P90, P95, P99
HISTOGRAM = "histogram"
# STDDEV computes the standard deviation of the values set within
# the window.
STDDEV = "stddev"

_collectors: Set[StatCollector] = set()

@dataclass
class Stat:
# key is the name of the stat.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't counter a better name for this?

Shouldn't it include help and unit fields as well to make them more discoverable?

# Each type of stat should have a globally unique key.
key: str
# Aggregations is how this stat should be aggregated.
aggregations: Set[StatType]

def add(self, v: float) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A model where counters only accumulate data and have the collectors sample them is much more efficient and predictable. This significantly reduces the impact of badly implemented collectors and lowers the collection cost.

Another thing that we should provide are passive counters that produce data on demand through a callback.
An example of that would be a counter for CUDA memory usage where we only want to measure it once per collector cycle.

for collector in _collectors:
collector.handle(self, v)

class StatCollector(ABC):
def handle(self, stat: Stat, v: float) -> None:
...

def register_stat_collector(collector: StatCollector) -> None:
_collectors.add(collector)
```

**Example Usage**

```py
from torch.monitor import Stat, StatType

BATCH_LATENCY = Stat(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stats are indirectly registered with collectors when they produce data.
This model is problematic because it doesn't allow users to query and filter which counters they are interested in.

Another issue is that we can't record that fact that no data was collected for a specific Stat. An explicit count of zero is a lot easier to handle than a missing entry in the logs.

Another thing is that we should decouple collection from aggregation so that users can override them on a per-counter base.

Funny enough, this example is a great case for that. Means and histograms are not great summary for operation as they are not robust. Some users will rather have P95 and P99 collected.

key="training.batch.latency",
aggregations={StatType.MEAN, StatType.HISTOGRAM},
)
EPOCH_LATENCY = Stat(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support resource metrics like GPU utilization or CPU/memory usage?

key="training.epoch.latency",
aggregations={StatType.MEAN, StatType.HISTOGRAM},
)

def train(...):
for i in range(epochs):
epoch_start = time.time()
for x, y in dataloader:
batch_start = time.time()
y_pred = model(x)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we logging the model related metrics like loss, accuracy by torch.monitor package?

...
BATCH_LATENCY.add(time.time()-batch_start)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite fragile and quite common.
Could we have a context manager that simplifies it?
Something like:

with BATCH_LATENCY.add_time():
   y_pred = model(x)

EPOCH_LATENCY.add(time.time()-epoch_start)
Comment on lines +166 to +167

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example!
One potential use case for torch.monitor would be logging real time energy and carbon consumption. For example, one might want to monitor real time CPU utilization and energy counters to ensure accurate accounting (something along the lines of this project: https://github.com/Breakend/experiment-impact-tracker). This sort of logging would ideally happen in a relatively short interval (eg, every few seconds instead of being tied to how the code is structured).
Ideally there would be a background thread that queries these metrics instead of having users manually add them to the application code. In fact, the best setup would require very little user intervention to monitor these metrics. This use case seems to be in line with the motivation of the RFC. But it is not clear to me from the examples that it is possible to achieve with the current design. Could you share more details on whether this sort of interaction-less/background monitoring is covered by torch.monitor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yes this is definitely within the scope of torch.monitor. We have some tentative plans to provide out of the box system metrics monitor (i.e. call a method and get system metrics logged via a background thread) and something like energy accounting would be right in line with that

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@d4l3k thanks for letting me know. Will that be included in this RFC then?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jieru-hu the RFC focuses on APIs we're introducing rather than the specifics of which metrics we plan on counting. I think it would be better suited for a PR after torch.monitor is introduced. I do think it's a compelling use case and a good reason why we should add this functionality.

Copy link

@jieru-hu jieru-hu Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @edward-io. sorry if my question was not clear. To clarify, I'm not asking about adding the energy metrics. I'm more curious about the mentioned tentative plan for providing out of the box system metrics monitor. The example here isn't clear on how metrics with no user interaction will be handled, since it looks like the one here focuses on monitoring epoch run time (which is a user-specified metric). It would be nice to have an example to show how system metrics with no user interaction will be handled (e.g., CPU utilization).

```

Collectors:

```py
class AggregatingStatCollector(StatCollector):
stats: Set[Stat]

count: DefaultDict[str, float]
sum: DefaultDict[str, float]
value: Dict[str, float]

def handle(self, stat: Stat, v: float) -> None:
stats.add(stat)

if (StatType.MEAN in stat.aggregations
or StatType.COUNT in stat.aggregations):
self.count[stat.key] += 1

if (StatType.MEAN in stat.aggregations
or StatType.SUM in stat.aggregations):
self.sum[stat.key] += v

if StatType.VALUE in stat.aggregations:
self.value[stat.key] = v

...

def report(self) -> Dict[str, float]:
out = {}
for stat in self.stats:
for type in stat.aggregations:
if type == StatType.MEAN:
out[stat.key + ".mean"] = (
self.sum[stat.key] / self.count[stat.key]
)
...

def reset(self) -> None:
...

collector = AggregatingStatCollector()
register_stat_collector(collector)

# in background thread
while True:
stats = collector.report()
collector.reset()
for k, v in stats.items():
tensorboard_writer.add_scalar(k, v)
Comment on lines +212 to +217
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should user write these codes in a background thread? Or this is torch.monitor internal implementation?

time.sleep(60)
Comment on lines +212 to +218

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest against using time-based interval (i.e. aggregate after every 60 seconds) to log events. I think using count-based approach (e.g. aggregate after every 20 items) is more common among users and would be better as an official example.

When logging events such as average of losses, average of batch time, it is important for a user to know that they are aggregated from the same number of samples so that the logged numbers are comparable. Some concrete examples where aggregation every N seconds is problematic:

  • Considering the "STDDEV" aggregation, you would really want to compute stddev from the same number of samples, because otherwise stddev is not comparable at all.
  • Considering the "average of batch time": if we log the "average of batch time" in every N-seconds window, the values that are logged will be biased towards larger values, because larger batch-time will be averaged with fewer samples.

Copy link
Member

@d4l3k d4l3k Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a count based approach the author would specify it on the metric itself? I.e. the author would set that for per epoch data we would log every 1 but for distributed collectives we would log every say 1000. It does add an extra value the author has to configure but seems feasible though potentially susceptible to performance issues depending on the log aggregator

We likely need to support time based aggregation sinks such as Prometheous/Grafana anyways. When logging to time series based systems count-based doesn't match super well since we may have stale values if the events stop happening. If we default to zero if no items between the Prometheus polls we're likely to get spiky outputs

For the aggregation design, we can export the count as well to help the user handle that behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's feasible to support both though it does add some more complexity to the aggregation layer. We can have an optional field as part of the stat definition with the # of samples required.

For conversion to time based loggers we'll likely want to support a "no data" indication so if there's not enough data points we just won't log anything and there will be a gap in the time series -- I believe this should be possible with Prometheus and fb303 though it might be fairly hacky with registering and unregistering series

```

## FAQ

- How does this differ from torch.profiler? Why do we need new ways to track performance?

- The current profiler is typically only turned on when there is an issue. These metrics and events are intended to be always-on to help monitor production training jobs and inference.
- We plan to extend the profiler on the C++ side to be able to track these events as well though the user facing interface for defining the stats and events will differ from the existing RECORD_FUNCTION interface.

- Why not log to TensorBoard?

- These events and metrics can and probably will be logged to TensorBoard for many users. This defines a high level interface that core and related projects can use to surface common information in a standardized way. The pluggable collectors enable logging to any location.

- Where will this be used?

- The events system will be immediately used to replace:

- torch.distributed.elastic.events
- torchx.events
- PyTorch Lightning -- LightningDeprecationEvent
- Tentatively collective operations from NCCL/Gloo

- The counters system will be immediately used to unify a number of existing metrics libraries within pytorch such as:
- torch.distributed.elastic.metrics
- torch.distributed.rpc.metrics
- torch.distributed.rpc.parameter_server.metrics
- The torch.jit.runtime.static may also be provided through this interface.

- How are system stats tracked?

- If the user has an existing system metrics tracking system as part of their loggers there's no required system stats tracking.
- If they don't have one, we plan to provide an out of the box SystemStatsProvider that can be enabled with a couple of lines of code in their main method that provides common stats such as CPU/GPU util and memory usage.

- How fast can we log?
- For events we expect that they will be used at about the same frequency as you might use the built in python logger. These are intended for rich structured event logging and thus for performance reasons need to be relatively few.