Skip to content

Commit

Permalink
refactor: pep8 continued
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasellinger committed May 2, 2024
1 parent d2fb3ac commit 92f4ed7
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 43 deletions.
2 changes: 2 additions & 0 deletions fetchers/wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class Wikipedia:
"""Wrapper for wikipedia api calls."""

USER_AGENT = 'summaryBot ([email protected])'
URL = "https://en.wikipedia.org/w/api.php"

Expand Down
Empty file removed formatter/__init__.py
Empty file.
5 changes: 0 additions & 5 deletions formatter/document_formatter.py

This file was deleted.

4 changes: 3 additions & 1 deletion germandpr_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Test for german dpr dataset."""

from datasets import load_dataset

# we can concatenate train + test since we do not train on
Expand Down Expand Up @@ -26,4 +28,4 @@ def create_fact(entry):
return entry

filtered_dataset = filtered_dataset.map(create_fact)
print('hi')
print('hi')
13 changes: 6 additions & 7 deletions losses/supcon.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Module for supervised contrastive loss.
ref: https://arxiv.org/abs/2004.11362"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor

from torch import nn
from torch.nn.functional import cosine_similarity

class SupConLoss(nn.Module):
"""Supervised Contrastive Loss."""
def __init__(self, temperature=0.5):
super(SupConLoss, self).__init__()
super().__init__()
self.temperature = temperature # higher temperature leads to lower loss

def forward(self, anchor: tensor, references: tensor, labels: tensor) -> tensor:
def forward(self, anchor: torch.tensor, references: torch.tensor, labels: torch.tensor) -> (
torch.tensor):
"""
Calculate the mean supervised contrastive loss over a batch. Each entry of the batch
can have a different amount of positives and negatives. These are marked with the labels.
Expand All @@ -24,7 +23,7 @@ def forward(self, anchor: tensor, references: tensor, labels: tensor) -> tensor:
"""
pos_count = torch.sum(torch.eq(labels, 1), dim=-1)

similarity = F.cosine_similarity(anchor, references, dim=-1) / self.temperature
similarity = cosine_similarity(anchor, references, dim=-1) / self.temperature
logits_max = torch.max(similarity, dim=-1)[0].detach()
similarity = similarity - logits_max.unsqueeze(-1)

Expand Down
2 changes: 1 addition & 1 deletion mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
13 changes: 6 additions & 7 deletions models/claim_verification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@
Input: Claim to verify and sentences which should be used to verify the claim.
Output: SUPPORTS | REFUTES | NOT ENOUGH INFO
"""
from datetime import datetime

import torch
from torch import nn


class ClaimVerificationModel(nn.Module):
"""Model to verify a claim. Hypotheses, Premise style."""

def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids=None, attention_mask=None,):
def forward(self, input_ids=None, attention_mask=None):
"""Forward function."""
return self.model(input_ids=input_ids, attention_mask=attention_mask)

