Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to test my trained model? #37

Open
FJNU-LWP opened this issue Oct 13, 2023 · 2 comments
Open

How to test my trained model? #37

FJNU-LWP opened this issue Oct 13, 2023 · 2 comments

Comments

@FJNU-LWP
Copy link

Hello, thank you very much for sharing this wonderful project!

I have now trained my own model and generated a .pth file using this code. How can I use this .pth file to test other data?

Looking forward to your response, and I would greatly appreciate it!

@call-me-akeiang
Copy link

Hello, have you solved this problem? @FJNU-LWP

@hemengfan2002
Copy link

hemengfan2002 commented Dec 23, 2024

import argparse

import torch
from torch import nn
from video_transformer import ViViT
from transformer import ClassificationHead
from data_trainer import KineticsDataModule
from tqdm import tqdm

def compute_top_k_accuracy(data_loader, model, cls_head, device, top_k=5):
top_1_correct = 0
top_5_correct = 0
total_samples = 0

 model.eval()
 cls_head.eval()

with torch.no_grad():
    for video, label in tqdm(data_loader, desc="Evaluating"):
        video = video.to(device)
        label = label.to(device)

        # Forward pass through the model
        logits = model(video)
        output = cls_head(logits)
        output = output.view(output.size(0), -1)  # Flatten logits

        # Calculate Top-k accuracy
        _, topk_preds = output.topk(top_k, dim=1, largest=True, sorted=True)

        # Calculate Top-1 and Top-5 accuracy
        top_1_correct += (topk_preds[:, 0] == label).sum().item()
        top_5_correct += (topk_preds == label.unsqueeze(1)).sum().item()

        total_samples += label.size(0)

top_1_accuracy = top_1_correct / total_samples
top_5_accuracy = top_5_correct / total_samples

return top_1_accuracy, top_5_accuracy

def parse_args():
parser = argparse.ArgumentParser(description='lr receiver')
# Common

parser.add_argument(
    '-objective', type=str, default='supervised',
    help='the learning objective from [mim, supervised]')
parser.add_argument(
    '-eval_metrics', type=str, default='finetune',
    help='the eval metrics choosen from [linear_prob, finetune]')
parser.add_argument(
    '-batch_size', type=int, default= 16,
    help='the batch size of data inputs')
parser.add_argument(
    '-num_workers', type=int, default=4,
    help='the num workers of loading data')

# Environment
parser.add_argument(
    '-gpus', nargs='+', type=int, default=-1,
    help='the avaiable gpus in this experiment')

# Data
parser.add_argument(
    '-num_class', type=int, default=400,
    help='the num class of dataset used')
parser.add_argument(
    '-num_samples_per_cls', type=int, default=10000,
    help='the num samples of per class')
parser.add_argument(
    '-img_size', type=int, default=224,
    help='the size of processed image')
parser.add_argument(
    '-num_frames', type=int, default=16,
    help='the mumber of frame sampling') 
parser.add_argument(
    '-frame_interval', type=int, default=16,
    help='the intervals of frame sampling')
parser.add_argument(
    '-data_statics', type=str, default='kinetics',
    help='choose data statics from [imagenet, kinetics, clip]')
parser.add_argument(
    '-val_data_path', type=str, default="/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics"
                                        "-400/raw/Kinetics-400/kinetics400_val_list_videos.txt",
    help='the path to val set')
parser.add_argument(
    '-test_data_path', type=str, default=None,
    help='the path to test set')
parser.add_argument(
    '-auto_augment', type=str, default=None,
    help='the used Autoaugment policy')
parser.add_argument(
    '-mixup', type=bool, default=False,
    help="""Whether or not to use multi crop.""")  
parser.add_argument(
    '-multi_crop', type=bool, default=False,
    help="""Whether or not to use multi crop.""")

# Model
parser.add_argument(
    '-arch', type=str, default='vivit',
    help='the choosen model arch from [timesformer, vivit]')
parser.add_argument(
    '-attention_type', type=str, default='fact_encoder',
    help='the choosen attention type using in model')
parser.add_argument(
    '-pretrain_pth', type=str, default="./vivit_model.pth",
    help='the path to the pretrain weights')
parser.add_argument(
    '-weights_from', type=str, default='kinetics',
    help='the pretrain params from [imagenet, kinetics, clip]')

args = parser.parse_args()

return args

def replace_state_dict(state_dict):
for old_key in list(state_dict.keys()):
if old_key.startswith('model'):
new_key = old_key[6:] # skip 'model.'
if 'in_proj' in new_key:
new_key = new_key.replace('in_proj_', 'qkv.') # in_proj_weight -> qkv.weight
elif 'out_proj' in new_key:
new_key = new_key.replace('out_proj', 'proj') # out_proj -> proj
state_dict[new_key] = state_dict.pop(old_key)
else: # cls_head
new_key = old_key[9:]
state_dict[new_key] = state_dict.pop(old_key)

def init_from_kinetics_pretrain_(module, pretrain_pth):
if torch.cuda.is_available():
state_dict = torch.load(pretrain_pth)
else:
state_dict = torch.load(pretrain_pth, map_location=torch.device('cpu'))
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']

replace_state_dict(state_dict)
msg = module.load_state_dict(state_dict, strict=False)
return msg

def test_vivit_model(test_checkpoint_path):
args = parse_args()
# Step 1: Load the pretrained model
model = ViViT(num_frames=args.num_frames,
img_size=args.img_size,
patch_size=16,
embed_dims=768,
in_channels=3,
attention_type=args.attention_type,
return_cls_token=True,
weights_from=args.weights_from,
pretrain_pth=args.pretrain_pth
)

# Load the pre-trained weights
cls_head = ClassificationHead(num_classes=args.num_class, in_channels=768)
msg_trans = init_from_kinetics_pretrain_(model, test_checkpoint_path)
msg_cls = init_from_kinetics_pretrain_(cls_head, test_checkpoint_path)

# model.eval()
# cls_head.eval()

print(f'Model loaded successfully. Missing keys (transformer): {msg_trans[0]}, (cls_head): {msg_cls[0]}')

# Step 2: Prepare the test dataset
val_data_module = KineticsDataModule(configs=args,
                                     train_ann_path=None,  # No need for training data
                                     val_ann_path=args.val_data_path,
                                     test_ann_path=None)
val_data_module.setup(stage='test') 
val_dataloader = val_data_module.val_dataloader()

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model) 
model = model.to(device)
cls_head = cls_head.to(device)
top_1_accuracy, top_5_accuracy = compute_top_k_accuracy(val_dataloader, model, cls_head, device, top_k=5)
print(f'Top-1 Accuracy: {top_1_accuracy * 100:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy * 100:.2f}%')

if name == 'main':
# Define paths and configurations
test_checkpoint_path = "/home/hmf/vivit_clip/results/lr_0.000625_optim_sgd_lr_schedule_cosine_weight_decay_0.0001_weights_from_imagenet_num_frames_16_frame_interval_16_mixup_False/ckpt/2024-12-17 19:12:49_ep_5_top1_acc_0.623.pth" # Path to the checkpoint
# test_data_path = "/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics-400/raw/Kinetics-400/kinetics400_val_list_videos.txt" # Path to the test dataset
"""
# Configuration settings for the model and training
configs = {
'num_frames': 16, # Adjust number of frames as needed
'img_size': 224,
'patch_size': 16,
'embed_dims': 768,
'in_channels': 3,
'attention_type': 'fact_encoder',
'return_cls_token': True,
'weights_from': 'kinetics',
'pretrain_pth': test_checkpoint_path,
'frame_interval': 16,
'num_class': 400,
'objective': 'supervised',
'data_statics': 'kinetics'
}
"""
# Call the function to test the model
test_vivit_model(test_checkpoint_path)
@FJNU-LWP @call-me-akeiang This is my test file. I hope it can help you. If there are errors, please point out. : )
By the way, what is the accuracy of the vivit-k400 model you trained in the validation set? My top1 accuracy is only 62%.
The lr is 5e-2, and the weight_decay is 0.0001. Other parameter settings are the same as those in the readme.md.
Can I have a look at your training log file? Or can you give me some advie? Looking forward to your response.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants