返回信息流深度学习脚手架,Jdit
新加了中文的使用指南!
https://dingguanglei.com/jdit/
Jdit用于快速搭建深度学习项目,是一个能够承担辅助工作(可视化、保存等)的基于pytorch的脚手架。
有没有发现你的项目代码里有好多辅助工作的代码,然而这些并不是你的核心研究内容?
例如:
count_parameters(model)
...
plt.imshow(xxx)
...
print(time-time())
...
SummaryWriter.add_scalar(xxx)
...
SummaryWriter.add_graphy(xxx)
...
def model_checkpoint(epoch):
...
if epoch % 10:
print(loss)
虽然这些可视化、打印输出等工作几乎是一行代码就能搞定,但是如果很多这样的代码和你的核心逻辑掺杂在一起。那真的是难以维护。况且这些工作就是简单的调API,冗杂且无聊,然而你却又必须要做。无奈~~
[ema2]
我希望所有辅助工作能和我的核心研究内容解耦,于是就有了Jdit。它可以帮助你承担这部分的辅助工作,让你专注于核心的研究内容。
同时,保留原生pytorch的操作(放开手脚用pytorch随意定义模型,使用原生的操作和运算符随意计算loss)
。就像是你自己按照喜好拼装好赛车,jdit来给你搭建场地,测量,维护人员。如果不满意甚至都可以随意修改。(末尾有示例)
Github
Jdit介绍
[ema3]
研究人员关注的:
- 数据集
- 模型
- loss函数
- 训练逻辑
Jdit辅助工作:
- 数据可视化(tensorboard)
- 模型可视化,检查点保存
- 配置信息保存
- 训练/验证的输出展示
- ......
在Jdit中,一个分类任务就是这样简单。实现自己的核心内容(数据集,模型,优化器,loss函数)。Tensorboard可视化、数据和模型的保存全部由Jdit来做。结果展示在Github
中。运行后,你想要的一切都在“log“文件夹中。
分类任务:
# This is your model. Defined by torch.nn.Module
class SimpleModel(nn.Module):
......
# A trainer, you need to rewrite the loss and valid function.
class FashionClassTrainer(ClassificationTrainer):
......
def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth)
return loss, var_dic
def compute_valid(self):
_, var_dic = self.compute_loss()
return var_dic
def start_fashionClassTrainer(gpus=(), nepochs=10):
num_class = 10
depth = 32
gpus = gpus
batch_size = 64
nepochs = nepochs
logdir = "log/fashion_classify"
opt_hpm = {"optimizer": "Adam",
"lr_decay": 0.94,
"decay_position": 10,
"decay_type": "epoch",
"lr": 1e-3,
"weight_decay": 2e-5,
"betas": (0.9, 0.99)}
mnist = FashionMNIST(batch_size=batch_size)
net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1)
opt = Optimizer(net.parameters(), **opt_hpm)
Trainer = FashionClassTrainer(logdir, nepochs, gpus, net, opt, mnist, num_class)
Trainer.train()
if __name__ == '__main__':
start_fashionClassTrainer()
Jdit是我在跑过很过项目后萌生的想法。同样的代码每次都要copy来copy去,即便封装好也不是那么方便。
既然常见的任务类型都是那么的大同小异(分类,生成,pix2pix等)。为什么不把这些逻辑包装起来做成模板呢。
我用这个框架做了很多生成对抗网络的任务,非常的方便,我不再需要写任何代码用于辅助工作了,一切都是Jdit底层模板来完成。
感兴趣的同学可以尝试一下!如果能为大家提供便利自然是最好了。同时,非常欢迎大家提供建议或吐槽!毕竟写库也没什么经验!向大佬们求点建议[ema3]
jdit.trainer.instances里面列出了现有模板的示例,可以直接copy运行。
目前已经完成的模板是分类,encode-decoder,基于GAN的生成任务,基于GAN的pix2pix任务。
目前并不用于检测类任务。
另外使用了清华镜像的同学 pip install jdit 的版本可能停留在 0.1.2 。
可以指定源来安装最新版本0.1.5
pip install jdit -i https://pypi.Python.org/simple/
示例都在jdit.trainer.instances下
分类任务的示例。
from jdit.trainer.instances.fashionClassification import start_fashionClassTrainer
if __name__ == '__main__':
start_fashionClassTrainer()
生成任务的示例。
from jdit.trainer.instances.fashionGenerateGan import start_fashionGenerateGanTrainer
if __name__ == '__main__':
start_fashionGenerateGanTrainer()
多个任务的串并行示例。
from jdit.trainer.instances.fashionClassParallelTrainer import start_fashionClassPrarallelTrainer
if __name__ == '__main__':
start_fashionClassPrarallelTrainer()
运行后会生成一个`log`文件夹。里面保存了训练过程中的所有数据(checkpoint和csv文件),以及tensorboard的展示数据。
运行tensorboard可以看到训练过程中的loss,输入输出的图像等。
tensorboard --logdir=log
这是一条镜像帖。来源:北邮人论坛 / ml-dm / #35919同步于 2019/12/22
该镜像源已超过 30 天没有更新,可能在源站已被删除。
ML_DM机器人发帖
【求建议】我写了个框架Jdit,深度学习项目研究的脚手架。
mima031103
2019/12/22镜像同步41 回复
订阅后,新回复会通过你的通知中心匿名送达。
9 条回复