BBYR Achieve
返回信息流
这是一条镜像帖。来源:北邮人论坛 / ml-dm / #35919同步于 2019/12/22
该镜像源已超过 30 天没有更新,可能在源站已被删除。
ML_DM机器人发帖

【求建议】我写了个框架Jdit,深度学习项目研究的脚手架。

mima031103
2019/12/22镜像同步41 回复
深度学习脚手架,Jdit 新加了中文的使用指南! https://dingguanglei.com/jdit/ Jdit用于快速搭建深度学习项目,是一个能够承担辅助工作(可视化、保存等)的基于pytorch的脚手架。 有没有发现你的项目代码里有好多辅助工作的代码,然而这些并不是你的核心研究内容? 例如: count_parameters(model) ... plt.imshow(xxx) ... print(time-time()) ... SummaryWriter.add_scalar(xxx) ... SummaryWriter.add_graphy(xxx) ... def model_checkpoint(epoch): ... if epoch % 10: print(loss) 虽然这些可视化、打印输出等工作几乎是一行代码就能搞定,但是如果很多这样的代码和你的核心逻辑掺杂在一起。那真的是难以维护。况且这些工作就是简单的调API,冗杂且无聊,然而你却又必须要做。无奈~~ [ema2] 我希望所有辅助工作能和我的核心研究内容解耦,于是就有了Jdit。它可以帮助你承担这部分的辅助工作,让你专注于核心的研究内容。 同时,保留原生pytorch的操作(放开手脚用pytorch随意定义模型,使用原生的操作和运算符随意计算loss) 。就像是你自己按照喜好拼装好赛车,jdit来给你搭建场地,测量,维护人员。如果不满意甚至都可以随意修改。(末尾有示例) Github Jdit介绍 [ema3] 研究人员关注的: - 数据集 - 模型 - loss函数 - 训练逻辑 Jdit辅助工作: - 数据可视化(tensorboard) - 模型可视化,检查点保存 - 配置信息保存 - 训练/验证的输出展示 - ...... 在Jdit中,一个分类任务就是这样简单。实现自己的核心内容(数据集,模型,优化器,loss函数)。Tensorboard可视化、数据和模型的保存全部由Jdit来做。结果展示在Github 中。运行后,你想要的一切都在“log“文件夹中。 分类任务: # This is your model. Defined by torch.nn.Module class SimpleModel(nn.Module): ...... # A trainer, you need to rewrite the loss and valid function. class FashionClassTrainer(ClassificationTrainer): ...... def compute_loss(self): var_dic = {} var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth) return loss, var_dic def compute_valid(self): _, var_dic = self.compute_loss() return var_dic def start_fashionClassTrainer(gpus=(), nepochs=10): num_class = 10 depth = 32 gpus = gpus batch_size = 64 nepochs = nepochs logdir = "log/fashion_classify" opt_hpm = {"optimizer": "Adam", "lr_decay": 0.94, "decay_position": 10, "decay_type": "epoch", "lr": 1e-3, "weight_decay": 2e-5, "betas": (0.9, 0.99)} mnist = FashionMNIST(batch_size=batch_size) net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1) opt = Optimizer(net.parameters(), **opt_hpm) Trainer = FashionClassTrainer(logdir, nepochs, gpus, net, opt, mnist, num_class) Trainer.train() if __name__ == '__main__': start_fashionClassTrainer() Jdit是我在跑过很过项目后萌生的想法。同样的代码每次都要copy来copy去,即便封装好也不是那么方便。 既然常见的任务类型都是那么的大同小异(分类,生成,pix2pix等)。为什么不把这些逻辑包装起来做成模板呢。 我用这个框架做了很多生成对抗网络的任务,非常的方便,我不再需要写任何代码用于辅助工作了,一切都是Jdit底层模板来完成。 感兴趣的同学可以尝试一下!如果能为大家提供便利自然是最好了。同时,非常欢迎大家提供建议或吐槽!毕竟写库也没什么经验!向大佬们求点建议[ema3] jdit.trainer.instances里面列出了现有模板的示例,可以直接copy运行。 目前已经完成的模板是分类,encode-decoder,基于GAN的生成任务,基于GAN的pix2pix任务。 目前并不用于检测类任务。 另外使用了清华镜像的同学 pip install jdit 的版本可能停留在 0.1.2 。 可以指定源来安装最新版本0.1.5 pip install jdit -i https://pypi.Python.org/simple/ 示例都在jdit.trainer.instances下 分类任务的示例。 from jdit.trainer.instances.fashionClassification import start_fashionClassTrainer if __name__ == '__main__': start_fashionClassTrainer() 生成任务的示例。 from jdit.trainer.instances.fashionGenerateGan import start_fashionGenerateGanTrainer if __name__ == '__main__': start_fashionGenerateGanTrainer() 多个任务的串并行示例。 from jdit.trainer.instances.fashionClassParallelTrainer import start_fashionClassPrarallelTrainer if __name__ == '__main__': start_fashionClassPrarallelTrainer() 运行后会生成一个`log`文件夹。里面保存了训练过程中的所有数据(checkpoint和csv文件),以及tensorboard的展示数据。 运行tensorboard可以看到训练过程中的loss,输入输出的图像等。 tensorboard --logdir=log
订阅后,新回复会通过你的通知中心匿名送达。
9 条回复
a543151514机器人#1 · 2019/12/22
好文 绑定
CrazyDream机器人#2 · 2019/12/22
Bd
heng1995520机器人#3 · 2019/12/22
听起来很高大上,帮顶哈哈哈 通过『我邮2.0』发布
zsygxh5机器人#4 · 2019/12/22
bd 通过『我邮2.0』发布
dtlqzx机器人#5 · 2019/12/22
bd
GoKu机器人#6 · 2019/12/22
bd
wujackjack机器人#7 · 2019/12/22
想法很赞
yo1995机器人#8 · 2019/12/22
bd
deepquitlear机器人#9 · 2019/12/22
帮顶