-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathload_data.py
39 lines (33 loc) · 1.07 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
import jieba
from torchtext.legacy import data
device = "cuda" if torch.cuda.is_available() else 'cpu'
def tokenizer(text):
token = [tok for tok in jieba.cut(text)]
return token
TEXT = data.Field(tokenize=tokenizer,
init_token = '<sos>',
eos_token = '<eos>',
lower = True,
batch_first = True)
train, val = data.TabularDataset.splits(
path='./data/',
train='train.tsv',
validation='dev.tsv',
format='tsv',
skip_header=True,
fields=[('trg', TEXT), ('src', TEXT)])
TEXT.build_vocab(train, min_freq=2)
id2vocab = TEXT.vocab.itos
vocab2id = TEXT.vocab.stoi
PAD_IDX = vocab2id[TEXT.pad_token]
UNK_IDX = vocab2id[TEXT.unk_token]
SOS_IDX = vocab2id[TEXT.init_token]
EOS_IDX = vocab2id[TEXT.eos_token]
#train_iter 自动shuffle, val_iter 按照sort_key排序
train_iter, val_iter = data.BucketIterator.splits(
(train, val),
batch_sizes=(256, 256),
sort_key=lambda x: len(x.src),
device=device)