Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_minibatch(lines, index, batch_size, word2ind, max_len, add_start=False, add_end=True):
- """Prepare minibatch."""
- if add_start and add_end:
- lines = [
- ['<s>'] + line + ['</s>']
- for line in lines[index:index + batch_size]
- ]
- elif add_start and not add_end:
- lines = [
- ['<s>'] + line
- for line in lines[index:index + batch_size]
- ]
- elif not add_start and add_end:
- lines = [
- line + ['</s>']
- for line in lines[index:index + batch_size]
- ]
- lines = [line[:max_len] for line in lines]
- lens = [len(line) for line in lines]
- max_len = max(lens)
- input_lines = np.array([
- [word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[:-1]] +
- [word2ind['<pad>']] * (max_len - len(line))
- for line in lines
- ]).astype(np.int32)
- output_lines = np.array([
- [word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[1:]] +
- [word2ind['<pad>']] * (max_len - len(line))
- for line in lines
- ]).astype(np.int32)
- mask = np.array(
- [
- ([1] * (l - 1)) + ([0] * (max_len - l))
- for l in lens
- ]
- ).astype(np.float32)
- return input_lines, output_lines, lens, mask
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement