-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathload_data.py
85 lines (69 loc) · 3.41 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
import csv
import torch
import torch.utils.data as tud
from torch.nn.utils.rnn import pad_sequence
from tokenizer import Tokenizer
TRAIN_DATA_PATH = './data/train.tsv'
DEV_DATA_PATH = './data/dev.tsv'
MAX_LEN = 512
BATCH_SIZE = 32
def collate_fn(batch_data):
"""
DataLoader所需的collate_fun函数,将数据处理成tensor形式
Args:
batch_data: batch数据
Returns:
"""
input_ids_list, token_type_ids_list, token_type_ids_for_mask_list, labels_list = [], [], [], []
for instance in batch_data:
# 按照batch中的最大数据长度,对数据进行padding填充
input_ids_temp = instance["input_ids"]
token_type_ids_temp = instance["token_type_ids"]
token_type_ids_for_mask_temp = instance["token_type_ids_for_mask"]
labels_temp = instance["labels"]
input_ids_list.append(torch.tensor(input_ids_temp, dtype=torch.long))
token_type_ids_list.append(torch.tensor(token_type_ids_temp, dtype=torch.long))
token_type_ids_for_mask_list.append(torch.tensor(token_type_ids_for_mask_temp, dtype=torch.long))
labels_list.append(torch.tensor(labels_temp, dtype=torch.long))
# 使用pad_sequence函数,会将list中所有的tensor进行长度补全,补全到一个batch数据中的最大长度,补全元素为padding_value
return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=0),
"token_type_ids": pad_sequence(token_type_ids_list, batch_first=True, padding_value=0),
"token_type_ids_for_mask": pad_sequence(token_type_ids_for_mask_list, batch_first=True, padding_value=-1),
"labels": pad_sequence(labels_list, batch_first=True, padding_value=-100)}
class BertDataset(tud.Dataset):
def __init__(self, data_path):
super(BertDataset, self).__init__()
self.data_set = []
with open (data_path, 'r', encoding='utf8') as rf:
r = csv.reader(rf, delimiter='\t')
next(r)
for row in r:
summary = row[0]
content = row[1]
input_ids, token_type_ids, token_type_ids_for_mask, labels = Tokenizer.encode(content, summary, max_length=MAX_LEN)
self.data_set.append({"input_ids": input_ids,
"token_type_ids": token_type_ids,
"token_type_ids_for_mask": token_type_ids_for_mask,
"labels": labels})
def __len__(self):
return len(self.data_set)
def __getitem__(self, idx):
return self.data_set[idx]
traindataset = BertDataset(TRAIN_DATA_PATH)
traindataloader = tud.DataLoader(traindataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valdataset = BertDataset(DEV_DATA_PATH)
valdataloader = tud.DataLoader(valdataset, BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
# for batch in valdataloader:
# print(batch["input_ids"])
# print(batch["input_ids"].shape)
# print('------------------')
# print(batch["token_type_ids"])
# print(batch["token_type_ids"].shape)
# print('------------------')
# print(batch["token_type_ids_for_mask"])
# print(batch["token_type_ids_for_mask"].shape)
# print('------------------')
# print(batch["labels"])
# print(batch["labels"].shape)
# print('------------------')