Pytorch学习: Pytorch Lightning


Pytorch学习: Pytorch Lightning

Pytorch Lightning是在Pytorch基础上封装的框架, 号称”Pytorch里的Keras”, 如官网所述, 它具有灵活, 解耦, 易于复现, 自动化, 扩展性好等优点(实际上大多也是Keras的优点哈哈哈). 知乎上对Pytorch Lightning的议论比较多, 有些人认为Pytorch Lightning纯属过度封装, 但它事实上确实能解决一些Pytorch自身不好解决的问题. 最主要的其实是保证了代码复用, 节省时间.
和Huggingface出品的Trainer相比, 我感觉在大多数任务上, Pytorch Lightning要更加灵活一些.

Introduction

下面是一个官方给出的VAE在MNIST上的例子, 大概建立一下Pytorch Lightning的初印象:

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
          nn.Linear(28 * 28, 64),
          nn.ReLU(),
          nn.Linear(64, 3))
        self.decoder = nn.Sequential(
          nn.Linear(3, 64),
          nn.ReLU(),
          nn.Linear(64, 28 * 28))

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)    
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)

# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# model
model = LitAutoEncoder()

# training
trainer = pl.Trainer(gpus=4, num_nodes=8, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)

Pytorch Lightning最重要的两个API便是LightningModuleTrainer. pl.LightningModulenn.Module有点像对吧? 都有forward(). 没错, “A LightningModule is still just a torch.nn.Module“.
从代码里可以看出, 在pl.LightningModule下重写了training_step, validation_step, 完成模型训练和验证的内部流程即可, 整个训练的逻辑已经被它封装好了, 无需重写.
同时, DataLoader使用的是Pytorch自己的DataLoader, 二者兼容. Pytorch Lightning有对DataLoader在逻辑上的进一步封装, 方便组织数据的加载逻辑. 但是我自己用的不是很习惯, 本文中就不提及了, 感兴趣的去这里查阅.
通过trainer.fit就开始了模型的训练, 和Keras很像.
事实上, 整个pl.LightningModule只是组织了下列6种行为的逻辑:

  • Computations (init).
  • Train Loop (training_step)
  • Validation Loop (validation_step)
  • Test Loop (test_step)
  • Prediction Loop (predict_step)
  • Optimizers and LR Schedulers (configure_optimizers)
    记住, 它并没有做进一步抽象, 只是简单的把逻辑组织在一起.

Initialization & Forward

pl.LightningModule继承nn.Module, 也就是说你call它的时候会默认调用它的forward().

但是, forward的具体行为在Training和Validation甚至是Prediction的时候可能是不同的, 所以只能写模型自身的逻辑, 不要把Loss的计算也写进去, 也不要把logits.argmax写进去.
一般来说, pl.LightningModule的初始化和forward是这样写的:

class TaskModel(pl.LightningModule):  
    def __init__(self, model):  
        super().__init__()  
        self.model = model  

    def forward(self, **inputs):  
        return self.model(**inputs)

没错, 仅仅是将模型在pl.LightningModule初始化时作为参数传进来, 然后添加一个Hook… 就像这样. 强烈建议把模型本身和训练逻辑解耦, 将来改起来方便很多.

Training & Validation

Training

pl.LightningModule组织的训练逻辑伪代码如下:

def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()

似乎很多对吧, 事实上我们只需要关注下面两个函数:

  1. training_step(self, batch, batch_idx).
  2. training_epoch_end(self, training_step_outputs).

其他的函数在项目规模不大的时候不会用到, 例如on_train_epoch_end, on_train_batch_end, 看起来比较美好, 但是实际上有些鸡肋, 因为合适的逻辑已经在training_steptraining_epoch_end里搞定了, 而且它们不存在耦合问题.
我们使用Pytorch Lightning的目的就是为了快速搭建一套能跑的流程, 如果真的用到了再去查文档就好.

training_step

batch就是DataLoader里返回的batch, 一般来说training_step里就是把batch解包, 然后计算loss. 例如:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    return loss

返回值可以是loss, 也可以是一个字典, 如果你想在每个训练epoch结束的时候再计算点别的什么东西, 可以这样写:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {
        "loss": loss, 
        "other_stuff": preds,
    }

这样在training_epoch_end中可以取到other_stuff. 但是一定要保证里面有个loss, 这样才能保证整个batch正常工作.

training_epoch_end

在每个epoch训练结束时调用training_epoch_end, 其参数training_step_outputs实际上是每个step返回的字典的一个列表.
例如:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {"loss": loss, "other_stuff": preds}


def training_epoch_end(self, training_step_outputs):
    all_preds = torch.stack(training_step_outputs)
    ...

training_epoch_end无返回值限制.
例子中的preds应该也是一个Tensor, 我们也可以在每个step结束时返回其他类型的值.

log

在训练时一般都要把loss记录下来, 使用self.log()就可以把标量记录下来, 在其他地方也都可以随时使用. 例如:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

on_step即一个step记录一次, 如果也同时on_epoch, 它会将整个epoch的loss加起来求个平均, 在上述代码里同时记录了train_loss_steptrain_loss_epoch.
记录的值可以在Tensorboard里看到, 非常方便.

  • 如果有多个要记录的值, 可以把它们都放进一个字典里, 然后使用self.log_dict(dict)一并记录下来.
  • 如果要记录的内容是图像, 语音等其他类型, 则需要调用logger来存储, 从这里获取更多信息.

Validataion

验证和被包含在训练逻辑中, 但流程几乎是一样的, 只是少了梯度优化的参与.
pl.LightningModule组织的验证逻辑伪代码如下:

def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    val_outs = []
    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(batch, batch_idx)
        val_outs.append(out)

    validation_epoch_end(val_outs)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)

与训练不同的是, 在验证开始前, pl.LightningModule会自动为我们启用model.eval(), 还会禁用梯度. 可以不必重复声明torch.no_grad, 如果不放心的话可以再包上一层.
我们同样只需要关注与训练过程相似的两个函数:

  1. validation_step(self, batch, batch_idx).
  2. validation_epoch_end(self, validation_step_outputs).

validation_step

与训练中的training_step相同. 直接贴出一个例子:

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)

validation_epoch_end

与训练中的training_epoch_end相同, 这里拿到的validation_step_outputs也是每个validation_step的返回值的一个字典的列表. 例如:

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return pred


def validation_epoch_end(self, validation_step_outputs):
    all_preds = torch.stack(validation_step_outputs)
    ...

validation_epoch_end无返回值限制.

Optimizer & LR Scheduler

在文章最开始的例子中, 我们重写了configure_optimizers()来为模型准备优化器. 大多数时候我们只需要一个optimizer和scheduler:

def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)
# or
def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=1e-3)
    scheduler = get_linear_schedule_with_warmup(optimizer, self.total_step)
    return [optimizer], [scheduler]

如果只有optimizer, 直接返回即可, 如果还有scheduler, 则需要把optimizer和scheduler分别套上一个list返回.
同时, 在pl.LightningModule内部使用self.parameters()可以获得所有的模型参数, 因为它继承了nn.Module.
再复杂一点, 也可以通过返回字典来控制optimizer和scheduler执行的间隔(interval / frequency):

# example with step-based learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
    gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
                 'interval': 'step'}  # called after each training step
    dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sched, dis_sched]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

特殊情况会用到多个优化器或者多个Scheduler, 首先参考这里 , 并在training_step中使用optimizer_idx来控制loss和optimizer的关联, 参考这里.

Test & Predict

PL应该是为了满足客制化而将Test和Predict区分开. 在我们跑实验而没有部署时, Test和Predict行为并没有什么区别, 但测试和真正Inference的时候的Predict还是不一样的, Predict没有标签.
和验证时相同, model.eval()和torch.no_grad()会自动在测试和预测时自动配上.
当Trainer调用trainer.test()时, 会调用test_step(), 它与training_step, validation_step类似, 一般重写test_step时只是一层对validation_step的封装.
在测试结束时, 我比较推荐在test_step返回batch级的预测结果, test_epoch_end一并保存实验结果, 这样封装一层比较有意义.
Predict仅有predict_step, 而没有predict_epoch_end.

Trainer

pl.LightningModule组织了逻辑, 而pl.Trainer驱动了流程.
其拟合阶段伪代码如下:

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")
    configure_optimizers()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")

一般我们这样使用Trainer来完成包含测试在内的整个流程:

model = MyLightningModule()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(model, test_dataloader, ckpt="best")

对Trainer的简单用法说明如下:

  • 使用trainer.stage_name()可以让模型执行相应的阶段(fit, validate, test, predict).
    如果不主动调用trainer.test(), 则不会执行测试阶段.
  • trainer.validate(), trainer.predict(), 可以分别让模型执行验证和预测阶段, 前者被包含在模型的训练过程中, 无需重复调用.
  • 虽然官网有写trainer.test(), trainer.predict()会自动加载最好的模型检查点后再测试和预测, 但我实测的时候没有加载, 默认是使用最后一个epoch测试和预测. 而在设置ckpt_path="best"才会加载最好的模型, 否则是以最后一个epoch的模型进行测试的. 该参数在trainer.fit()中附加也可以让模型从该检查点开始训练.

Parameters

定义Trainer时有很多参数很好用, 在这里推荐一些.

  • max_epochs: 最大Epoch, 肯定要设置.
  • default_root_dir: 默认存储模型, 日志地址. 如果不设置, 每次跑实验时候都会多一个version_x文件夹, 看个人喜好和需求.
  • val_check_interval: 验证间隔, 计量单位是epoch, 如果有更高的验证频次需求, 也可以设置为小数, 即不到1个epoch验证一次.
  • gpus: 使用的GPU数量. 在即将出现的2.0版本中会被accelerator=gpu, device=x取代.
  • precision: 全精度 / 半精度训练.
  • accumulate_grad_batches: 梯度累加, 可以多个batch更新一次梯度, 以间接的近似大batch的效果. PS: 听说对比学习不能用.
  • gradient_clip_val: 梯度裁剪, 将梯度大小限制在该值内, 防止梯度过大崩掉.
  • num_sanity_val_steps: 在执行训练前会先用几个batch的验证数据跑一下, 检查代码是否有问题, 设置为-1为全部, 0为不检测. 我一般设置为0.
  • callbacks: 回调函数, 接受值为回调函数的列表, 下小节会讲.

Callback

一般来说, 早停检查点是两个比较常用的Callback, 需要在Trainer定义时作为参数传入. 例如:

from pytorch_lightning import Trainer  
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint  

early_stopping = EarlyStopping('val_loss')  
checkpoint = ModelCheckpoint(  
    save_weights_only=True,  
    save_on_train_epoch_end=False,  
    monitor="valid_f1",  
    mode="max",  
    save_top_k=3,  
)

trainer = Trainer(callbacks=[early_stopping, checkpoint])

仅当ModelCheckpointsave_on_train_epoch_end设置为False时才会在验证时保存, 否则设置为True时是在训练时保存, 默认为None.

还有一个PrintTableMetricsCallback, 不用带参数, 会在每个epoch结束时打印表格, 不过我基本不用.

Trainer in Python scripts

通常情况下, 使用ArgumentParser能更灵活的跑实验. 可以对Trainer手动添加参数:

from argparse import ArgumentParser


def main(hparams):
    model = LightningModule()
    trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--accelerator", default=None)
    parser.add_argument("--devices", default=None)
    args = parser.parse_args()

    main(args)

如果需要修改某些参数可以在命令行附带上:

python main.py --accelerator 'gpu' --devices 2

但上面手动很麻烦, Trainer支持自动添加参数到里面:

from argparse import ArgumentParser


def main(args):
    model = LightningModule()
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    main(args)

也可以走混合路线, 同时定义别的超参和Trainer的参数到parser里.

其实都不如Hydra来的优雅, 见文末Recommended推荐的模板.

Tips

这一节写一些我自己使用过程中用到的一些很有用的小技巧.

  1. pl.LightningModule的构造函数里面, 使用self.save_hyperparameters()可以将pl.LightningModule中所有的传入参数记录到yaml文件里, 非常方便于实验记录.

  2. pl.seedeverything(), 彻底告别自己写随机种子设置函数.

  3. 有的时候想把模型的预测结果和模型本身的权重保存到同一个目录下, 但是我不想自己按照规则去写路径, 而是和Trainer的设置同步, 该怎么办呢? 在pl.LightningModule里会添加Trainer的Hook, 调用self.trainer就能够获得它身上的属性. 例如我想把模型预测结果保存到日志目录下, 应该这么写:

    pred_save_path = os.path.join(self.trainer.log_dir, "prediction.json")  
    your_save_function(pred_save_path)

    其他需要的属性也是同理, 通过Hook可以轻松拿到Trainer身上的属性.

  4. 使用pl.LightningModule.load_from_checkpoint(ckpt_path)可以一条命令直接为TaskModel加载超参和模型权重.


文章作者: AnNing
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 AnNing !
评论
  目录