def save(self, name):
timestamp = datetime.now().strftime("%m-%d_%H-%M")
model_path = f'{name}_{timestamp}.pth'
torch.save(self.state_dict(), model_path)
"""Stores the model."""
self.model.save_pretrained(f'{name}')
14 changes: 8 additions & 6 deletions models/evidence_selection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Input: Claim to verify and document which should be used for verifying.
Output: Sentence Embeddings.
"""
from datetime import datetime

import torch
from torch import nn
import torch.nn.functional as F


class EvidenceSelectionModel(nn.Module):
"""Model to compute sentence embeddings."""

def __init__(self, model, feed_forward=False, normalize_before_fc=True, out_features=256):
super().__init__()
self.model = model
Expand All @@ -22,22 +22,24 @@ def __init__(self, model, feed_forward=False, normalize_before_fc=True, out_feat
self.fc = nn.Linear(1024, out_features)

def forward(self, input_ids=None, attention_mask=None, sentence_mask=None):
"""Forward function."""
if sentence_mask is None:
# keep in mind that here cls and end token are inside the mask.
sentence_mask = attention_mask.unsqueeze(dim=1)

outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
outputs = self.model(input_ids=input_ids,
attention_mask=attention_mask)['last_hidden_state']

sentence_embeddings = self.sentence_mean_pooling(outputs, sentence_mask)
if self.feed_forward:
if self.normalize_before_fc:
sentence_embeddings = F.normalize(sentence_embeddings, dim=2)
return self.fc(sentence_embeddings)
else:
return sentence_embeddings
return sentence_embeddings

@staticmethod
def sentence_mean_pooling(model_output, sentence_mask):
"""Mean pooling of the embeddings of the sentences."""
token_embeddings = model_output.unsqueeze(1)

masks_size = sentence_mask.count_nonzero(dim=-1)
Expand All @@ -48,7 +50,7 @@ def sentence_mean_pooling(model_output, sentence_mask):
return sentence_embeddings

def save(self, name):
# model_path = f'{name}.pth'
"""Stores the model."""
if self.feed_forward:
torch.save(self.fc.state_dict(), f'{name}_fc.pth')
else:
Expand Down
10 changes: 5 additions & 5 deletions pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,20 @@ def fetch_evidence(self, word: str) -> list[list[str]]:
def select_evidence(self, claim: str, evidence_list: list[list[str]], top_k=3,
max_evidence_count=3) -> list[str]:
if len(evidence_list) > max_evidence_count:
evidence_txts = [" ".join(txt) for txt in evidence_list]
ranked_indices = rank_docs(claim, evidence_txts, k=max_evidence_count)
ranked_indices = rank_docs(claim, [" ".join(txt) for txt in evidence_list],
k=max_evidence_count)
evidence_list = [evidence_list[i] for i in ranked_indices]

sentence_similarities = []
for sentences in evidence_list:
claim_model_input, sentences_model_input = self.build_selection_model_input(claim,
sentences)
claim_model_input, sentences_model_input = self._build_selection_model_input(claim,
sentences)
with torch.no_grad():
claim_embedding = self.selection_model(**claim_model_input)
sentence_embeddings = self.selection_model(**sentences_model_input)
claim_similarities = cosine_similarity(claim_embedding,
sentence_embeddings, dim=2).tolist()[0]
sentence_similarity = [(x, y) for x, y in zip(sentences, claim_similarities)]
sentence_similarity = list(zip(sentences, claim_similarities))
sentence_similarities.extend(sentence_similarity)

sorted_sentences = sorted(sentence_similarities, key=lambda x: x[1], reverse=True)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ numpy~=1.26.4
matplotlib~=3.8.4
peft~=0.10.0
requests~=2.31.0
beautifulsoup4~=4.12.2
beautifulsoup4~=4.12.2
rank_bm25~=0.2.2
3 changes: 1 addition & 2 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime

import torch
import gc
import numpy as np
import transformers

Expand All @@ -11,7 +10,7 @@
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
from transformers import AutoTokenizer, BigBirdModel, get_linear_schedule_with_warmup, AutoModel
from transformers import AutoTokenizer, BigBirdModel
from matplotlib import pyplot as plt
from torch.cuda.amp import GradScaler, autocast

Expand Down
15 changes: 7 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def rank_docs(query: str, docs: List[str], k=5, get_indices=True) -> List[str] |
:return: List of most similar documents.
"""
def preprocess(txt: str):
return txt.lower()
return txt.lower()

query = preprocess(query)
docs = [preprocess(doc) for doc in docs]
Expand All @@ -36,8 +36,7 @@ def preprocess(txt: str):
if get_indices:
scores = np.array(bm25.get_scores(query.split(" ")))
return np.flip(np.argsort(scores)[-k:]).tolist()
else:
return bm25.get_top_n(query.split(" "), docs, k)
return bm25.get_top_n(query.split(" "), docs, k)


def calc_bin_stats(gt_labels: List, pr_labels: List, values: List) -> Dict:
Expand Down Expand Up @@ -65,11 +64,11 @@ def calc_bin_stats(gt_labels: List, pr_labels: List, values: List) -> Dict:
bin_gt_labels = gt_labels[bin_mask]

if len(bin_pr_labels) > 0:
acc = accuracy_score(bin_gt_labels, bin_pr_labels)
f1_weighted = f1_score(bin_gt_labels, bin_pr_labels, average='weighted')
f1_macro = f1_score(bin_gt_labels, bin_pr_labels, average='macro')
bin_stats[bin_upper] = {'acc': acc, 'f1_weighted': f1_weighted, 'f1_macro': f1_macro}

bin_stats[bin_upper] = {'acc': accuracy_score(bin_gt_labels, bin_pr_labels),
'f1_weighted': f1_score(bin_gt_labels, bin_pr_labels,
average='weighted'),
'f1_macro': f1_score(bin_gt_labels, bin_pr_labels,
average='macro')}
return bin_stats


Expand Down

0 comments on commit 92f4ed7

Please sign in to comment.