-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9c890c3
Showing
13 changed files
with
907 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__ | ||
*.pyc | ||
vqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "resnet"] | ||
path = resnet | ||
url = https://github.com/Cyanogenoid/pytorch-resnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Strong baseline for visual question answering | ||
|
||
This is a re-implementation of Vahid Kazemi and Ali Elqursh's paper [Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering][0] in [PyTorch][1]. | ||
|
||
The paper shows that with a relatively simple model, using only common building blocks in Deep Learning, you can get better accuracies than the majority of previously published work on the popular [VQA v1][2] dataset. | ||
|
||
This repository is intended to provide a straightforward implementation of the paper for other researchers to build on. | ||
The results closely match the reported results, as the majority of details should be exactly the same as the paper. (Thanks to the authors for answering my questions about some details!) | ||
This implementation seems to consistently converge to about 0.1% better results, but I am not aware of what implementation difference is causing this. | ||
|
||
A fully trained model (convergence shown below) is [available for download][5]. | ||
|
||
![Graph of convergence of implementation versus paper results](http://i.imgur.com/moWYEm8.png) | ||
|
||
|
||
## Running the model | ||
|
||
- Clone this repository with: | ||
``` | ||
git clone https://github.com/Cyanogenoid/pytorch-vqa --recursive | ||
``` | ||
- Set the paths to your downloaded [questions, answers, and MS COCO images][4] in `config.py`. | ||
- `qa_path` should contain the files `OpenEnded_mscoco_train2014_questions.json`, `OpenEnded_mscoco_val2014_questions.json`, `mscoco_train2014_annotations.json`, `mscoco_val2014_annotations.json`. | ||
- `train_path`, `val_path`, `test_path` should contain the train, validation, and test `.jpg` images respectively. | ||
- Pre-process images (93 GiB of free disk space required for f16 accuracy) with [ResNet152 weights ported from Caffe][3] and vocabularies for questions and answers with: | ||
``` | ||
python preprocess-images.py | ||
python preprocess-vocab.py | ||
``` | ||
- Train the model in `model.py` with: | ||
``` | ||
python train.py | ||
``` | ||
This will alternate between one epoch of training on the train split and one epoch of validation on the validation split while printing the current training progress to stdout and saving logs in the `logs` directory. | ||
The logs contain the name of the model, training statistics, contents of `config.py`, model weights, evaluation information (per-question answer and accuracy), and question and answer vocabularies. | ||
- During training (which takes a while), plot the training progress with: | ||
``` | ||
python view-log.py <path to .pth log> | ||
``` | ||
|
||
|
||
## Python 3 dependencies (tested on Python 3.6.2) | ||
|
||
- torch | ||
- torchvision | ||
- h5py | ||
- tqdm | ||
|
||
|
||
|
||
[0]: https://arxiv.org/abs/1704.03162 | ||
[1]: https://github.com/pytorch/pytorch | ||
[2]: http://visualqa.org/ | ||
[3]: https://github.com/ruotianluo/pytorch-resnet | ||
[4]: http://visualqa.org/vqa_v1_download.html | ||
[5]: https://github.com/Cyanogenoid/pytorch-vqa/releases |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# paths | ||
qa_path = 'vqa' # directory containing the question and annotation jsons | ||
train_path = 'mscoco/train2014' # directory of training images | ||
val_path = 'mscoco/val2014' # directory of validation images | ||
test_path = 'mscoco/test2015' # directory of test images | ||
preprocessed_path = '/ssd/resnet-14x14.h5' # path where preprocessed features are saved to and loaded from | ||
vocabulary_path = 'vocab.json' # path where the used vocabularies for question and answers are saved to | ||
|
||
task = 'OpenEnded' | ||
dataset = 'mscoco' | ||
|
||
# preprocess config | ||
preprocess_batch_size = 64 | ||
image_size = 448 # scale shorter end of image to this size and centre crop | ||
output_size = image_size // 32 # size of the feature maps after processing through a network | ||
output_features = 2048 # number of feature maps thereof | ||
central_fraction = 0.875 # only take this much of the centre when scaling and centre cropping | ||
|
||
# training config | ||
epochs = 50 | ||
batch_size = 128 | ||
initial_lr = 1e-3 # default Adam lr | ||
lr_halflife = 50000 # in iterations | ||
data_workers = 8 | ||
max_answers = 3000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import json | ||
import os | ||
import os.path | ||
import re | ||
|
||
from PIL import Image | ||
import h5py | ||
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
|
||
import config | ||
import utils | ||
|
||
|
||
def get_loader(train=False, val=False, test=False): | ||
""" Returns a data loader for the desired split """ | ||
assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' | ||
split = VQA( | ||
utils.path_for(train=train, val=val, test=test, question=True), | ||
utils.path_for(train=train, val=val, test=test, answer=True), | ||
config.preprocessed_path, | ||
answerable_only=train, | ||
) | ||
loader = torch.utils.data.DataLoader( | ||
split, | ||
batch_size=config.batch_size, | ||
shuffle=train, # only shuffle the data in training | ||
pin_memory=True, | ||
num_workers=config.data_workers, | ||
collate_fn=collate_fn, | ||
) | ||
return loader | ||
|
||
|
||
def collate_fn(batch): | ||
# put question lengths in descending order so that we can use packed sequences later | ||
batch.sort(key=lambda x: x[-1], reverse=True) | ||
return data.dataloader.default_collate(batch) | ||
|
||
|
||
class VQA(data.Dataset): | ||
""" VQA dataset, open-ended """ | ||
def __init__(self, questions_path, answers_path, image_features_path, answerable_only=False): | ||
super(VQA, self).__init__() | ||
with open(questions_path, 'r') as fd: | ||
questions_json = json.load(fd) | ||
with open(answers_path, 'r') as fd: | ||
answers_json = json.load(fd) | ||
with open(config.vocabulary_path, 'r') as fd: | ||
vocab_json = json.load(fd) | ||
self._check_integrity(questions_json, answers_json) | ||
|
||
# vocab | ||
self.vocab = vocab_json | ||
self.token_to_index = self.vocab['question'] | ||
self.answer_to_index = self.vocab['answer'] | ||
|
||
# q and a | ||
self.questions = list(prepare_questions(questions_json)) | ||
self.answers = list(prepare_answers(answers_json)) | ||
self.questions = [self._encode_question(q) for q in self.questions] | ||
self.answers = [self._encode_answers(a) for a in self.answers] | ||
|
||
# v | ||
self.image_features_path = image_features_path | ||
self.coco_id_to_index = self._create_coco_id_to_index() | ||
self.coco_ids = [q['image_id'] for q in questions_json['questions']] | ||
|
||
# only use questions that have at least one answer? | ||
self.answerable_only = answerable_only | ||
if self.answerable_only: | ||
self.answerable = self._find_answerable() | ||
|
||
@property | ||
def max_question_length(self): | ||
if not hasattr(self, '_max_length'): | ||
self._max_length = max(map(len, self.questions)) | ||
return self._max_length | ||
|
||
@property | ||
def num_tokens(self): | ||
return len(self.token_to_index) + 1 # add 1 for <unknown> token at index 0 | ||
|
||
def _create_coco_id_to_index(self): | ||
""" Create a mapping from a COCO image id into the corresponding index into the h5 file """ | ||
with h5py.File(self.image_features_path, 'r') as features_file: | ||
coco_ids = features_file['ids'][()] | ||
coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} | ||
return coco_id_to_index | ||
|
||
def _check_integrity(self, questions, answers): | ||
""" Verify that we are using the correct data """ | ||
qa_pairs = list(zip(questions['questions'], answers['annotations'])) | ||
assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' | ||
assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' | ||
assert questions['data_type'] == answers['data_type'], 'Mismatched data types' | ||
assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' | ||
|
||
def _find_answerable(self): | ||
""" Create a list of indices into questions that will have at least one answer that is in the vocab """ | ||
answerable = [] | ||
for i, answers in enumerate(self.answers): | ||
answer_has_index = len(answers.nonzero()) > 0 | ||
# store the indices of anything that is answerable | ||
if answer_has_index: | ||
answerable.append(i) | ||
return answerable | ||
|
||
def _encode_question(self, question): | ||
""" Turn a question into a vector of indices and a question length """ | ||
vec = torch.zeros(self.max_question_length).long() | ||
for i, token in enumerate(question): | ||
index = self.token_to_index.get(token, 0) | ||
vec[i] = index | ||
return vec, len(question) | ||
|
||
def _encode_answers(self, answers): | ||
""" Turn an answer into a vector """ | ||
# answer vec will be a vector of answer counts to determine which answers will contribute to the loss. | ||
# this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up | ||
# to get the loss that is weighted by how many humans gave that answer | ||
answer_vec = torch.zeros(len(self.answer_to_index)) | ||
for answer in answers: | ||
index = self.answer_to_index.get(answer) | ||
if index is not None: | ||
answer_vec[index] += 1 | ||
return answer_vec | ||
|
||
def _load_image(self, image_id): | ||
""" Load an image """ | ||
if not hasattr(self, 'features_file'): | ||
# Loading the h5 file has to be done here and not in __init__ because when the DataLoader | ||
# forks for multiple works, every child would use the same file object and fail | ||
# Having multiple readers using different file objects is fine though, so we just init in here. | ||
self.features_file = h5py.File(self.image_features_path, 'r') | ||
index = self.coco_id_to_index[image_id] | ||
dataset = self.features_file['features'] | ||
img = dataset[index].astype('float32') | ||
return torch.from_numpy(img) | ||
|
||
def __getitem__(self, item): | ||
if self.answerable_only: | ||
# change of indices to only address answerable questions | ||
item = self.answerable[item] | ||
|
||
q, q_length = self.questions[item] | ||
a = self.answers[item] | ||
image_id = self.coco_ids[item] | ||
v = self._load_image(image_id) | ||
# since batches are re-ordered for PackedSequence's, the original question order is lost | ||
# we return `item` so that the order of (v, q, a) triples can be restored if desired | ||
# without shuffling in the dataloader, these will be in the order that they appear in the q and a json's. | ||
return v, q, a, item, q_length | ||
|
||
def __len__(self): | ||
if self.answerable_only: | ||
return len(self.answerable) | ||
else: | ||
return len(self.questions) | ||
|
||
|
||
# this is used for normalizing questions | ||
_special_chars = re.compile('[^a-z0-9 ]*') | ||
|
||
# these try to emulate the original normalisation scheme for answers | ||
_period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') | ||
_comma_strip = re.compile(r'(\d)(,)(\d)') | ||
_punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') | ||
_punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) | ||
_punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) | ||
|
||
|
||
def prepare_questions(questions_json): | ||
""" Tokenize and normalize questions from a given question json in the usual VQA format. """ | ||
questions = [q['question'] for q in questions_json['questions']] | ||
for question in questions: | ||
question = question.lower()[:-1] | ||
yield question.split(' ') | ||
|
||
|
||
def prepare_answers(answers_json): | ||
""" Normalize answers from a given answer json in the usual VQA format. """ | ||
answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] | ||
# The only normalisation that is applied to both machine generated answers as well as | ||
# ground truth answers is replacing most punctuation with space (see [0] and [1]). | ||
# Since potential machine generated answers are just taken from most common answers, applying the other | ||
# normalisations is not needed, assuming that the human answers are already normalized. | ||
# [0]: http://visualqa.org/evaluation.html | ||
# [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 | ||
|
||
def process_punctuation(s): | ||
# the original is somewhat broken, so things that look odd here might just be to mimic that behaviour | ||
# this version should be faster since we use re instead of repeated operations on str's | ||
if _punctuation.search(s) is None: | ||
return s | ||
s = _punctuation_with_a_space.sub('', s) | ||
if re.search(_comma_strip, s) is not None: | ||
s = s.replace(',', '') | ||
s = _punctuation.sub(' ', s) | ||
s = _period_strip.sub('', s) | ||
return s.strip() | ||
|
||
for answer_list in answers: | ||
yield list(map(process_punctuation, answer_list)) | ||
|
||
|
||
class CocoImages(data.Dataset): | ||
""" Dataset for MSCOCO images located in a folder on the filesystem """ | ||
def __init__(self, path, transform=None): | ||
super(CocoImages, self).__init__() | ||
self.path = path | ||
self.id_to_filename = self._find_images() | ||
self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order | ||
print('found {} images in {}'.format(len(self), self.path)) | ||
self.transform = transform | ||
|
||
def _find_images(self): | ||
id_to_filename = {} | ||
for filename in os.listdir(self.path): | ||
if not filename.endswith('.jpg'): | ||
continue | ||
id_and_extension = filename.split('_')[-1] | ||
id = int(id_and_extension.split('.')[0]) | ||
id_to_filename[id] = filename | ||
return id_to_filename | ||
|
||
def __getitem__(self, item): | ||
id = self.sorted_ids[item] | ||
path = os.path.join(self.path, self.id_to_filename[id]) | ||
img = Image.open(path).convert('RGB') | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
return id, img | ||
|
||
def __len__(self): | ||
return len(self.sorted_ids) | ||
|
||
|
||
class Composite(data.Dataset): | ||
""" Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """ | ||
def __init__(self, *datasets): | ||
self.datasets = datasets | ||
|
||
def __getitem__(self, item): | ||
current = self.datasets[0] | ||
for d in self.datasets: | ||
if item < len(d): | ||
return d[item] | ||
item -= len(d) | ||
else: | ||
raise IndexError('Index too large for composite dataset') | ||
|
||
def __len__(self): | ||
return sum(map(len, self.datasets)) |
Empty file.
Oops, something went wrong.