BBYR Achieve
返回信息流
这是一条镜像帖。来源:北邮人论坛 / python / #26193同步于 2022/10/8
Python机器人发帖

【求助】pycharm上transformer运行出错

painting
2022/10/8镜像同步0 回复
def train(emb_dim=32, n_layer=3, n_head=4): dataset = utils.DateData(4000) print("Chinese time order: yy/mm/dd ", dataset.date_cn[:3], "\nEnglish time order: dd/M/yyyy", dataset.date_en[:3]) print("Vocabularies: ", dataset.vocab) print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}", f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}") # shuffle=True改成了False loader = DataLoader(dataset, batch_size=32, shuffle=False) model = Transformer(n_vocab=dataset.num_word, max_len=MAX_LEN, n_layer=n_layer, emb_dim=emb_dim, n_head=n_head, drop_rate=0.1, padding_idx=0) if torch.cuda.is_available(): print("GPU train avaliable") device = torch.device("cuda") model = model.cuda() else: device = torch.device("cpu") model = model.cpu() for i in range(100): for batch_idx, batch in enumerate(loader): bx, by, decoder_len = batch bx, by = torch.from_numpy(utils.pad_zero(bx, max_len=MAX_LEN)).type(torch.LongTensor).to( device), torch.from_numpy(utils.pad_zero(by, MAX_LEN + 1)).type(torch.LongTensor).to(device) loss, logits = model.step(bx, by) if batch_idx % 50 == 0: target = dataset.idx2str(by[0, 1:-1].cpu().data.numpy()) pred = model.translate(bx[0:1], dataset.v2i, dataset.i2v) res = dataset.idx2str(pred[0].cpu().data.numpy()) src = dataset.idx2str(bx[0].cpu().data.numpy()) print( "Epoch: ", i, "| t: ", batch_idx, "| loss: %.3f" % loss, "| input: ", src, "| target: ", target, "| inference: ", res, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--emb_dim", type=int, help="change the model dimension") parser.add_argument("--n_layer", type=int, help="change the number of layers in Encoder and Decoder") parser.add_argument("--n_head", type=int, help="change the number of heads in MultiHeadAttention") args = parser.parse_args() args = dict(filter(lambda x: x[1], vars(args).items())) train(**args) transfomer程序运行出错 Traceback (most recent call last): File "D:/software/pycharm/pythonProject/CNN/transformer.py", line 284, in <module> train(**args) File "D:/software/pycharm/pythonProject/CNN/transformer.py", line 255, in train for batch_idx, batch in enumerate(loader): File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 681, in __next__ data = self._next_data() File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 720, in _next_data index = self._next_index() # may raise StopIteration File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 671, in _next_index return next(self._sampler_iter) # may raise StopIteration File "D:\Anaconda3\lib\site-packages\torch\utils\data\sampler.py", line 247, in __iter__ for idx in self.sampler: File "D:\Anaconda3\lib\site-packages\torch\utils\data\sampler.py", line 76, in __iter__ return iter(range(len(self.data_source))) TypeError: object of type 'DateData' has no len()
订阅后,新回复会通过你的通知中心匿名送达。
0 条回复
暂无回复 · 你可以订阅本帖等待新回复。