chainerでattentionモデルを実装

chainerで実装してみた

こちらを参考に,実装してみた。


train_pair = [
    ["初めまして。", "初めまして。よろしくお願いします。"],
    ["どこから来たんですか?", "日本から来ました。"],
    ["日本のどこに住んでるんですか?", "東京に住んでいます。"],
    ["仕事は何してますか?", "私は会社員です。"],
    ["お会いできて嬉しかったです。", "私もです!"],
    ["おはよう。", "おはようございます。"],
    ["いつも何時に起きますか?", "6時に起きます。"],
    ["朝食は何を食べますか?", "たいていトーストと卵を食べます。"],
    ["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
    ["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"],
    ["週末は何をしていますか?", "友達と会っていることが多いです。"],
    ["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"]
]
test_pair = [
    ["初めまして。", "初めまして。よろしくお願いします。"],
    ["どこから来たんですか?", "米国から来ました。"],
    ["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"],
    ["おはよう。", "おはよう。"],
    ["いつも何時に起きますか?", "7時に起きます。"],
    ["夕食は何を食べますか?", "たいていトーストと卵を食べます。"],
    ["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
    ["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"],
    ["週頭は何をしていますか?", "友達と会っていることが多いです。"],
]



# https://nojima.hatenablog.com/entry/2017/10/10/023147


import MeCab
import chainer
import chainer.links as L
import chainer.functions as F
import random
import numpy as np

SIZE=10000
EOS=1
UNK=0


class EncoderDecoder(chainer.Chain):
    def __init__(self, n_layer,n_vocab, n_out, n_hidden,dropout):
        super(EncoderDecoder,self).__init__()

        with self.init_scope():
            self.embed_x = L.EmbedID(n_vocab, n_hidden)
            self.embed_y = L.EmbedID(n_out,n_hidden)

            self.encoder = L.NStepLSTM(
                n_layers=n_layer,
                in_size=n_hidden,
                out_size=n_hidden,
                dropout=dropout)
            self.decoder = L.NStepLSTM(
                n_layers=n_layer,
                in_size=n_hidden,
                out_size=n_hidden,
                dropout=dropout)

            self.W_C = L.Linear(2*n_hidden, n_hidden)
            self.W_D = L.Linear(n_hidden, n_out)

    def __call__(self, xs , ys ):
        xs = [x[::-1] for x in xs]

        eos = self.xp.array([EOS], dtype=np.int32)
        ys_in = [F.concat((eos, y), axis=0) for y in ys]
        ys_out = [F.concat((y, eos), axis=0) for y in ys]

        # Both xs and ys_in are lists of arrays.
        exs = [self.embed_x(x) for x in xs]
        eys = [self.embed_y(y) for y in ys_in]

        # hx:dimension x batchsize
        # cx:dimension x batchsize
        # yx:batchsize x timesize x dimension
        hx, cx, yx = self.encoder(None, None, exs)  # yxに全T方向ステップのyの出力,数はxsの長さと同じ,バッチごとにバラバラ

        _, _, os = self.decoder(hx, cx, eys)

        loss=0
        for o,y,ey in zip(os,yx,ys_out): # バッチごとに処理
            op=self._calculate_attention_layer_output(o,y)
            loss+=F.softmax_cross_entropy(op,ey)
        loss/=len(yx)

        chainer.report({'loss': loss}, self)
        return loss

    def _calculate_attention_layer_output(self, embedded_output, attention):
        inner_prod = F.matmul(embedded_output, attention, transb=True)
        weights = F.softmax(inner_prod)
        contexts = F.matmul(weights, attention)
        concatenated = F.concat((contexts, embedded_output))
        new_embedded_output = F.tanh(self.W_C(concatenated))
        return self.W_D(new_embedded_output)

    def translate(self,xs,max_length=30):
        with chainer.no_backprop_mode(),chainer.using_config("train",False):
            xs=xs[::-1] # reverse list
            #exs = [self.embed_x(x) for x in xs]
            exs = self.embed_x(xs)
            hx, cx, yx = self.encoder(None, None, [exs])

            predicts=[]
            eos = self.xp.array([EOS], dtype=np.int32)
            # EOSだけ入力,あとは予想した出力を入力にして繰り返す
            for y in yx:  # バッチ単位
                predict=[]
                ys_in=[eos]
                for i in range(max_length):
                    eys = [self.embed_y(y) for y in ys_in]
                    _, _, os = self.decoder(hx, cx, eys)
                    op=self._calculate_attention_layer_output(os[0], y)
                    word_id=int(F.argmax(F.softmax(op)).data) # 単語IDに戻す

                    if word_id == EOS:break
                    predict.append(word_id)
                    ys_in=[self.xp.array([word_id], dtype=np.int32)]
                predicts.append(np.array(predict))
            return predict

class Data(chainer.dataset.DatasetMixin):
    def __init__(self):
        mecab = MeCab.Tagger("-Owakati")

        self.vocab={"eos":0,"unk":1}
        def to_dataset(source,target,train=True):
            swords = to_number(mecab.parse(source).strip().split(" "),train)
            twords = to_number(mecab.parse(target).strip().split(" "),train)
            return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32))

        def to_number(words,train):
            ds=[]
            for w in words:
                if w not in self.vocab:
                    if train:
                        self.vocab[w]=len(self.vocab)
                    else:
                        w="unk"
                ds.append(self.vocab[w])
            return ds

        self.train_data=[]
        self.test_data=[]
        for source,target in train_pair:
            self.train_data.append(to_dataset(source,target))

        for source,target in test_pair:
            self.test_data.append(to_dataset(source,target,False))


        self.vocab_inv={}
        for w in self.vocab.keys():
            self.vocab_inv[self.vocab[w]]=w

def convert(batch, device):
    def to_device_batch(batch):
        return [chainer.dataset.to_device(device, x) for x in batch]

    res= {'xs': to_device_batch([x for x, _ in batch]),
            'ys': to_device_batch([y for _, y in batch])}
    return res




seed = 12345
random.seed(seed)
np.random.seed(seed)

data = Data()

batchsize=5
train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize)
test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data))


n_vocab=len(data.vocab)
n_out=len(data.vocab)
n_hidden=300
n_layer=1
dropout=0.3
print("n_vocab:",n_vocab)

optimizer=chainer.optimizers.Adam()

mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
optimizer.setup(mlp)

updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1)

#train
epochs=20
trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result")
trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))

trainer.run()
mlp.to_cpu()
chainer.serializers.save_npz("attention.model",mlp)




# https://nojima.hatenablog.com/entry/2017/10/17/034840
mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
chainer.serializers.load_npz("attention.model",mlp,path="")

for source,target in data.test_data:
    predict=mlp.translate(np.array(source))
    print("-----")
    print("source:",[data.vocab_inv[w] for w in source])
    print("predict:",[data.vocab_inv[w] for w in predict])
    print("target:",[data.vocab_inv[w] for w in target])


実行結果

n_vocab: 71
epoch       main/loss   main/accuracy
1           4.00112                    
2           3.148                      
3           2.33681                    
4           1.62404                    
5           1.30338                    
6           0.93243                    
7           0.6091                     
8           0.3701                     
9           0.233166                   
10          0.202335                   
11          0.119416                   
12          0.0804442                  
13          0.0629114                  
14          0.0467078                  
15          0.032828                   
16          0.0285745                  
17          0.0225082                  
18          0.0183743                  
19          0.0152494                  
20          0.0140195                  
-----
source: ['初め', 'まして', '。']
predict: ['初め', 'まして', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします']
target: ['初め', 'まして', '。', 'よろしくお願いします', '。']
-----
source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?']
predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来']
target: ['unk', 'から', '来', 'まし', 'た', '。']
-----
source: ['unk', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?']
predict: ['東京', 'に', '住ん', 'で', 'い', 'ます', '。']
target: ['unk', 'に', '住ん', 'で', 'い', 'ます', '。']
-----
source: ['おはよう。']
predict: ['おはよう', 'ござい', 'ます', '。']
target: ['おはよう。']
-----
source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?']
predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に']
target: ['unk', '時', 'に', '起き', 'ます', '。']
-----
source: ['unk', 'は', '何', 'を', '食べ', 'ます', 'か', '?']
predict: ['たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と']
target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。']
-----
source: ['unk', 'は', '毎日', '食べ', 'ます', 'か', '?']
predict: ['たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く']
target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。']
-----
source: ['unk', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?']
predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。']
target: ['毎日', 'unk', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。']
-----
source: ['unk', 'unk', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?']
predict: ['友達', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と']
target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。']

あまり,結果が良くないが,訓練データを増やし,Epochを増やすと良くなるかもしれない。