
将数据集按比例划分为 train、test、val。
对平行语料处理后如下图所示:
步骤:- 随机打乱数据集
- 划分数据集
- 划分平行语料
import os
import random
def data_split(config, file, train_ratio=0.98, shuffle=True):
"""
:param config: 数据文件所在的文件夹名
:param file: 要处理数据的文件名(全称)
:param train_ratio: 训练集占比
:param shuffle: 是否打乱
:return: None
"""
with open(os.path.join(config, file), 'r', encoding='utf-8') as fp: # 用拼接config与file
lines = fp.read().strip().split('n')
n = len(lines)
if shuffle:
random.shuffle(lines) # 随机打乱数据集
train_len = int(n * train_ratio)
val_len = int(n * (1 - train_ratio) / 2)
train_data = lines[:train_len] # 训练集
val_data = lines[train_len:(train_len + val_len + 1)] # 验证集
test_data = lines[(train_len + val_len + 1):] # 测试集
train = 'train.txt'
test = 'test.txt'
val = 'val.txt'
s_config = os.path.join(config, 'dataset') # 划分后数据集存放位置
with open(os.path.join(config, train), 'w', encoding='utf-8') as fp:
fp.write("n".join(train_data))
para_divide(s_config, train)
with open(os.path.join(config, test), 'w', encoding='utf-8') as fp:
fp.write("n".join(test_data))
para_divide(s_config, test)
with open(os.path.join(config, val), 'w', encoding='utf-8') as fp:
fp.write("n".join(val_data))
para_divide(s_config, val)
print('总共有数据:{}条'.format(n))
print('训练集:{}条'.format(len(train_data)))
print('测试集:{}条'.format(len(test_data)))
print('验证集:{}条'.format(len(val_data)))
def para_divide(config, data): # 划分平行语料
f_data = open(os.path.join(config[:4], data), 'r', encoding='utf-8', errors='ignore') # config[:4]='data'
en = open(os.path.join(config, (data[:-4]+'.en')), 'w', encoding='utf-8') # data[:-4]去除文件后缀
zh = open(os.path.join(config, (data[:-4]+'.zh')), 'w', encoding='utf-8')
line = f_data.readline()
while line:
l, r = line.strip().split('t') # 按t划分
l = l.strip() + 'n'
r = r.strip() + 'n'
en.writelines(l)
zh.writelines(r)
line = f_data.readline()
en.close()
zh.close()
f_data.close()
def main():
data_split("data", "en-zh.txt")
if __name__ == '__main__':
main()
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)