-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharpwaves performance fixes #288
Sharpwaves performance fixes #288
Conversation
44f67a6
to
3c28739
Compare
Since Python for-loops are infamously slow, I vectorized as much as I could of the HUGE DISCLAIMER: I did my best to keep the math intact but I did this like 10 days ago so I'm not 100% sure the results are identical. Tests are passing but I haven't checked what they're doing. It's also interesting that this opens the door to vectorizing the whole
|
43338cb
to
63f5bb9
Compare
3072237
to
a62eae5
Compare
Fixed the test errors, caused by doing a [list] - [list] operation. I ran the tests in my computer and did not get the error, probably because self.sw_settings["sharpwave_features"]["width"] was not set. |
Wow, that's great additions and impressive speed-up! Also, the tests in /tests are called by tox, so the Github Actions tests should be identical. |
Update on steepness: I don't think this is possible to vectorize, due to the fact that each peak-to-trough / trough-to-peak distance is different. Using nd.image is possible but would require building masks out of the indexes, which requires extra computation, and using a view on numpy arrays as I have it currently is pretty efficient already anyway. Btw, I wanted to ask, is there a possibility that there is a bug here: rise_steepness = np.max(np.diff(self.data_process_sw[peak_idx_left : trough_idx + 1]))
decay_steepness = np.max(np.diff(self.data_process_sw[trough_idx : peak_idx_right + 1]) )
|
Yes! I think you're right. That's a bug. It should have been always the |
Cool! I was a bit hesitant to bring it up since I still don't really understand what a lot of the features are doing but this one was easy to visualize so the typo stood out to me. I added the correction to the PR. |
When running the profiler on an offline analysis sharpwaves was, together with bursts, one of the more computationally costly features to calculate. I made some changes that brought the computation time for this feature down to 50%:
Commit 1 - Simplify
_get_peaks_around
: The profiler showed a very high amount of calls toargsort
from the method_get_peaks_around
and I realized that for each trough in the data, this function was being called and subsequently calling argsort twice. I worked out the logic of the function and realized that sincescipy.signal.find_peaks
is returning the peak indexes in order the calculation can be reduced to an array filtering operation which is much faster.Commit 2 - Single pass adjacent peak finding: Even with the previous change, it seemed to me that calling
_get_peaks_around
once per trough, and then doing a comparison on the whole list of indexes was not necessary, so I replaced the_get_peaks_around
calls with a variableright_peak_idx
which each loop it increases until it goes past the current trough, keeping track of the right adjacent peak to the current trough at each step of the loop. This makes it so that all the work_get_peaks_around
was doing we get for free now.Commit 3 - fftconvolve: Since
signal.convolve
was callingsignal.fftconvolve
under-the-hood, I changed it to that, which removes the overhead of having to decide which method to use and makes the code more transparent in my opinion. Barely impacts performance.-Before changes: 58.61098559697469 seconds per run.
-After commit 1: 53.140084981918335 seconds per run.