返回信息流def train(emb_dim=32, n_layer=3, n_head=4):
dataset = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ", dataset.date_cn[:3], "\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
print("Vocabularies: ", dataset.vocab)
print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
# shuffle=True改成了False
loader = DataLoader(dataset, batch_size=32, shuffle=False)
model = Transformer(n_vocab=dataset.num_word, max_len=MAX_LEN, n_layer=n_layer, emb_dim=emb_dim, n_head=n_head,
drop_rate=0.1, padding_idx=0)
if torch.cuda.is_available():
print("GPU train avaliable")
device = torch.device("cuda")
model = model.cuda()
else:
device = torch.device("cpu")
model = model.cpu()
for i in range(100):
for batch_idx, batch in enumerate(loader):
bx, by, decoder_len = batch
bx, by = torch.from_numpy(utils.pad_zero(bx, max_len=MAX_LEN)).type(torch.LongTensor).to(
device), torch.from_numpy(utils.pad_zero(by, MAX_LEN + 1)).type(torch.LongTensor).to(device)
loss, logits = model.step(bx, by)
if batch_idx % 50 == 0:
target = dataset.idx2str(by[0, 1:-1].cpu().data.numpy())
pred = model.translate(bx[0:1], dataset.v2i, dataset.i2v)
res = dataset.idx2str(pred[0].cpu().data.numpy())
src = dataset.idx2str(bx[0].cpu().data.numpy())
print(
"Epoch: ", i,
"| t: ", batch_idx,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--emb_dim", type=int, help="change the model dimension")
parser.add_argument("--n_layer", type=int, help="change the number of layers in Encoder and Decoder")
parser.add_argument("--n_head", type=int, help="change the number of heads in MultiHeadAttention")
args = parser.parse_args()
args = dict(filter(lambda x: x[1], vars(args).items()))
train(**args)
transfomer程序运行出错
Traceback (most recent call last):
File "D:/software/pycharm/pythonProject/CNN/transformer.py", line 284, in <module>
train(**args)
File "D:/software/pycharm/pythonProject/CNN/transformer.py", line 255, in train
for batch_idx, batch in enumerate(loader):
File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 681, in __next__
data = self._next_data()
File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 720, in _next_data
index = self._next_index() # may raise StopIteration
File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 671, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "D:\Anaconda3\lib\site-packages\torch\utils\data\sampler.py", line 247, in __iter__
for idx in self.sampler:
File "D:\Anaconda3\lib\site-packages\torch\utils\data\sampler.py", line 76, in __iter__
return iter(range(len(self.data_source)))
TypeError: object of type 'DateData' has no len()
这是一条镜像帖。来源:北邮人论坛 / python / #26193同步于 2022/10/8
Python机器人发帖
【求助】pycharm上transformer运行出错
painting
2022/10/8镜像同步0 回复
订阅后,新回复会通过你的通知中心匿名送达。
0 条回复
暂无回复 · 你可以订阅本帖等待新回复。