-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance question and batching #1
Comments
Hey Tim
This is in line with what I on egnn. Torch is probably off though, and quickly rerunning it got me to ~13ms. Thanks for pointing out, I'll update the readmes once I have time to do try it properly. As for batching, one could definitely improve it, for example by solving a special knapsack problem. Right now the batches are padded to 1.3 times the worst case (here) max_batch_nodes = int(
1.3 * max(sum(d["_n_atoms"]) for d in dataset.val_dataloader())
) which is clearly not very smart. On the other hand the point of these experiments was not about performance, but more about validating PaiNN and confirming it does what it should. On top of this GPU utilization during training it hardly goes above 30% on QM9. If you look around, especially in the QM9 code, you'll see how hacky it. |
I did the advanced profiling technique of putting a print statement in the jitted functions, and they were all compiled just once:
|
Weird. What's not normal about your times is that the inference run at epoch 1 takes about the same as second one at epoch 11 This should not be the case, since jitting alone will move the average runtime upwards (you can see this from the output I printed). Did you modify the inference function? The runtimes I report are for the model forward only, not the graph transform and dataloading. |
I seem to get different timings when I change the val-freq argument. When I evaluate after every epoch, I get much better times:
Infer time is the time per graph, right? I did not change the infer function. |
Time is per batch of 100 graphs (as default), after padding. This has gotten even stranger. The time I put in the readme is from evaluating every graph, but at the moment I honestly don't know why the times are like this. |
Hi there!
Thanks for this great repository and sharing your implementation of PaiNN!
I have a question/issue regarding the advertised performance:
Is the inference time mentioned in the readme the timing on the validation split, when you run validate.py?
When I ran validate.py myself, I did not get quite the same performance on a newer gpu (A100) on a cluster:
Starting 100 epochs with 587137 parameters. Jitting... ... [Epoch 11] train loss 0.019037, epoch 227907.39ms - val loss 0.079965 (best), infer 18.01ms
I haven't managed to check GPU usage yet, but this timing is actually similar to what I got with the original.
To make sure the gpu was found by Jax, I printed:
Jax devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
What has sped up my training in the past was using dynamic batching, i.e. collecting graphs for a batch until a maximum number of edges, nodes or graphs is reached. It would be interesting to see if this speeds up training in this minimum example as well. If understand correctly, the batches from schnetpack.QM9 come naively batched, with the same number of graphs?
The text was updated successfully, but these errors were encountered: