Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoints for CRIS #44

Open
berkegokmen1 opened this issue Aug 3, 2024 · 0 comments
Open

Checkpoints for CRIS #44

berkegokmen1 opened this issue Aug 3, 2024 · 0 comments

Comments

@berkegokmen1
Copy link

Hi,

Firstly, thank you for this work. I've been trying to use your work in order to segment the diseases from X-Ray images. I've tried it with "pretrained" checkpoints of CLIP and CRIS. Using the following script:

prompts = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Pleural Effusion",
]


clip_pretrain = "pretrained/RN50.pt"
word_len = 77
fpn_in = [512, 1024, 1024]
fpn_out = [256, 512, 1024]
vis_dim = 512
word_dim = 1024
num_layers = 3
num_head = 8
dim_ffn = 2048
dropout = 0.2
context_length = 77  # 77 for clipseg
intermediate = False
cris_pretrain = "pretrained/cris.pt"
tokenizer_type = "clipseg"
img_mean = [0.48145466, 0.4578275, 0.40821073]
img_std = [0.26862954, 0.26130258, 0.27577711]


prompts = [f"Findings of {el}" for el in prompts] + ["Support devices"]

model = CRIS(
    clip_pretrain=clip_pretrain,
    word_len=word_len,
    fpn_in=fpn_in,
    fpn_out=fpn_out,
    vis_dim=vis_dim,
    word_dim=word_dim,
    num_layers=num_layers,
    num_head=num_head,
    dim_ffn=dim_ffn,
    dropout=dropout,
    intermediate=intermediate,
    cris_pretrain=cris_pretrain,
)

model.to(device)

if tokenizer_type == "biomedclip":
    tokenizer = open_clip.get_tokenizer("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224").tokenizer
else:  # ie. tokenizer_type == "clipseg":
    tokenizer = CLIPTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")

# Load image and text
transorm = T.Compose([T.Resize((416, 416)), T.ToTensor(), T.Normalize(mean=img_mean, std=img_std)])

image = Image.open(
    "image.png"
)

pixel_values = transorm(image).unsqueeze(0)
input_ids = tokenizer(
    prompts,
    max_length=context_length,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
).input_ids

print(pixel_values.shape, input_ids.shape)

pixel_values = pixel_values.to(device).expand(len(prompts), -1, -1, -1)
input_ids = input_ids.to(device)

out = model(pixel_values, input_ids)

out = out.squeeze(1)

_, ax = plt.subplots(1, len(prompts) + 1, figsize=(15, 4))
[a.axis("off") for a in ax.flatten()]
ax[0].imshow(image)
[ax[i + 1].imshow(torch.sigmoid(out[i]).cpu().detach().numpy()) for i in range(len(prompts))]
[ax[i + 1].text(0, -15, prompts[i]) for i in range(len(prompts))]

plt.savefig("result.png")

This was just to try the model out. As expected, I got the following result:

Screenshot 2024-08-03 at 02 32 09

Which is not exactly what I needed. Thus, I wanted to ask if you are planning to share your trained models so that I can try them out to see whether they work better than pretrained densenet121+gradcam?

Thanks in advance and looking forward for your answer.

Berke

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant