-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Register the overloads added by CustomDist in worker processes #7241
Conversation
] |
Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative |
Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed? |
Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits. |
506e49f
to
54ca772
Compare
tests/smc/test_smc.py
Outdated
return rng.normal(loc=mu, scale=1, size=size) | ||
|
||
with pm.Model(): | ||
mu = pm.CustomDist("mu", 0, logp=_logp, random=_random) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two kinds of CustomDist, if instead of random
you pass dist
you get a different Op type, that I guess would still fail after this PR
Edit: I see you mentioned this in your top message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it fail with random
but not dist
. Testing locally it seems to fail with both for me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tested both the code I used in the linked issue and your example below. I still get no errors using pm.Potential
or dist
argument. Originally I was using pymc==5.10.0
but now I tested with pymc==5.13.0
. Maybe there are differences between Windows and Linux if you're using one?
I'm not completely sure why using dist
works for me but, based on some quick testing, DistributionMeta.__new__
is called when e.g. Normal
is defined and the overloads for builtin distributions are registered there. I'm not well versed in multiprocessing or the way that Python does importing but my hunch is that the worker processes automatically import stuff from pymc
and the overloads get registered as a side effect. For user-defined logprob
etc. this is not the case since the registration isn't done during importing.
I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph: import pymc as pm
def _logp(value, mu):
return -((value - mu) ** 2)
def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)
with pm.Model():
mu = pm.Normal("mu", 0)
pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
pm.sample_smc(draws=6, cores=1) It also fails even with a single core |
22e8f0b did refactoring for |
Co-authored-by: Ricardo Vieira <[email protected]>
Somehow, in main, I am getting |
Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me? |
I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing Can you share more details about your environment/setup? |
I added the output of Basically I followed the install instructions and the pull request tutorial when installing. Might have also |
We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase |
The test does fail without the changes when I run it from miniforge prompt though. |
Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce |
Is there anything that needs to be done here besides running the tests? |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7241 +/- ##
==========================================
+ Coverage 92.18% 92.19% +0.01%
==========================================
Files 103 103
Lines 17259 17282 +23
==========================================
+ Hits 15910 15933 +23
Misses 1349 1349
|
Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet. |
I guess the underlying reason for the failure is that pickling of |
I don’t think the problem is about pickling. The |
I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the For instance the following fails with the import pymc as pm
import cloudpickle
import multiprocessing
def use_logp_func(pickled_model):
model = cloudpickle.loads(pickled_model)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
print(func(1.0))
if __name__ == "__main__":
with pm.Model() as model:
def logp(value):
return -(value**2)
pm.DensityDist("x", logp=logp)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
pickled_model = cloudpickle.dumps(model)
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(target=use_logp_func, args=(pickled_model,))
process.start()
process.join() |
I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively.
I’m not sure if these two methods can cover all use patterns though. |
Alternatively we could pass the functions needed to each process which is more like what This also avoids recompiling the same functions multiple times? |
@lucianopaz |
I started work on
|
I think we should explore an alternative where we compile the functions SMC needs and fork afterwards like pm.sample does. This approach seems more brittle? It would also avoid re-compiling the same functions in each chain |
I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way. |
I think this limitation is likely deeper than what you're addressing here. As @aseyboldt and @lucianopaz mentioned we're using dynamic dispatching as a recurring theme in our codebase and pytensor's However, I don't agree with their solutions |
Using the class that's being dispatched to register the dispatches during pickling seems at odds with the point of dispatching. The class shouldn't have to know what's being dispatched upon. For instance we also have icdf methods, what if someone dispatched on it from the outside, does pickling work for it? Or would the setstate/getstate need to know about icdf (as well any other dispatch that may not even be part of PyMC)? It's also not a PyMC model responsibility. CustomDist can be defined just fine outside of a model |
Thank you for taking the time to explain. I'll start working on the compilation approach. |
However since this fixes existing behavior I think we can go ahead and merge it as a temporary patch? |
Co-authored-by: Ricardo Vieira <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you open a follow up issue for the compiling once and then forking?
I started working on compiling the functions in the main process. Should I close this or is there any chance of this getting merged before the newer PR? |
@ricardoV94 @lucianopaz @aseyboldt @twiecki Just following up on this. Should I close the PR? |
@EliasRas I had approved this PR but didn't merge. Apologies, did it just now |
Description
Currently
sample_smc
can fail due to aNotImplementedError
if it's used with a model defined usingCustomDist
. If aCustomDist
is used withoutdist
parameter, the overloads for_logprob
,_logcdf
and_support_point
are registered only in the main process.This PR adds an initializer which registers the overloads in the worker processes of the pool used in
sample_smc
.Related Issue
pymc.sample_smc
fails withpymc.CustomDist
#7224Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/