Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argmax problem in DDIM_sample #1

Open
JaejunHwang opened this issue Oct 18, 2024 · 4 comments
Open

argmax problem in DDIM_sample #1

JaejunHwang opened this issue Oct 18, 2024 · 4 comments
Labels
fixed addressed issue question Further information is requested

Comments

@JaejunHwang
Copy link

Hi, I am trying to reproduce your code. However, I found a weird point.
image

with this code, mask_logits.shape = [B, 1, H, W]. Channel is already 1, so after argmax there are only 0.

@haipengzhou856
Copy link
Owner

Hi Jaejun!
In this snippet (ddim_sample_one_frame), it is only used for sampling a single frame. Thus, the B (a.k.a, single frame) is always equal to 1 by default. In terms of the channel, our task involves binary segmentation, which results in the channel becoming 1.

Hence, you can understand mask_pred = torch.argmax(mask_logit, dim=1) == mask_pred=mask_logit.squeeze(1) (squeeze the channel). It does nothing but squeeze. But it will make sense if we deploy our method for other multi-classes VOS task. You may also refer to DDP(ICCV23) and their code at line 270 to check the implementation.

Please let me know if you have other problems.

Cheers, :)
Rydeen

@haipengzhou856 haipengzhou856 added the question Further information is requested label Oct 23, 2024
@JaejunHwang
Copy link
Author

Thank you for your reply. I rechecked DDP and your code. But I still have questions.
In your code, a shape of 'mask_logit' is [batch=1, 1, H, W].
image

Then Argmax function makes all values to 0, because of only one channel.
image

Final prediction is decided by 'mask_logit', so the model could make some performance, however I am not sure about the diffusion working.

image

This is the output of mask_pred of prior and present step.

  • Is there any reason to make single batch on sampling? (DDP also looks same)

@haipengzhou856
Copy link
Owner

Hi Jaejun!

I tidy up your questions and sum them up as follows. Hope it is helpful!

Q1: why remain the mask_pred = torch.argmax(mask_logit, dim=1)

It aims to be in line with the DDP setting and for future works. True, it always sets mask_pred to become ALL 0. While, if we turn to multiclass VOS, it would be changed. It makes nothing for binary VOS task, but makes sense for multiclass VOS. I will try it in the future as I indicate in TODO and I will arxiv a technical report.

Q2: How it works in Diffusion?

You can find out that mask_pred and mask_t control the SNR, a.k.a., the degree of noise. Refer to the algorithm 2 flowchart in DDP(iccv23) and pix2seq(iccv23), and the sampling code is corresponding to that process.

An initial implementation of Diffusion for segmentation is originated in pix2seq(iccv23), which is based on BiT Diffusion (Tensorflow version) (proposed by Hinton's group). And lucidrains reproduces it in a pytorch version. DDP's codes and mine are on top of it.

Q3: why only sample a single batch? why always set batch_size=1?

No specific reason. I've also tried to increase the batch size, but there are some bugs for the device conflicts. It works well when bs=1, and it's fair to compute FPS. Thus, I'm just lazy to figure it out :) Forgive me.

Feel free to ask me if you still have puzzles!

Cheers :)

Rydeen

@JaejunHwang
Copy link
Author

Thank you for your prompt reply!

Sincerely, JaejunHwang

@haipengzhou856 haipengzhou856 added the fixed addressed issue label Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fixed addressed issue question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants