with zipfile.ZipFile('../data/jaychou_lyrics.txt.zip') as zin: with zin.open('jaychou_lyrics.txt') as f: corpus_chars = f.read().decode('utf-8') corpus_chars[:40]
idx_to_char = list(set(corpus_chars)) #去重复,求得不重复的字符 char_to_idx = dict([(char, i)for i, char in enumerate(idx_to_char)])# 将字符与索引一一映射构造字典 vocab_size = len(char_to_idx) print(vocab_size) print(char_to_idx['有']) #print(idx_to_char)
1027
596
1 2 3 4
corpus_indices = [char_to_idx[char] for char in corpus_chars] sample = corpus_indices[:20] print('chars:', ''.join([idx_to_char[idx] for idx in sample])) print('indices:', sample)
for i in range(epoch_size): # 每次读取batch_size个随机样本 i = i * batch_size batch_indices = example_indices[i: i + batch_size] X = [_data(j * num_steps) for j in batch_indices] Y = [_data(j * num_steps + 1) for j in batch_indices] yield nd.array(X, ctx), nd.array(Y, ctx)
1 2 3
my_seq = list(range(40)) for X, Y in data_iter_random(my_seq, batch_size=2, num_steps=6): print('X: ', X, '\nY:', Y, '\n')
defdata_iter_consecutive(corpus_indices, batch_size, num_steps, ctx=None): corpus_indices = nd.array(corpus_indices, ctx=ctx) data_len = len(corpus_indices) #字符长度 batch_len = data_len // batch_size#可以有多少个小批量 indices = corpus_indices[0: batch_size*batch_len].reshape((batch_size, batch_len))#转换维度batch_size行,batch_len列 epoch_size = (batch_len - 1) // num_steps for i in range(epoch_size): i = i * num_steps X = indices[:, i: i + num_steps] Y = indices[:, i + 1: i + num_steps + 1] yield X, Y
1 2
for X, Y in data_iter_consecutive(my_seq, batch_size=2, num_steps=6): print('X: ', X, '\nY:', Y, '\n')