Skip to content

Commit

Permalink
update meta trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
wj-Mcat committed May 5, 2022
1 parent 1aa61d6 commit b2f5e44
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 107 deletions.
6 changes: 3 additions & 3 deletions PaddleFSL/examples/optim/anil_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def init_models(config: Config):

config = Config().parse_args(known_only=True)
config.device = 'gpu'
config.k_shot = 5
config.k_shot = 1

# config.dataset = 'omniglot'
# config.dataset = 'miniimagenet'
config.dataset = 'cifarfs'
config.dataset = 'miniimagenet'
# config.dataset = 'cifarfs'
# config.dataset = 'fc100'
# config.dataset = 'cub'

Expand Down
62 changes: 62 additions & 0 deletions PaddleFSL/examples/optim/anil_text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""ANIL example for optimization"""
from __future__ import annotations
import os
import paddle
from paddle import nn
from paddle.optimizer import Adam
import paddlefsl
from paddlefsl.metaopt.anil import ANILLearner
from paddlenlp.transformers.ernie.modeling import ErnieModel
from paddlenlp.transformers.ernie.tokenizer import ErnieTokenizer

from examples.optim.meta_trainer import Config, Trainer, load_datasets

class SequenceClassifier(nn.Layer):
"""Sequence Classifier"""
def __init__(self, hidden_size: int, output_size: int, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, output_size)

def forward(self, embedding):
"""handle the main logic"""
embedding = self.dropout(embedding)
logits = self.classifier(embedding)
return logits


if __name__ == '__main__':

config = Config().parse_args(known_only=True)
config.device = 'gpu'

train_dataset = paddlefsl.datasets.few_rel.FewRel('train')
valid_dataset = paddlefsl.datasets.few_rel.FewRel('valid')
test_dataset = paddlefsl.datasets.few_rel.FewRel('valid')

config.tracking_uri = os.environ.get('TRACKING_URI', None)
config.experiment_id = os.environ.get('EXPERIMENT_ID', None)

tokenzier = ErnieTokenizer.from_pretrained('ernie-1.0')
feature_model, head_layer = ErnieModel.from_pretrained('ernie-1.0'), SequenceClassifier(hidden_size=768, output_size=config.n_way)

criterion = nn.CrossEntropyLoss()
learner = ANILLearner(
feature_model=feature_model,
head_layer=head_layer,
learning_rate=config.inner_lr,
)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.meta_lr, T_max=config.epochs)
optimizer = Adam(parameters=learner.parameters(), learning_rate=scheduler)
trainer = Trainer(
config=config,
train_dataset=train_dataset,
dev_dataset=valid_dataset,
test_dataset=test_dataset,
learner=learner,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
tokenizer=tokenzier
)
trainer.train()
11 changes: 6 additions & 5 deletions PaddleFSL/examples/optim/maml_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ def init_models(config: Config):

config = Config().parse_args(known_only=True)
config.device = 'gpu'
# config.dataset = 'omniglot'
# config.dataset = 'miniimagenet'
# config.dataset = 'cifarfs'
# config.dataset = 'fc100'
config.dataset = 'cub'
if not config.dataset:
# config.dataset = 'omniglot'
# config.dataset = 'miniimagenet'
config.dataset = 'cifarfs'
# config.dataset = 'fc100'
# config.dataset = 'cub'

config.tracking_uri = os.environ.get('TRACKING_URI', None)
config.experiment_id = os.environ.get('EXPERIMENT_ID', None)
Expand Down
151 changes: 52 additions & 99 deletions PaddleFSL/examples/optim/meta_trainer.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
"""MAML example for optimization"""
from __future__ import annotations
import os
from sched import scheduler
from typing import Optional, Tuple
import warnings
import math
from loguru import logger

import paddle
from paddle import nn
from paddle.optimizer import Adam, Optimizer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddle.nn import Layer
from paddle.metric.metrics import Accuracy
from tap import Tap
from tqdm import tqdm
from tabulate import tabulate
from mlflow.tracking import MlflowClient
from mlflow import set_tag
import numpy as np

import paddlefsl
from paddlefsl.datasets.cv_dataset import CVDataset
from paddlefsl.backbones.conv import ConvBlock
from paddlefsl.metaopt.base_learner import BaseLearner
from paddlefsl.metaopt.anil import ANILLearner

"""Data Utils for Meta Optimzations Algorithms"""
from typing import Tuple, Dict
import paddlefsl
from paddlefsl.datasets.cv_dataset import CVDataset

Expand Down Expand Up @@ -76,7 +67,7 @@ def load_datasets(name: str) -> Tuple[CVDataset, CVDataset, CVDataset]:

class Config(Tap):
"""Alernative for Argument Parse"""
dataset: str = 'omniglot'
dataset: str = ''
input_size: Optional[str] = None
n_way: int = 5
k_shot: int = 1
Expand Down Expand Up @@ -160,8 +151,6 @@ def __init__(

) -> None:
self.config = config

logger.info(self.config)

self.train_dataset = train_dataset
self.dev_dataset = dev_dataset
Expand Down Expand Up @@ -197,6 +186,11 @@ def __init__(
self.client.log_param(self.config.run_id, 'learner', value=self.learner.__class__.__name__)
set_tag("mlflow.runName", self.learner.__class__.__name__)

learner_name = learner.__class__.__name__
file_name = f'{config.dataset}-{learner_name}-{config.run_id}.log'
logger.add(os.path.join('logs', file_name))
logger.info(self.config)

def _set_device(self):
paddle.device.set_device(self.config.device)
if paddle.distributed.get_world_size() > 1:
Expand All @@ -208,63 +202,32 @@ def on_train_epoch_end(self):

self.train_bar.update()

bar_info = f'Epoch: {self.context.epoch}/{self.config.epochs} \t dev-loss: {self.context.dev_loss} \t\t dev-acc: {self.context.dev_acc}'
bar_info = f'Epoch: {self.context.epoch}/{self.config.epochs} \t train-loss: {self.context.train_loss} \t\t train-acc: {self.context.train_acc}'
self.train_bar.set_description(bar_info)

if self.config.tracking_uri:
self.client.log_metric(self.config.run_id, key='train-loss', value=self.context.train_loss)
self.client.log_metric(self.config.run_id, key='train-acc', value=self.context.train_acc)

self.client.log_metric(self.config.run_id, key='dev-loss', value=self.context.dev_loss)
self.client.log_metric(self.config.run_id, key='dev-acc', value=self.context.dev_acc)

def fast_adapt(self, task, learner: BaseLearner):
"""make inner loop fast adaption based on the task
Args:
task (_type_): which contains the support_set & query_set data
learner (BaseLearner): the meta optimization algriothm, which contains the model
Returns:
Tuples[Tensor, Tensor]: the loss and acc of the inner loop
"""
support_set, support_set_labels = paddle.to_tensor(task.support_data, dtype='float32'), paddle.to_tensor(task.support_labels, dtype='int64')
query_set, query_set_labels = paddle.to_tensor(task.query_data, dtype='float32'), paddle.to_tensor(task.query_labels, dtype='int64')

# handle the fast adaption in few dataset
for _ in range(self.config.train_inner_adapt_steps):
logits = learner(support_set)
loss = self.criterion(logits, support_set_labels)
learner.adapt(loss)

# evaluate the model on query set data
logits = learner(query_set)
val_loss = self.criterion(logits, query_set_labels)
val_acc = self.metric.compute(logits, query_set_labels)
val_acc = self.metric.update(val_acc)
return val_loss, val_acc

def compute_loss(self, input_data, labels, learner: BaseLearner, inner_steps: int = 1):
def compute_loss(self, input_data, labels, learner: BaseLearner):
"""compute the loss based on the input_data and labels"""
input_data, labels = paddle.to_tensor(input_data, dtype='float32'), paddle.to_tensor(labels, dtype='int64')

all_loss, all_acc = 0.0, 0.0
for _ in range(inner_steps):
logits = learner(input_data)
loss = self.criterion(logits, labels)
all_loss += loss
logits = learner(input_data)
loss = self.criterion(logits, labels)

acc = self.metric.compute(logits, labels)
all_acc += self.metric.update(acc)
acc = self.metric.compute(logits, labels)
acc = self.metric.update(acc)

return all_loss / inner_steps, all_acc / inner_steps
return loss, acc

def train_epoch(self):
"""train one epoch"""
self.learner.train()

self.context.train_loss = 0

train_loss, train_acc = 0, 0
val_loss, val_acc = 0, 0

self.metric.reset()
self.optimizer.clear_grad()
Expand All @@ -276,32 +239,24 @@ def train_epoch(self):
learner = self.learner.clone()

# inner loop
inner_loss, inner_acc = self.compute_loss(
task.support_data, task.support_labels,
learner,
inner_steps=self.config.train_inner_adapt_steps
)
learner.adapt(inner_loss)
train_loss += inner_loss.numpy()[0]
train_acc += inner_acc
for _ in range(self.config.train_inner_adapt_steps):
inner_loss, _ = self.compute_loss(
task.support_data, task.support_labels, learner
)
learner.adapt(inner_loss)

# outer loop: compute loss on the validation dataset
val_loss_, val_acc_ = self.compute_loss(
task.query_data,
task.query_labels,
learner,
inner_steps=self.config.train_inner_adapt_steps
loss, acc = self.compute_loss(
task.query_data, task.query_labels, learner
)
val_loss += val_loss_
val_acc += val_acc_

self.context.train_loss, self.context.train_acc = train_loss / self.config.meta_batch_size, train_acc / self.config.meta_batch_size
self.context.dev_loss, self.context.dev_acc = val_loss.numpy()[0] / self.config.meta_batch_size, val_acc / self.config.meta_batch_size
train_loss += loss
train_acc += acc

self.optimizer.clear_grad()
val_loss.backward()
train_loss.backward()
self.optimizer.step()
self.scheduler.step()
self.context.train_loss, self.context.train_acc = train_loss.numpy()[0] / self.config.meta_batch_size, train_acc / self.config.meta_batch_size

def eval(self, dataset: CVDataset, learner: BaseLearner, mode: str = 'dev'):
"""eval the model on the dataset
Expand All @@ -313,9 +268,9 @@ def eval(self, dataset: CVDataset, learner: BaseLearner, mode: str = 'dev'):
"""
logger.info(f'start doing {mode} on the dataset ...')
eval_bar = tqdm(total=self.config.test_epoch, desc=f'{mode} Bar')
test_loss, test_acc = 0, []
test_loss, test_acc = [], []
for _ in range(self.config.test_epoch):
epoch_acc = 0.0
val_loss, val_acc = 0.0, 0.0
for _ in range(self.config.test_batch_size):

task = dataset.sample_task_set(
Expand All @@ -325,40 +280,38 @@ def eval(self, dataset: CVDataset, learner: BaseLearner, mode: str = 'dev'):
learner = self.learner.clone()

# inner loop
inner_loss, inner_acc = self.compute_loss(
task.support_data, task.support_labels,
learner,
inner_steps=self.config.test_inner_adapt_steps
)
learner.adapt(inner_loss)
for _ in range(self.config.test_inner_adapt_steps):
inner_loss, _ = self.compute_loss(
task.support_data, task.support_labels, learner
)
learner.adapt(inner_loss)

# outer loop: compute loss on the validation dataset
_, val_acc_ = self.compute_loss(
task.query_data,
task.query_labels,
learner,
inner_steps=self.config.test_inner_adapt_steps
loss, acc = self.compute_loss(
task.query_data, task.query_labels, learner,
)
epoch_acc += val_acc_
val_loss += loss.numpy()[0]
val_acc += acc

test_acc.append(val_acc / self.config.test_batch_size)
test_loss.append(val_loss / self.config.test_batch_size)

test_acc.append(epoch_acc / self.config.test_batch_size)
eval_bar.update()
eval_bar.set_description(
f'acc {test_acc[-1]:.6f}'
)
result = [
test_loss / self.config.test_epoch * self.config.test_batch_size,
min(test_acc),
sum(test_acc) / len(test_acc),
max(test_acc)
]
logger.success(test_acc)
logger.success("\n" + tabulate([result], ['loss', 'min-acc', 'mean-acc', 'max-acc'], tablefmt="grid"))
mean_loss, std_loss = np.mean(test_loss), np.std(test_loss)
mean_acc, std_acc = np.mean(test_acc), np.std(test_acc)

logger.success(f'======================Epoch: {self.context.epoch}/{self.config.epochs}-{mode}======================')
logger.success(f'mean-loss: {mean_loss:.6f}, std-loss: {std_loss:.6f}')
logger.success(f'mean-acc: {mean_acc:.6f}, std-acc: {std_acc:.6f}')
logger.success('================================================')
if self.config.tracking_uri:
self.client.log_metric(self.config.run_id, f'{mode}-loss', result[0])
self.client.log_metric(self.config.run_id, f'{mode}-min-acc', result[1])
self.client.log_metric(self.config.run_id, f'{mode}-mean-acc', result[2])
self.client.log_metric(self.config.run_id, f'{mode}-max-acc', result[3])
self.client.log_metric(self.config.run_id, f'{mode}-mean-loss', mean_loss)
self.client.log_metric(self.config.run_id, f'{mode}-std-loss', std_loss)
self.client.log_metric(self.config.run_id, f'{mode}-mean-acc', mean_acc)
self.client.log_metric(self.config.run_id, f'{mode}-std-acc', std_acc)

def train(self):
"""handle the main train"""
Expand Down

0 comments on commit b2f5e44

Please sign in to comment.