Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance question and batching #1

Open
tisabe opened this issue Feb 26, 2024 · 5 comments
Open

Performance question and batching #1

tisabe opened this issue Feb 26, 2024 · 5 comments

Comments

@tisabe
Copy link

tisabe commented Feb 26, 2024

Hi there!

Thanks for this great repository and sharing your implementation of PaiNN!
I have a question/issue regarding the advertised performance:
Is the inference time mentioned in the readme the timing on the validation split, when you run validate.py?
When I ran validate.py myself, I did not get quite the same performance on a newer gpu (A100) on a cluster:
Starting 100 epochs with 587137 parameters. Jitting... ... [Epoch 11] train loss 0.019037, epoch 227907.39ms - val loss 0.079965 (best), infer 18.01ms
I haven't managed to check GPU usage yet, but this timing is actually similar to what I got with the original.
To make sure the gpu was found by Jax, I printed:
Jax devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

What has sped up my training in the past was using dynamic batching, i.e. collecting graphs for a batch until a maximum number of edges, nodes or graphs is reached. It would be interesting to see if this speeds up training in this minimum example as well. If understand correctly, the batches from schnetpack.QM9 come naively batched, with the same number of graphs?

@gerkone
Copy link
Owner

gerkone commented Feb 26, 2024

Hey Tim
Glad you found it useful. Looking at it now the timings seem a bit sketchy, especially the torch one. I reran the validation, but after jitting I still get similar values for jax: around 8ms on my gpu (RTX 4000), so your result is unexpected. It should not be compiling again after the first epoch as the validation loader is not shuffled. Just to be sure could you check if that's the case?

Starting 100 epochs with 587137 parameters.
Jitting...
[Epoch    1] train loss 0.334538, epoch 357033.30ms - val loss 0.232655 (best), infer 33.78ms
[Epoch    2] train loss 0.053471, epoch 287106.63ms - val loss 0.163645 (best), infer 8.08ms
[Epoch    3] train loss 0.035966, epoch 277099.84ms - val loss 0.106103 (best), infer 8.10ms

This is in line with what I on egnn. Torch is probably off though, and quickly rerunning it got me to ~13ms. Thanks for pointing out, I'll update the readmes once I have time to do try it properly.

As for batching, one could definitely improve it, for example by solving a special knapsack problem. Right now the batches are padded to 1.3 times the worst case (here)

  max_batch_nodes = int(
      1.3 * max(sum(d["_n_atoms"]) for d in dataset.val_dataloader())
  )

which is clearly not very smart. On the other hand the point of these experiments was not about performance, but more about validating PaiNN and confirming it does what it should. On top of this GPU utilization during training it hardly goes above 30% on QM9. If you look around, especially in the QM9 code, you'll see how hacky it.

@tisabe
Copy link
Author

tisabe commented Feb 27, 2024

I did the advanced profiling technique of putting a print statement in the jitted functions, and they were all compiled just once:

Jax devices:  [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
Target:  mu
Starting 100 epochs with 587137 parameters.
Jitting...
unjitted call to update
unjitted call to train_mse
unjitted call to predict
[Epoch    1] train loss 0.315345, epoch 359965.70msunjitted call to eval_mae
unjitted call to eval_mae
unjitted call to predict
 - val loss 0.196682 (best), infer 16.31ms
[Epoch    2] train loss 0.052425, epoch 281621.66ms
[Epoch    3] train loss 0.038055, epoch 252686.31ms
[Epoch    4] train loss 0.026014, epoch 248861.45ms
[Epoch    5] train loss 0.022775, epoch 232528.28ms
[Epoch    6] train loss 0.019563, epoch 233256.68ms
[Epoch    7] train loss 0.017503, epoch 243905.89ms
[Epoch    8] train loss 0.013741, epoch 236907.86ms
[Epoch    9] train loss 0.015446, epoch 243664.94ms
[Epoch   10] train loss 0.012798, epoch 248185.17ms
[Epoch   11] train loss 0.012427, epoch 233851.51ms - val loss 0.088224 (best), infer 17.92ms
[Epoch   12] train loss 0.009694, epoch 236405.61ms

@gerkone
Copy link
Owner

gerkone commented Feb 27, 2024

Weird. What's not normal about your times is that the inference run at epoch 1 takes about the same as second one at epoch 11 This should not be the case, since jitting alone will move the average runtime upwards (you can see this from the output I printed). Did you modify the inference function? The runtimes I report are for the model forward only, not the graph transform and dataloading.

@tisabe
Copy link
Author

tisabe commented Feb 29, 2024

I seem to get different timings when I change the val-freq argument. When I evaluate after every epoch, I get much better times:

[Epoch    2] train loss 0.052915, epoch 258117.03ms - val loss 0.124259 (best), infer 2.76ms
[Epoch    3] train loss 0.035268, epoch 240552.22ms - val loss 0.132137, infer 2.77ms
[Epoch    4] train loss 0.028790, epoch 233492.50ms - val loss 0.138793, infer 2.79ms

Infer time is the time per graph, right? I did not change the infer function.

@gerkone
Copy link
Owner

gerkone commented Mar 1, 2024

Time is per batch of 100 graphs (as default), after padding. This has gotten even stranger. The time I put in the readme is from evaluating every graph, but at the moment I honestly don't know why the times are like this.
I still think it's somehow jitting again every time if you evaluate at the 10th epoch. Of course this should not happen since the shapes are always the same. Small check would be to discard the first runtime in the evaluate function, for example by dry-running the model over the first batch out of next(iter(loader)).
There is likely something wrong in the validation experiments, so don't take them as a good starting point. I might spend some time to investigate this next week, I'll let you know if I find something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants