今回はPytorch初学者が苦戦することが多いDatasetとtransformsの自作方法をご紹介します。今回は例としてRNNを用いた自動文章生成のようなタスク向けのデータセットを作成します。自動文章生成では以下のように学習を進めていきます。

なので正解ラベルとして入力データを1つずらしたデータが必要になります。
例: 入力: ["i", "study", "very", "hard"]→正解ラベル: ["study", "very", "hard", "."]

サンプルとして以下の文のリストを使用します。

sentences = [["i", "am", "so", "happy", "."],
["you", "say", "hello", "and", "i", "say", "hi", "."],
["i", "work", "very", "hard", "every", "day", "."],
["my", "dog", "is", "so", "cute", "."]]

必要なライブラリ等を読み込んでおきます。

import torch
from torch.utils.data import DataLoader
from tensorflow.keras.preprocessing.sequence import pad_sequences

単語をidに変換するtransformsを作成

ネットワークで文章を学習するためには単語をidに変換する必要があります。今回はデータの前処理を行うtransformsを自作してid化の処理を行ってみます。本来この処理はtransforms自作して行う必要はありませんが今回は練習のために自作してみます。

公式のチュートリアルによるとtransformsを実装するためには予め用意されているtransformsの動作に習うために「コール可能なクラス」として実装する必要があるので、__call__を必ず含むクラスを定義します。今回は以下のようなクラスを定義します。

class EncoderDecoder(object):
    def __init__(self, sentences, bos=False, eos=False):
        # word_to_idの辞書
        self.w2i = {}
        # id_to_wordの辞書
        self.i2w = {}
        # 文頭の記号を入れるか
        self.bos = bos
        # 文末の記号を入れるか
        self.eos = eos
        # 予約語(パディング, 文章の始まり)
        self.special_chars = ['<pad>', '<s>', '</s>', '<unk>']
        self.bos_char = self.special_chars[1]
        self.eos_char = self.special_chars[2]
        self.oov_char = self.special_chars[3]
        # 全ての単語を読み込んで辞書作成
        self.fit(sentences)
    # コールされる関数
    def __call__(self, sentence):
        return self.transform(sentence)
    # 辞書作成
    def fit(self, sentences):
        self._words = set()
        # 未知の単語の集合を作成する
        for sentence in sentences:
            self._words.update(sentence)
        # 予約語分ずらしてidを振る
        self.w2i = {w: (i + len(self.special_chars))
                    for i, w in enumerate(self._words)}
        # 予約語を辞書に追加する(<pad>:0, <s>:1, </s>:2, <unk>:3)
        for i, w in enumerate(self.special_chars):
            self.w2i[w] = i
        # word_to_idの辞書を用いてid_to_wordの辞書を作成する
        self.i2w = {i: w for w, i in self.w2i.items()}
    # 1文をidに変換する
    def transform(self, sentence):
        # 指定があれば始まりと終わりの記号を追加する
        if self.bos:
            sentence = [self.bos_char] + sentence
        if self.eos:
            sentence = sentence + [self.eos_char]
        output = self.encode(sentence)
        return output
    # 1文ずつidにする
    def encode(self, sentence):
        output = []
        for w in sentence:
            if w not in self.w2i:
                idx = self.w2i[self.oov_char]
            else:
                idx = self.w2i[w]
            output.append(idx)
        return output
    # # 1文ずつ単語リストに直す
    # def decode(self, sentence):
    #     return [self.i2w[id] for id in sentence]

以下のように使うことができます。eos=Trueとbos=Trueを指定しているので文頭に1, 文末に2が追加されています。

transform = EncoderDecoder(sentences, bos=True, eos=True)
transform(["i", "am", "so", "happy", "."])
[1, 17, 20, 16, 21, 5, 2]

ちなみに今回は使用しないのでコメントアウトしてありますがtransform.decodeを使うと以下のようにidを単語に戻すことができます。

Datasetを作成

先ほど作成したtransformsを引数にとってデータセットを作成することができるDatasetを作成します。Datasetを実装するためには

  • torch.utils.data.Datasetを継承
  • __len__と、__getitem__を実装
    する必要があります。

__len__

__len__はlen(Dataset)が実行された時に呼ばれます。今回は引数のデータの長さを返せば問題ないです。

__getitem__

__getitem__はDataset[0]というようにインデックスが指定された時に呼ばれます。ここで前処理をしたデータを返すようにします。

また今回は、transformsが指定されている場合、以下の処理も行います。

  • 0でパディングをして文章の長さを合わせる、長すぎる文章を短くする
  • 整数の配列なのでLongTensor型に変換する
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None, max_length=5):
        self.transform = transform
        self.data_num = len(data)
        # データを1ずつずらす
        self.x = [d[:-1] for d in data]
        self.y = [d[1:] for d in data]
        # パディングして合わせる長さ
        self.max_length = max_length
    def __len__(self):
        return self.data_num
    def __getitem__(self, idx):
        out_data = self.x[idx]
        out_label =  self.y[idx]
        # transformが指定されていたら単語をidにする
        if self.transform:
            out_data = self.transform(out_data)
            out_label = self.transform(out_label)
            # パディングして長さを合わせる
            out_data = pad_sequences([out_data], padding='post', maxlen=self.max_length)[0]
            out_label = pad_sequences([out_label], padding='post', maxlen=self.max_length)[0]
            # LongTensor型に変換する
            out_data = torch.LongTensor(out_data)
            out_label = torch.LongTensor(out_label)
        return out_data, out_label

まずは、transformsを指定せずに実行してみると、入力データと正解ラベルが一つずつずれたデータができていることがわかります。

dataset = MyDataset(sentences)
for d in dataset:
    print("x: {}, y: {}".format(d[0], d[1]))
x: ['i', 'am', 'so', 'happy'], y: ['am', 'so', 'happy', '.']
x: ['you', 'say', 'hello', 'and', 'i', 'say', 'hi'], y: ['say', 'hello', 'and', 'i', 'say', 'hi', '.']
x: ['i', 'work', 'very', 'hard', 'every', 'day'], y: ['work', 'very', 'hard', 'every', 'day', '.']
x: ['my', 'dog', 'is', 'so', 'cute'], y: ['dog', 'is', 'so', 'cute', '.']

次にtransformsを指定して実行すると、全ての処理が正常に行われていることがわかります。

transform = EncoderDecoder(sentences)
dataset = MyDataset(sentences, transform=transform)
for d in dataset:
    print("x: {}, y: {}".format(d[0], d[1]))
x: tensor([17, 20, 16, 21,  0]), y: tensor([20, 16, 21,  5,  0])
x: tensor([19,  7, 17, 18, 12]), y: tensor([ 7, 17, 18, 12,  5])
x: tensor([10,  9, 22,  4, 11]), y: tensor([ 9, 22,  4, 11,  5])
x: tensor([ 6, 13, 14, 16, 15]), y: tensor([13, 14, 16, 15,  5])

ここでパディングする際にtensorflow.keras.preprocessing.sequence.pad_sequenceを使用していますが、pytorchにも系列長を揃えてくれるtorch.nn.utils.rnn.pad_sequenceが存在します。しかし、pytorchのpad_sequenceよりもkerasのものの方が指定できる引数が多いため、使い勝手がいいかと思います。

作成したDatasetをDataloaderでバッチ単位にする

最後にDatasetをDataloaderでバッチ単位にすれば終わりです。これまで作成したクラスさえ定義してしまえば、以下の3行でデータを準備することができます。今回作成したデータは(バッチサイズ, データ数)の次元になっているので、RNNやLSTMで学習する際はbatch_first=Trueにするのを忘れないようにしてください。

transform = EncoderDecoder(sentences)
dataset = MyDataset(sentences, transform=transform)
data_loader = DataLoader(dataset, batch_size=2)
for x, y in data_loader:
    print("x: {}".format(x))
    print("y: {}".format(y))
    break
x: tensor([[17, 20, 16, 21,  0],
        [19,  7, 17, 18, 12]])
y: tensor([[20, 16, 21,  5,  0],
        [ 7, 17, 18, 12,  5]])