Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to calculate the metric Pixel-MSE? #141

Open
lihuining opened this issue Nov 3, 2024 · 3 comments
Open

How to calculate the metric Pixel-MSE? #141

lihuining opened this issue Nov 3, 2024 · 3 comments

Comments

@lihuining
Copy link

No description provided.

@williamyang1991
Copy link
Owner

williamyang1991 commented Nov 4, 2024

Assume I1 and I2 are the two consecutive frames, and the edited results are O1 and O2.

Use the following code to calculate the optical flow between I1 and I2, and use the flow to warp O1 to align with O2.

warped_O1, mask, optical_flow =  get_warped_and_mask(flow_model,  I1,  I2, O1)

def get_warped_and_mask(flow_model,
image1,
image2,
image3=None,
pixel_consistency=False):
if image3 is None:
image3 = image1
padder = InputPadder(image1.shape, padding_factor=8)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
results_dict = flow_model(image1,
image2,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1],
pred_bidir_flow=True)
flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W]
bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W]
fwd_occ, bwd_occ = forward_backward_consistency_check(
fwd_flow, bwd_flow) # [1, H, W] float
if pixel_consistency:
warped_image1 = flow_warp(image1, bwd_flow)
bwd_occ = torch.clamp(
bwd_occ +
(abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0,
1).unsqueeze(0)
warped_results = flow_warp(image3, bwd_flow)
return warped_results, bwd_occ, bwd_flow

Then compute the MSE between the warped O1 and O2

err = F.mse_loss(warped_O1*(1-mask), O2*(1-mask))

And average the err over all consecutive frames in a video

@lihuining
Copy link
Author

lihuining commented Nov 4, 2024

@williamyang1991 I use this code to calculate the mse error, and select two consecutive frames to test, but I get the err is about 450. The error reported in the paper is lower than 1, so what is the error in the code?

import sys
sys.path.append("..")
from deps.gmflow.gmflow.gmflow import GMFlow
from deps.ControlNet.annotator.util import HWC3
import cv2
from PIL import Image
flow_model = GMFlow(
    feature_channels=128,
    num_scales=1,
    upsample_factor=8,
    num_head=1,
    attention_type='swin',
    ffn_dim_expansion=4,
    num_transformer_layers=6,
).to('cuda')
def preprocess(image_path):
    frame = cv2.imread(image_path)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = HWC3(frame)
    image2 = torch.from_numpy(img).permute(2, 0, 1).float()
    return image2
def save_tensor(input_tensor):
    bwd_occ_np = input_tensor.squeeze().detach().cpu().numpy()
    bwd_occ_np = (bwd_occ_np - bwd_occ_np.min()) / (bwd_occ_np.max() - bwd_occ_np.min()) * 255
    bwd_occ_np = bwd_occ_np.astype(np.uint8)
    image = Image.fromarray(bwd_occ_np)
    image.save('bwd_occ_mask.png')

checkpoint = torch.load('/media/allenyljiang/564AFA804AFA5BE51/Codes/Video_Editing/Rerender_A_Video/models/gmflow_sintel-0c07dcb3.pth',
                        map_location=lambda storage, loc: storage)
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
flow_model.load_state_dict(weights, strict=False)
flow_model.eval()

path1 = "/media/allenyljiang/564AFA804AFA5BE51/Codes/Video_Editing/Rerender_A_Video/flow/test_video/000000.png"
path2 = "/media/allenyljiang/564AFA804AFA5BE51/Codes/Video_Editing/Rerender_A_Video/flow/test_video/000001.png"
I1 =preprocess(path1)
I2 = preprocess(path2)
O1 = I1.unsqueeze(0).to("cuda")
print(I1.shape,I2.shape,O1.shape)
warped_O1, mask, optical_flow =  get_warped_and_mask(flow_model,  I1,  I2, O1)
save_tensor(mask)
err = F.mse_loss(warped_O1*(1-mask), O1*(1-mask)) # 然后在整个视频上平均
print(err)

@williamyang1991
Copy link
Owner

sorry,
should use O2 rather than O1 here

err = F.mse_loss(warped_O1*(1-mask), O2*(1-mask))

And your preprocess() outputs images with range [0,255],
but we use images with range [-1,1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants