如何优化Trainer:灵活性与易用性的平衡

释放双眼,带上耳机,听听看~!
本文讨论了如何平衡模型训练中的灵活性与易用性,重点介绍了基于pytorch lightning和huggingface的Trainer,以及针对多模态大模型的代码框架。文章还详细介绍了一个开源项目的功能和灵活性,欢迎阅读并提出建议。

训练模型不能没有一个灵活的Trainer,就像纪录片不能没有麦克阿瑟

说到Trainer,大多人会想到pytorch lightning和huggingface,也有相关问题去对比这二者,在使用过huggingface的Trainer后,我认为它有以下两个缺点:

  • 用多层封装换来了易用性,但如果要自定义模块(比如:想给cosine scheduler设置一个min_lr、想实现vit的学习率逐层decay)会比较麻烦
  • 参数和功能有点多了,这些功能耦合在一起,会有些混乱,对于自己做小项目或者做科研,似乎不需要这么多功能

在上个月,蹭着通义千问的热度,我写了这篇

juejin.cn/post/729832…

项目开源于:

github.com/Coobiw/Mini…

该项目主要是重构lavis之后搭建的,lavis(github.com/salesforce/… )是多模态领域很火的一个开源仓库,像BLIP2、InstructBLIP、MiniGPT4等许多多模态大模型都是基于lavis进行进一步开发的。在仔细阅读其源码后,我非常喜欢它的代码框架,所以我针对其Trainer进行重构,可以更加灵活地适配或迁移到用户的任务、模型、数据集。

这个干净、灵活又不太冗杂的Trainer开源在:

github.com/Coobiw/Mini…

欢迎大家在私信、知乎、github仓库issue中给这个项目提提建议,如果对你有帮助的话,请多多点star呀!这对我真很重要:)

实现的功能

  • Registry机制:为model、dataset、processor(预处理的transform)、lr_scheduler、task(现在进行的task,如:分类、分割、image2prompt等)构建注册表
  • 完整、灵活的配置文件:一个配置文件对应一次完整的运行(训练),有多而不冗余的参数可供设置
  • 去冗余性:
    • 对于上述的注册表中的每个组件,都提供有基类,减少代码重复
    • 去除一些重复、冗余的功能
  • 可扩展性/灵活性:自顶向下满足了
    • 任务可扩展(类似于OpenMMLab基于MMEngine和MMCV支持了那么多视觉、多模态任务):对于所有任务均可支持,本项目支持了图像分类(以猫狗分类为例)、Image2Prompt(为了适配本菜鸡第一次kaggle比赛(www.kaggle.com/competition…,最终获得银牌,虽菜但难忘就多实现了它,简易化pipeline如下图)
    • 模型可扩展
    • 数据集可扩展性(包含预处理的可扩展性)
    • scheduler的可扩展性

如何优化Trainer:灵活性与易用性的平衡

支持新功能的QuickTutorial

定义你的数据集

datasets目录下,继承BaseDataset类,实现你的dataset,如果需要自定义collator,请在这里完成

例如如下代码完成了一个新的分类数据集定义,其目录结构见example_data/classification

from common.registry import registry
from .base_dataset import BaseDataset

from PIL import Image
from pathlib import Path
import os

import torch

@registry.register_dataset('example_cls_dataset')
class ExampleClsDataset(BaseDataset):
    def __init__(self, transform, data_root):
        super().__init__(transform=transform)
        self.data_root = Path(data_root)
        self.cls_dir = sorted(list(os.listdir(self.data_root)))

        self.data = []
        self.labels = []

        self.idx2cls,self.cls2idx = {}, {}
        for i,cls in enumerate(self.cls_dir):
            self.idx2cls[i] = cls
            self.cls2idx[cls] = i
            imgs = [str(self.data_root/cls/img) for img in os.listdir(self.data_root/cls) if self.is_img(img)]
            self.data.extend(imgs)
            self.labels.extend([i]*len(imgs))

        assert len(self.data) == len(self.labels)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = self.data[index]
        label = self.labels[index]

        image = Image.open(image_path).convert("RGB")

        return self.transform(image), torch.tensor(label,dtype=torch.long)

    @staticmethod
    def is_img(img_path):
        return Path(img_path).suffix.lower() in ['.jpg', '.jpeg', '.png']

    def collator(self,batch):
        images, labels = zip(*batch)
        images = torch.stack(images)
        labels = torch.stack(labels)

        return {"images": images, "labels": labels}

定义你的模型

models目录下,继承BaseModel类进行实现,可以参考本库给出的resnet_clip

注意:请为你的模型实现train_stepval_step两个方法,会在task.train_steptask.val_step时调用

定义新的task

tasks目录下,继承BaseTask类进行实现,可以参考本库给出的ClassificationTask任务

注:一般来说,只需要针对任务和任务对应的metric修改 val_step 即可

为什么需要Trainer和registry机制

要想知道为什么需要Trainer,首先我们创造一个没有Trainer的时代,只使用原生pytorch去构建一个训练流程,这时我们需要做:

  1. 定义Dataset、Dataloader
  2. 定义model
  3. 定义损失函数
  4. 定义损失函数
  5. 定义优化器
  6. 定义训练过程中的学习率变化策略(scheduler)
  7. 循环、迭代更新模型

大致pytorch代码如下:

import torch
from torch.utils.data import DataLoader

from dataset import train_data, val_data
from network import Net

# dataloader
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=16, shuffle=False)

# 定义模型
model = Net(...)

# 定义损失函数
criterion = torch.nn.CrossEntropyLoss()

# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# 定义scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,...)

# 定义训练epoch的次数
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 循环迭代起来
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0

    model.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        model = model.to(device)
        out = model(x)
        loss = criterion(out, y)
        train_loss += loss.item()
        prediction = torch.max(out,1)[1]
        pred_correct = (prediction == y).sum()
        train_acc += pred_correct.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(...)
    # 验证
    model.eval()
    with torch.no_grad():
        eval_loss = 0
        eval_acc = 0
        for i, (x, y) in enmuerate(val_loader):
            out = model(x)
            loss = criterion(out, y)
            eval_loss += loss.item()
            prediction = torch.max(out, 1)[1]
            pred_correct = (prediction == y).sum()
            eval_acc += pred_correct.item()
        print(...)

这段代码看起来没什么问题,很简洁易懂,代价就是:

代码封装上:

  • 当前代码完全没有怎么进行封装,完成pipeline的每个部分都呈现在一个脚本文件中,从程序设计的角度很不优美

事实上,只需要:

# 各种组件的定义
......

trainer.train()

trainer.eval()

将固定的迭代更新流程写在trainer的train()函数中,然后train里再

class Trainer:
    ......
    
    def train(xxxx):
        for _ in range(epoch):
            self.train_step(xxxxx)
            
            self.val_step(xxxx)
      
    def train_step(xxx):
        for sample in dataloader:
            loss = model.train_step(sample)
            # 反向传播
            ......
    def val_step(xx):
        for sample in dataloader:
            metric = model.val_step(sample)
            # log、save模型等其他功能
            ......

灵活性上:

  • 当前任务,我要换个模型!

    • 好的,from network import Net 改成 from network import NewNet
  • 另一个任务,我要换另一个任务的模型!

    • 好的,from network import Net 改成 from network import NewTaskNet
  • 我要换个数据集!

    • 好的,from dataset import train_data 改成 from dataset import new_train_data
  • 我要换个优化器!我要换个scheduler!

    • 好的,要么也像上面一样,在别的文件实现了,然后改import

    • 要么,直接改这个脚本里的变量定义

这两个问题,就可以通过实现一个Trainer,并且加入register机制,将pipeline中各个模块的定义字符串化,可以通过一个yaml文件直接定义一次运行行为中所有组件。

可以认为,通过Register机制定义的是一个个组件,而Trainer就是一个封装所需功能,并给这些组件提供插槽!

灵活的基础组件定义:register机制——将类定义字符串化

现在的一些高影响力的开源仓库,经常离不开register机制,比如:timm、openmmlab系列仓库等,register机制的代码大致如下:

class Registry:
    mapping = {
        "dataset_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
    }
    @classmethod
    def register_model(cls, name):
        def wrap(model_cls):
            from models import BaseModel

            assert issubclass(
                model_cls, BaseModel
            ), "All models must inherit BaseModel class"
            if name in cls.mapping["model_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["model_name_mapping"][name]
                    )
                )
            cls.mapping[  "model_name_mapping"  ][name] = model_cls
            return model_cls

        return wrap
        
    @classmethod
    def get_model_class(cls, name):
        return cls.mapping["model_name_mapping"].get(name, None)

registry = Registry()

可以看到,关键的一行代码就是:

cls.mapping[  "model_name_mapping"  ][name] = model_cls

即将一个定义好的model class放到model_name_mapping这个字典(一般称为注册表)中,如果我们需要找到这个模型,只需要:

model_name = xxx # name
model = register.get_model_class(model_name)(模型的init参数们)

这样就实现了字符串与model class的映射,后续就不需要像之前说的每次import新定义的model class,直接通过修改配置文件里的model_name即可。

完成注册操作的代码如下:

@registry.register_model(模型name)
class NewNet(BaseModel): # 继承预定义好的BaseModel类,减少重复代码的ctrl cv
    ......

前面在定义registry时,我们有:

mapping = {
        "dataset_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
    }

这些注册表,包含:dataset、task、processor(输入的预处理,比如:读取图像、数据增强、ToTensor、归一化)、model、lr_scheduler,如果有进一步的包括优化器等你愿意去修改的组件,都可以为他构建一个注册表

这样一来,训练pipeline中各个组件的定义就完全字符串化了,这里放上一个最后的配置文件的部分截图和简单注释,让大家直观地感受字符串化后的好处!

如何优化Trainer:灵活性与易用性的平衡

给各个组件来一个功能齐全且带“插槽”的Trainer

当定义好各个组件后,就需要一个功能齐全,带插槽的Trainer,来让它们发挥作用了,在Trainer中,需要定义一整个训练、验证的流程,需要将输入的组件们进一步封装,发挥作用(如:将dataset变成dataloader)。

从自底向上的编程角度来看,Trainer就应该处于最上层,他需要足够大、足够global,可以适配底部组件们的变化。

首先,Trainer将组件插入进来的过程:(这里并没有把所有参数介绍全,仅介绍了最常见的组件和参数们,更加细节可以去仓库看源码)

class Trainer:
    def __init__(self,config,model,datasets,task,job_id):
        self.config = config
        self.job_id = job_id
        self._model = model
        self.datasets = datasets
        self.task = task

        self._wrapped_model = None
        self._device = None
        self._optimizer = None
        self._scaler = None
        self._dataloaders = None
        self._lr_sched = None

        self.start_epoch = 0

        self.setup_output_dir()
 
    @property
    def device(self):
        if self._device is None:
            self._device = torch.device(self.config.run.device)

        return self._device

    @property
    def use_distributed(self):
        return self.config.run.distributed

    @property
    def model(self):
        """
        A property to get the DDP-wrapped model on the device.
        """
        # move model to device
        if self._model.device != self.device:
            self._model = self._model.to(self.device)

            # ddp training wrapper
            if self.use_distributed:
                if self._wrapped_model is None:
                    self._wrapped_model = DDP(
                        self._model, device_ids=[self.config.run.gpu]
                    )
            else:
                self._wrapped_model = self._model

        return self._wrapped_model

    @property
    def dataloaders(self) -> dict:
        run_cfg = self.config.run
        if self._dataloaders is None:
            self._dataloaders = get_dataloaders(
                datasets = self.datasets,
                batch_size = run_cfg.batch_size,
                batch_size_val = run_cfg.batch_size_val,
                num_worker = run_cfg.num_worker,
                ddp = run_cfg.distributed,
            )
        return self._dataloaders

    @property
    def optimizer(self):
        if self._optimizer is None:
            # 可以用这个实现逐层lr decay
            # 需要重写model的get_optimizer_params,可以参考lavis的vit
            lr_scale = self.config.run.get("lr_layer_decay", 1)
            weight_decay = self.config.run.get("weight_decay", 0.05)
            optim_params = self._model.get_optimizer_params(weight_decay,lr_scale)

            num_parameters = 0
            for p_group in optim_params:
                for p in p_group["params"]:
                    num_parameters += p.data.nelement()
            logging.info("number of trainable parameters: {}".format(num_parameters))

            beta2 = self.config.run.get("beta2", 0.999)

            self._optimizer = torch.optim.AdamW(
                optim_params,
                lr=float(self.config.run.init_lr),
                betas=(0.9, beta2),
            )
        return self._optimizer

    @property
    def scaler(self):
        amp = self.config.run.get("amp", False)

        if amp:
            if self._scaler is None:
                self._scaler = torch.cuda.amp.GradScaler()

        return self._scaler

    @property
    def lr_scheduler(self):
        """
        A property to get and create learning rate scheduler by split just in need.
        """
        if self._lr_sched is None:
            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run.lr_sched)

            # max_epoch = self.config.run.max_epoch
            max_epoch = self.max_epoch
            # min_lr = self.config.run.min_lr
            min_lr = self.min_lr
            # init_lr = self.config.run.init_lr
            init_lr = self.init_lr

            # optional parameters
            decay_rate = self.config.run.get("lr_decay_rate", None)
            warmup_start_lr = self.config.run.get("warmup_lr", -1)
            warmup_steps = self.config.run.get("warmup_steps", 0)

            self._lr_sched = lr_sched_cls(
                optimizer=self.optimizer,
                max_epoch=max_epoch,
                min_lr=min_lr,
                init_lr=init_lr,
                decay_rate=decay_rate,
                warmup_start_lr=warmup_start_lr,
                warmup_steps=warmup_steps,
            )

        return self._lr_sched

    @property
    def train_loader(self):
        train_dataloader = self.dataloaders["train"]

        return train_dataloader

而其训练流程中,大致完成以下功能:

  • resume功能:保存训练中的checkpoint(包含模型、优化器状态、scheduler状态、epoch等),如果中断可以加载,继续训练
  • 保存checkpoint,也会根据eval结果保存最好的一个checkpoint
  • 日志功能
  • 训练、验证的基本流程
    @main_process
    def _save_checkpoint(self, cur_epoch, is_best=False):
        """
        Save the checkpoint at the current epoch.
        """
        model_no_ddp = self.unwrap_dist_model(self.model)
        param_grad_dic = {
            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
        }
        state_dict = model_no_ddp.state_dict()
        for k in list(state_dict.keys()):
            if k in param_grad_dic.keys() and not param_grad_dic[k]:
                # delete parameters that do not require gradient
                del state_dict[k]

        save_obj = {
            "model": state_dict,
            "optimizer": self.optimizer.state_dict(),
            "config": OmegaConf.to_container(self.config),
            "scaler": self.scaler.state_dict() if self.scaler else None,
            "epoch": cur_epoch,
        }
        save_to = os.path.join(
            self.output_dir,
            "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
        )
        logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
        torch.save(save_obj, save_to)

    def _reload_best_model(self, model):
        """
        Load the best checkpoint for evaluation.
        """
        checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")

        logging.info("Loading checkpoint from {}.".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        try:
            model.load_state_dict(checkpoint["model"])
        except RuntimeError as e:
            logging.warning(
                """
                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
                Trying to load the model with strict=False.
                """
            )
            model.load_state_dict(checkpoint["model"], strict=False)
        return model

    def _load_checkpoint(self, filename):
        """
        Resume from a checkpoint.
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location=self.device)
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]
        self.unwrap_dist_model(self.model).load_state_dict(state_dict)

        self.optimizer.load_state_dict(checkpoint["optimizer"])
        if self.scaler and "scaler" in checkpoint:
            self.scaler.load_state_dict(checkpoint["scaler"])

        self.start_epoch = checkpoint["epoch"] + 1
        logging.info("Resume checkpoint from {}".format(filename))

    @main_process
    def log_stats(self, stats, split_name):
        if isinstance(stats, dict):
            log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
            with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
                f.write(json.dumps(log_stats) + "n")
        elif isinstance(stats, list):
            pass

    @main_process
    def log_config(self):
        with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
            f.write(json.dumps(OmegaConf.to_container(self.config), indent=4) + "n")
  
    @torch.no_grad()
    def eval_epoch(self, cur_epoch, skip_reload=False):
        """
        Evaluate the model on a given split.

        Args:
            split_name (str): name of the split to evaluate on.
            cur_epoch (int): current epoch.
            skip_reload_best (bool): whether to skip reloading the best checkpoint.
                During training, we will reload the best checkpoint for validation.
                During testing, we will use provided weights and skip reloading the best checkpoint .
        """
        data_loader = self.dataloaders.get('val', None)
        assert data_loader, "data_loader for split {} is None.".format("val")

        # TODO In validation, you need to compute loss as well as metrics
        # TODO consider moving to model.before_evaluation()
        model = self.unwrap_dist_model(self.model)
        if not skip_reload and cur_epoch == "best":
            model = self._reload_best_model(model)
        model.eval()

        self.task.before_evaluation(
            model=model,
            dataset=self.datasets["val"],
        )
        results = self.task.evaluation(model, data_loader)

        if results is not None:
            return self.task.after_evaluation(
                val_result=results,
                epoch=cur_epoch,
            )

    def train(self):
        start_time = time.time()
        best_agg_metric = 0
        best_epoch = 0
        best_metrics = {}

        self.log_config()

        # resume from checkpoint if specified
        if not self.evaluate_only and self.resume_ckpt_path is not None:
            self._load_checkpoint(self.resume_ckpt_path)

        for cur_epoch in range(self.start_epoch, self.max_epoch):
            # training phase
            if not self.evaluate_only:
                logging.info("Start training")
                # See https://github.com/salesforce/LAVIS/issues/449
                # if cur_epoch == self.start_epoch:
                #     self.task.before_training(
                #         model=self.unwrap_dist_model(self.model),
                #         dataset=self.datasets["train"],
                #     )
                train_stats = self.train_epoch(cur_epoch)
                self.log_stats(split_name="train", stats=train_stats)

            # evaluation phase
            if cur_epoch % self.eval_freq == 0 or cur_epoch == self.max_epoch -1:
                logging.info("Evaluating on {}.".format("val"))

                val_log = self.eval_epoch(
                    cur_epoch=cur_epoch,
                )
                if val_log is not None:
                    if is_main_process():
                        assert (
                            "agg_metrics" in val_log
                        ), "No agg_metrics found in validation log."

                        agg_metrics = val_log["agg_metrics"]
                        if agg_metrics > best_agg_metric:
                            best_epoch, best_agg_metric = cur_epoch, agg_metrics
                            best_metrics = deepcopy(val_log)

                            self._save_checkpoint(cur_epoch, is_best=True)

                        if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
                            self._save_checkpoint(cur_epoch, is_best=False)
                        val_log.update({"best_epoch": best_epoch})
                        self.log_stats(val_log, "val")
                else:  # 没有定义task的evaluation
                    if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
                        self._save_checkpoint(cur_epoch, is_best=False)

            else:
                if not self.evaluate_only:
                    if cur_epoch % self.save_freq == 0:
                        self._save_checkpoint(cur_epoch, is_best=False)

            if self.evaluate_only:
                break

            if is_dist_avail_and_initialized():
                dist.barrier()

        return best_metrics

    def train_epoch(self, epoch):
        # train
        self.model.train()

        return self.task.train_epoch(
            epoch=epoch,
            model=self.model,
            data_loader=self.train_loader,
            optimizer=self.optimizer,
            scaler=self.scaler,
            lr_scheduler=self.lr_scheduler,
            cuda_enabled=self.cuda_enabled,
            log_freq=self.log_freq,
            accum_grad_iters=self.accum_grad_iters,
            grad_norm_clip=self.grad_norm_clip,
        )

train.py文件:定义各组件并插入进Trainer中完成训练

train.py文件的运行指令:

# 单卡
python train.py --cfg-path projects/train_classification.yaml

# 多卡
python -m torch.distributed.run --nproc_per_node=4 train.py --cfg-path projects/train_classification.yaml

# 换个任务
python train.py --cfg-path projects/train_image2prompt.yaml

其主要完成以下内容:

  • 解析配置的yaml文件
  • 根据配置的yaml文件在注册表中找到并定义指定的model、dataset、processor、task、lr_scheduler等基础组件
  • 将基础组件插入Trainer,调用trainer.train()进行训练和验证
import os
from pathlib import Path

import warnings

import argparse
from omegaconf import OmegaConf

import random
import numpy as np
import torch
import torch.distributed as dist

from common.dist_utils import (
    init_distributed_mode,
    main_process,
)

from common.registry import registry
from common.logger import setup_logger
from tasks import setup_task

from trainer import Trainer

# imports modules for registration
from common.optims import (
    LinearWarmupCosineLRScheduler,
    LinearWarmupStepLRScheduler,
    ConstantLRScheduler,
)  # 加入到注册表里,不用直接使用(由于是from的import形式,optim.py里的所有类都会加入注册表,所以实际上import一个也可以)

from processors import load_processor
from models import *
from datasets import load_dataset

warnings.filterwarnings('ignore')

def now():
    from datetime import datetime

    return datetime.now().strftime("%Y%m%d%H%M")[:-1]

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

def get_config(args):
    cfg_path = Path(args.cfg_path)
    assert cfg_path.suffix == '.yaml', 'config file must be .yaml file'
    config = OmegaConf.load(cfg_path)
    init_distributed_mode(config.run)
    return config

def get_transforms(config) -> dict:
    dataset_cfg = config.dataset

    transforms = {}
    transforms['train'] = load_processor(**dataset_cfg.train_cfg.transform)
    transforms['val'] = load_processor(**dataset_cfg.val_cfg.transform)

    return transforms

def get_datasets(config,transforms) -> dict:
    dataset_cfg = config.dataset

    datasets = {}
    train_cfg = dict(dataset_cfg.pop('train_cfg'))
    val_cfg = dict(dataset_cfg.pop('val_cfg'))
    train_cfg['transform'], val_cfg['transform']= transforms['train'],transforms['val']
    datasets["train"] = load_dataset(train_cfg.pop('name'),train_cfg)
    datasets['val'] = load_dataset(val_cfg.pop('name'),val_cfg)

    return datasets

def get_model(config):
    model_cfg = config.model
    model_cls = registry.get_model_class(model_cfg.arch)
    return model_cls.from_config(model_cfg)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg-path',type=str)
    parser.add_argument('--seed',type=int,default=42)
    args = parser.parse_args()

    seed_everything(args.seed)
    config = get_config(args)

    setup_logger()

    transforms = get_transforms(config)
    datasets = get_datasets(config,transforms)
    model = get_model(config)
    task = setup_task(config)
    job_id = now()

    trainer = Trainer(config,model,datasets,task,job_id)
    trainer.train()

if __name__ == "__main__":
    main()
本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

WebGLM:大模型的联网增强问答系统

2023-12-8 10:15:14

AI教程

激活函数的应用与实现:Relu节点的计算图和可视化

2023-12-8 10:17:00

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索