Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vanilla HMC method #75

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

add vanilla HMC method #75

wants to merge 6 commits into from

Conversation

master
Copy link
Contributor

@master master commented May 29, 2023

Add full-batch Hamiltonian Monte Carlo implementation.

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

momentum, _ = jax.flatten_util.ravel_pytree(momentum)
kinetic = 0.5 * jnp.dot(momentum, momentum)
hamiltonian = kinetic + state.log_prob
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, you can avoid the minimum and the exponential here. You can define

log_accept_ratio = hamiltonian - state.hamiltonian

See later for the accept/reject part.

return revert_updates, state.params, state.hamiltonian

updates, new_params, new_hamiltonian = jax.lax.cond(
jax.random.uniform(uniform_key) < accept_prob,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the comment above, this line should become

jnp.log(jax.random.uniform(uniform_key)) < log_accept_ratio.

This is equivalent to what you have written but with one operation less. Alternatively, notice that -log(U) ~ Exponential(1)) if U~Uniform(0, 1). This means that you can also write

-jax.random.exponential(uniform_key)) < log_accept_ratio.

All of these should be equivalent. Please check that the lines I wrote are correct :-)

"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("HMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting to see the stored _hamiltonian here too?

**kwargs,
)
state = state.replace(
opt_state=state.opt_state._replace(log_prob=aux["loss"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should opt_state be added to the parameters of HMCState?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants