chainerでCRFを用いて固有表現抽出

このあたりを参考に書き直してみました。
CRFなどの解説はリンク先を参照のこと

仕様

単語列から名詞の塊を抜き出す。
英語で複合名詞など複数の名詞で構成される名詞の塊にフラグをつける。
下記の例では、Oが名詞以外、Bが名詞の開始位置、Iが複合名詞の2個目以降を示している。

 the wall street journal reported today that apple corporation made money 
O    B   I      I      I      O        O    O     B     I          O    O      O

これを、Linearの1層で学習し、CRFで出力する

コード



import numpy as np
import argparse
import os
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, optimizers, initializers,reporter
from chainer.training import extensions
import random
from itertools import chain


# my model
class IconDetector(chainer.Chain):

    def __init__(self, n_vocab, window, n_label, n_unit):
        super().__init__()
        with self.init_scope():
            self.embed=L.EmbedID(n_vocab, n_unit)
            self.lin=L.Linear(n_unit , n_label)
            self.crf=L.CRF1d(n_label=n_label)
        self.window=window

    def forward(self, xs):
        l = xs.shape[1]
        ys = []
        # 1 wordづつ
        for i in range(l):
            x = self.embed(xs[:,i])
            h = F.tanh(x)
            y = self.lin(h)
            ys.append(y)
        return ys #[window, batchsize, n_label]

    def __call__(self, xs, ts):
        """error function"""
        ys = self.forward(xs)
        ts = [ts[:, i] for i in range(ts.data.shape[1])] # [window,batchsize]
        loss = self.crf(ys, ts)
        reporter.report({'loss': loss}, self)
        return loss

    def predict(self, xs):
        ts = self.forward(xs)
        _, ys = self.crf.argmax(ts)
        return ys

class WindowIterator(chainer.dataset.Iterator):

    def __init__(self, text, label, window, batch_size, shuffle= True,repeat=True):
        self.text = np.asarray(text, dtype=np.int32)
        self.label = np.asarray(label, dtype=np.int32)
        self.window = window
        self.batch_size = batch_size
        self._repeat = repeat
        self._shuffle=shuffle

        if self._shuffle:
            self.order = np.random.permutation(
                len(text) - window ).astype(np.int32)

        else:
            self.order=np.array(list(range(len(text) - window )))
        self.current_position = 0
        self.epoch = 0
        self.is_new_epoch = False

    def __next__(self):
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        i = self.current_position
        i_end = i + self.batch_size
        position = self.order[i: i_end]
        offset = np.concatenate([np.arange(0, self.window )])
        pos = position[:, None] + offset[None, :]
        context = self.text.take(pos)
        doc = self.label.take(pos)

        if i_end >= len(self.order):
            np.random.shuffle(self.order)
            self.epoch += 1
            self.is_new_epoch = True
            self.current_position = 0
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return  context, doc

    @property
    def epoch_detail(self):
        return self.epoch + float(self.current_position) / len(self.order)

@chainer.dataset.converter()
def convert(batch, device):
    context, doc = batch
    xp = device.xp
    doc = xp.asarray(doc)
    context = xp.asarray(context)
    return context, doc


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', '-d', type=str, default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--window', '-w', default=5, type=int,
                        help='window size')
    parser.add_argument('--batchsize', '-b', type=int, default=2,
                        help='learning minibatch size')
    parser.add_argument('--epoch', '-e', default=100, type=int,
                        help='number of epochs to learn')
    parser.add_argument('--out', default=os.environ["WORK"]+"/summarization/sm_icgw/result",
                        help='Directory to output the result')

    args = parser.parse_args()

    random.seed(12345)
    np.random.seed(12345)

    device = chainer.get_device(args.device)

    training_data = [(
        " the wall street journal reported today that apple corporation made money ".split(),
        "O B I I I O O O B I O O O".split()
    ), (
        " georgia tech is a university in georgia ".split(),
        "O B I O O O O B O".split()
    )]

    validation_data = [(' georgia tech reported today '.split(),"O B I O O O".split())]

    word2id={"":0,"":1,"":2,"":3}
    label2id={"B":0,"I":1,"O":2}

    def get_dataset(data,is_train=True):
        texts=[]
        labels=[]
        for word,attrib in data:
            for w,a in zip(word,attrib):
                if w not in word2id:
                    if is_train:
                        word2id[w]=len(word2id)
                    else:
                        w=""
                texts.append(word2id[w])
                labels.append(label2id[a])
        return texts,labels

    train_text,train_label=get_dataset(training_data)
    valid_text,valid_label=get_dataset(validation_data,False)

    n_vocab=len(word2id)
    n_label=len(label2id)
    n_unit=n_vocab//2
    model=IconDetector(n_vocab=n_vocab, window=args.window,n_label=n_label,n_unit=n_unit)

    model.to_device(device)
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train_iter = WindowIterator(train_text, train_label, args.window, args.batchsize)
    valid_iter = WindowIterator(valid_text, valid_label, args.window,args.batchsize, repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    trainer.extend(extensions.LogReport(),trigger=(10, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss','validation/main/loss']))
    #trainer.extend(extensions.ProgressBar())

    trainer.extend(extensions.Evaluator(valid_iter, model,converter=convert, device=device),trigger=(10, 'epoch'))
    trainer.run()

    # testing
    testing_data = [(' the street journal is a university '.split(),"O B I I O O O O".split())]
    test_text,test_label=get_dataset(testing_data,False)
    with chainer.using_config('train', False), \
            chainer.using_config('enable_backprop', False):
        ys=model.predict(np.array([test_text],dtype=np.int32))

        ys=list(chain.from_iterable(ys))
        print(ys)
        print(test_label)

if __name__ == '__main__':
    main()

実行結果

epoch       main/loss   validation/main/loss
10          6.24186     7.77101               
20          5.43089     6.29938               
30          3.16409     5.09233               
40          2.68146     4.02053               
50          1.47248     3.07853               
60          1.63073     2.28975               
70          1.13568     1.67798               
80          0.842947    1.22796               
90          0.364242    0.912943              
100         0.658557    0.692736              
[2, 0, 1, 1, 2, 2, 2, 2]
[2, 0, 1, 1, 2, 2, 2, 2]

ちゃんと学習できているようです

LSTMでSIN波を学習させる

このあたりを参考に書き直してみました。

プログラム


import chainer
import chainer.links as L
import chainer.functions as F
import numpy as np
from chainer import optimizers
from chainer.datasets import tuple_dataset
from chainer import reporter
import matplotlib.pyplot as plt

class MyClassifier(chainer.Chain):

    def __init__(self, seqlen, out_units=1, dropout=0.1):
        super(MyClassifier, self).__init__()
        with self.init_scope():
            self.encoder = L.LSTM(seqlen, out_units)
        self.dropout = dropout

    def forward(self, xs, ys):
        concat_outputs =F.concat( self.predict(xs), axis=0)
        concat_truths = F.concat(ys, axis=0)

        loss = F.mean_squared_error(concat_outputs, concat_truths)
        reporter.report({'loss': loss}, self)
        return loss

    def predict(self, xs, softmax=False, argmax=False):
        concat_encodings = F.dropout(self.encoder(xs), ratio=self.dropout)
        if softmax:
            return F.softmax(concat_encodings).array
        elif argmax:
            return self.xp.argmax(concat_encodings.array, axis=1)
        else:
            return concat_encodings

    def reset_state(self):
        self.encoder.reset_state()


# main
if __name__ == "__main__":

    # データ
    def get_dataset(N,st=0,ed=2*np.pi):
        xval = np.linspace(st,ed, N)  # x軸の値
        y = np.sin(xval)              # y軸の値
        return y,xval

    # 0-2piまでを100等分
    N = 100
    train_ds,train_xval = get_dataset(N)

    # valid
    N_valid = 50
    valid_ds,valid_xval = get_dataset(N_valid,2*np.pi,3*np.pi)


    # 学習パラメータ
    batchsize = 10
    n_epoch = 1000
    seqlen=10
    gpu = -1

    def seq_split(seqlen,ds,xval):
        X=[]    # seqlen 分Y軸の値を取り出し
        y=[]    # seqlen+1のY軸の値を正解として格納
        ret_xval=[]   # その時のX軸の値
        for i in range(len(ds)-seqlen-1):
            X.append(np.array(ds[i:i+seqlen]))
            y.append(np.array([ds[i+seqlen]]))
            ret_xval.append(xval[i+seqlen])
        return X,y,ret_xval


    X_train,y_train,_=seq_split(seqlen,train_ds,train_xval)
    train=tuple_dataset.TupleDataset(np.array(X_train, dtype=np.float32), np.array(y_train, dtype=np.float32))

    X_valid,y_valid,_=seq_split(seqlen,valid_ds,valid_xval)
    valid=tuple_dataset.TupleDataset(np.array(X_valid, dtype=np.float32), np.array(y_valid, dtype=np.float32))

    # モデル作成
    model = MyClassifier(seqlen=seqlen)
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    valid_iter = chainer.iterators.SerialIterator(valid, batchsize, repeat=False, shuffle=False)

    updater = chainer.training.updaters.StandardUpdater(train_iter, optimizer, device=gpu)
    trainer = chainer.training.Trainer(updater, (n_epoch, 'epoch'), out="snapshot")

    trainer.extend(chainer.training.extensions.Evaluator(valid_iter, model, device=gpu))
    trainer.extend(chainer.training.extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = 10
    trainer.extend(chainer.training.extensions.snapshot(filename="snapshot_cureent"), trigger=(frequency, 'epoch'))

    trainer.extend(chainer.training.extensions.LogReport(trigger=(10, "epoch")))
    trainer.extend(chainer.training.extensions.PrintReport(
        ["epoch", "main/loss", "validation/main/loss", "elapsed_time"]))
    # trainer.extend(chainer.training.extensions.ProgressBar())
    trainer.extend(chainer.training.extensions.PlotReport(['main/loss', 'validation/main/loss'], trigger=(10, "epoch"),
                                                          filename="loss.png"))
    trainer.run()


    # テストデータ
    N_test = 50
    test_ds,test_xval = get_dataset(N_test,3*np.pi,4*np.pi)
    X_test,y_test,x=seq_split(seqlen,test_ds,test_xval)

    with chainer.using_config("train", False), chainer.using_config('enable_backprop', False):
        model_predict=[]
        for xt in X_test:
            model_predict.append(model.predict(np.array([xt],dtype=np.float32)))
        y=F.concat(y_test,axis=0).data
        p=F.concat(model_predict, axis=0).data
        plt.plot(x,y,color="b")
        plt.plot(x,p,color="r")
        plt.show()

xが0から2πまでをN=100等分し、その時のyの値を学習します

出力

epoch       main/loss   validation/main/loss  elapsed_time
10          0.412313    0.22419               0.680705      
20          0.34348     0.176668              1.45632       
30          0.322577    0.157206              2.21422       
40          0.313264    0.142921              3.01513       
50          0.286482    0.156473              3.79678       
60          0.271693    0.145663              4.5567        
70          0.26161     0.153716              5.27208       
80          0.249188    0.148146              6.01087       
90          0.229289    0.135516              6.7483        
100         0.235577    0.135556              7.4885        
110         0.210943    0.122771              8.23302       
120         0.206481    0.125734              8.96983       
130         0.199006    0.110476              9.70469       
140         0.185752    0.105235              10.443        
150         0.170962    0.0997412             11.2183       
160         0.17236     0.0935808             11.9595       
170         0.159314    0.08463               12.7018       
180         0.159687    0.0768325             13.4395       
190         0.137582    0.0830929             14.18         
200         0.146723    0.0774658             14.9236       
210         0.144295    0.0731078             15.6569       
220         0.137017    0.0640324             16.3937       
230         0.120282    0.0675966             17.1413       
240         0.122946    0.065422              17.886        
250         0.112485    0.0601812             18.6319       
260         0.115785    0.0622176             19.3996       
270         0.108337    0.0572397             20.1444       
280         0.113232    0.0529092             20.8836       
290         0.111058    0.0531542             21.6315       
300         0.0978656   0.0537841             22.367        
310         0.107318    0.0500048             23.1059       
320         0.100115    0.0471024             23.8513       
330         0.0941351   0.0466782             24.6336       
340         0.0992865   0.0462204             25.4045       
350         0.0909913   0.0444264             26.1611       
360         0.0862495   0.04312               26.9144       
370         0.0855708   0.0425527             27.7444       
380         0.0872798   0.0402729             28.5102       
390         0.0902846   0.0412268             29.265        
400         0.0840953   0.0421445             30.0178       
410         0.0782678   0.0389577             30.7646       
420         0.0813485   0.0383136             31.5095       
430         0.0892697   0.0380051             32.2756       
440         0.0784913   0.0378043             33.0377       
450         0.077315    0.0382185             33.7859       
460         0.0797667   0.0370378             34.5328       
470         0.085258    0.0353085             35.2766       
480         0.0904794   0.0364417             36.0569       
490         0.0760617   0.0347273             36.7939       
500         0.0776822   0.0349566             37.5342       
510         0.0753997   0.0348118             38.2753       
520         0.084298    0.0351735             39.0137       
530         0.0827682   0.0336927             39.7616       
540         0.0885914   0.032941              40.4968       
550         0.0755433   0.0328157             41.2496       
560         0.0864087   0.0329004             41.9931       
570         0.0696817   0.0333915             42.7455       
580         0.0673041   0.0310085             43.4925       
590         0.0695492   0.0322058             44.2814       
600         0.0626493   0.0321012             45.0249       
610         0.0780004   0.0311171             45.7709       
620         0.0842924   0.0313559             46.5104       
630         0.0869392   0.0310007             47.253        
640         0.0807384   0.0300728             48.0017       
650         0.061898    0.0308209             48.799        
660         0.0656733   0.0304086             49.5613       
670         0.0785108   0.0301482             50.3314       
680         0.0748608   0.0305042             51.1063       
690         0.0720685   0.0300295             51.874        
700         0.071668    0.0300219             52.6741       
710         0.0577448   0.0295377             53.4246       
720         0.07059     0.0295795             54.1749       
730         0.0677278   0.0291774             54.9278       
740         0.0747214   0.0296802             55.6766       
750         0.0796093   0.0296162             56.4188       
760         0.0659514   0.0292572             57.1649       
770         0.0651634   0.0294423             57.9178       
780         0.0732363   0.0291244             58.6651       
790         0.0750731   0.0292386             59.54         
800         0.0768709   0.0289816             60.4078       
810         0.0712155   0.0293295             61.1858       
820         0.0613775   0.0290006             62.1138       
830         0.0674119   0.0289881             62.9673       
840         0.0720593   0.0289764             63.7099       
850         0.0653162   0.028784              64.4681       
860         0.0873163   0.0285498             65.2305       
870         0.0674722   0.0286272             65.9901       
880         0.0718548   0.0282714             66.7374       
890         0.0659643   0.028604              67.5049       
900         0.0790577   0.0286806             68.2621       
910         0.0734288   0.028266              69.0192       
920         0.0664547   0.0282677             69.7996       
930         0.0799885   0.0284383             70.5581       
940         0.0687402   0.0284575             71.3119       
950         0.0721005   0.0282444             72.0637       
960         0.0750532   0.0282908             72.8485       
970         0.0663102   0.0282821             73.6094       
980         0.0807248   0.0281474             74.3701       
990         0.0730301   0.0283711             75.155        
1000        0.0674249   0.0282971             75.9349    

テスト結果

青い線が正解で赤い線が学習したのちの予測結果です。
Xの値を与えて、Yを予測した結果ですが、あまりうまく予測できていませんね

chainerでsin関数を学習させてみたをtrainerで書き直す

chainerのサンプルを探していると,古いものではtrainerを使っていないものも結構あります。
そこで,chainerでsin関数を学習させてみた
を書き直してみました。




# https://qiita.com/hikobotch/items/018808ef795061176824
# https://github.com/kose/chainer-linear-regression/blob/master/train.py

# とりあえず片っ端からimport
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import time
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os
# データ
def get_dataset(N):
    x = np.linspace(0, 2 * np.pi, N)
    y = np.sin(x)
    dataset=[]
    for xx,yy in zip(x,y):
        dataset.append((xx,yy))
    return dataset

# ニューラルネットワーク
class MyChain(Chain):
    def __init__(self, n_units=10):
        super(MyChain, self).__init__(
             l1=L.Linear(1, n_units),
             l2=L.Linear(n_units, n_units),
             l3=L.Linear(n_units, 1))

    def __call__(self, x_data, y_data):
        x = Variable(x_data.astype(np.float32).reshape(len(x_data),1)) # Variableオブジェクトに変換
        y = Variable(y_data.astype(np.float32).reshape(len(y_data),1)) # Variableオブジェクトに変換
        loss= F.mean_squared_error(self.predict(x), y)
        chainer.reporter.report({
            'loss': loss
        }, self)

        return loss

    def  predict(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        h3 = self.l3(h2)
        return h3

    def get_predata(self, x):
        return self.predict(Variable(x.astype(np.float32).reshape(len(x),1))).data

from chainer.training import trigger as trigger_module
class MyChartReport(chainer.training.extension.Extension):

    #trigger = (1, 'epoch')

    def __init__(self,model,N_test=900,trigger = (1, 'epoch')):
        self.N_test=N_test
        self.model=model
        self._trigger = trigger_module.get_trigger(trigger)

    def __call__(self,trainer):
        if self._trigger(trainer):
            with chainer.configuration.using_config('train', False):
                theta = np.linspace(0, 2 * np.pi, self.N_test)
                sin = np.sin(theta)
                test = self.model.get_predata(theta)
                plt.plot(theta, sin, label="sin")
                plt.plot(theta, test, label="test")
                plt.legend()
                plt.grid(True)
                plt.xlim(0, 2 * np.pi)
                plt.ylim(-1.2, 1.2)
                plt.title("sin")
                plt.xlabel("theta")
                plt.ylabel("amp")
                plt.savefig("fig_sin_epoch{}.png".format(trainer.updater.epoch)) 
                plt.clf()



# main
if __name__ == "__main__":

    # 学習データ
    N = 1000
    train = get_dataset(N)

    # テストデータ
    N_test = 900
    test = get_dataset(N_test)

    # 学習パラメータ
    batchsize = 10
    n_epoch = 500
    n_units = 100
    gpu=-1

    # モデル作成
    model = MyChain(n_units)
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

    updater = chainer.training.updaters.StandardUpdater(train_iter, optimizer, device=gpu)
    trainer = chainer.training.Trainer(updater, (n_epoch, 'epoch'), out="snapshot")

    trainer.extend(chainer.training.extensions.Evaluator(test_iter, model, device=gpu))
    trainer.extend(chainer.training.extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = 10
    trainer.extend(chainer.training.extensions.snapshot(filename="snapshot_cureent"),trigger=(frequency, 'epoch'))


    # Write a log of evaluation statistics for each epoch
    #trainer.extend(chainer.training.extensions.LogReport())

    trainer.extend(MyChartReport(model,trigger=(10, "epoch")))

    trainer.extend(chainer.training.extensions.LogReport(trigger=(10, "epoch")))
    trainer.extend(chainer.training.extensions.PrintReport(
        ["epoch", "main/loss", "validation/main/loss", "elapsed_time"]))
    #trainer.extend(chainer.training.extensions.ProgressBar())
    trainer.extend(chainer.training.extensions.PlotReport(['main/loss', 'validation/main/loss'],trigger=(10, "epoch"),filename="loss.png"))
    trainer.run()

    model.to_cpu()
    chainer.serializers.save_npz("model.h5",model)


ポイントは,extensionを使って,エポックごとにsinの学習結果を画像で保存しているところです。
結構いい加減に作っているので,画像保存ディレクトリやファイル名などは外から与えるようにするといいでしょう。

chainerでattentionモデルを実装

chainerで実装してみた

こちらを参考に,実装してみた。


train_pair = [
    ["初めまして。", "初めまして。よろしくお願いします。"],
    ["どこから来たんですか?", "日本から来ました。"],
    ["日本のどこに住んでるんですか?", "東京に住んでいます。"],
    ["仕事は何してますか?", "私は会社員です。"],
    ["お会いできて嬉しかったです。", "私もです!"],
    ["おはよう。", "おはようございます。"],
    ["いつも何時に起きますか?", "6時に起きます。"],
    ["朝食は何を食べますか?", "たいていトーストと卵を食べます。"],
    ["朝食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
    ["野菜をたくさん取っていますか?", "毎日野菜を取るようにしています。"],
    ["週末は何をしていますか?", "友達と会っていることが多いです。"],
    ["どこに行くのが好き?", "私たちは渋谷に行くのが好きです。"]
]
test_pair = [
    ["初めまして。", "初めまして。よろしくお願いします。"],
    ["どこから来たんですか?", "米国から来ました。"],
    ["米国のどこに住んでるんですか?", "ニューヨークに住んでいます。"],
    ["おはよう。", "おはよう。"],
    ["いつも何時に起きますか?", "7時に起きます。"],
    ["夕食は何を食べますか?", "たいていトーストと卵を食べます。"],
    ["夕食は毎日食べますか?", "たまに朝食を抜くことがあります。"],
    ["肉をたくさん取っていますか?", "毎日インクを取るようにしています。"],
    ["週頭は何をしていますか?", "友達と会っていることが多いです。"],
]



# https://nojima.hatenablog.com/entry/2017/10/10/023147


import MeCab
import chainer
import chainer.links as L
import chainer.functions as F
import random
import numpy as np

SIZE=10000
EOS=1
UNK=0


class EncoderDecoder(chainer.Chain):
    def __init__(self, n_layer,n_vocab, n_out, n_hidden,dropout):
        super(EncoderDecoder,self).__init__()

        with self.init_scope():
            self.embed_x = L.EmbedID(n_vocab, n_hidden)
            self.embed_y = L.EmbedID(n_out,n_hidden)

            self.encoder = L.NStepLSTM(
                n_layers=n_layer,
                in_size=n_hidden,
                out_size=n_hidden,
                dropout=dropout)
            self.decoder = L.NStepLSTM(
                n_layers=n_layer,
                in_size=n_hidden,
                out_size=n_hidden,
                dropout=dropout)

            self.W_C = L.Linear(2*n_hidden, n_hidden)
            self.W_D = L.Linear(n_hidden, n_out)

    def __call__(self, xs , ys ):
        xs = [x[::-1] for x in xs]

        eos = self.xp.array([EOS], dtype=np.int32)
        ys_in = [F.concat((eos, y), axis=0) for y in ys]
        ys_out = [F.concat((y, eos), axis=0) for y in ys]

        # Both xs and ys_in are lists of arrays.
        exs = [self.embed_x(x) for x in xs]
        eys = [self.embed_y(y) for y in ys_in]

        # hx:dimension x batchsize
        # cx:dimension x batchsize
        # yx:batchsize x timesize x dimension
        hx, cx, yx = self.encoder(None, None, exs)  # yxに全T方向ステップのyの出力,数はxsの長さと同じ,バッチごとにバラバラ

        _, _, os = self.decoder(hx, cx, eys)

        loss=0
        for o,y,ey in zip(os,yx,ys_out): # バッチごとに処理
            op=self._calculate_attention_layer_output(o,y)
            loss+=F.softmax_cross_entropy(op,ey)
        loss/=len(yx)

        chainer.report({'loss': loss}, self)
        return loss

    def _calculate_attention_layer_output(self, embedded_output, attention):
        inner_prod = F.matmul(embedded_output, attention, transb=True)
        weights = F.softmax(inner_prod)
        contexts = F.matmul(weights, attention)
        concatenated = F.concat((contexts, embedded_output))
        new_embedded_output = F.tanh(self.W_C(concatenated))
        return self.W_D(new_embedded_output)

    def translate(self,xs,max_length=30):
        with chainer.no_backprop_mode(),chainer.using_config("train",False):
            xs=xs[::-1] # reverse list
            #exs = [self.embed_x(x) for x in xs]
            exs = self.embed_x(xs)
            hx, cx, yx = self.encoder(None, None, [exs])

            predicts=[]
            eos = self.xp.array([EOS], dtype=np.int32)
            # EOSだけ入力,あとは予想した出力を入力にして繰り返す
            for y in yx:  # バッチ単位
                predict=[]
                ys_in=[eos]
                for i in range(max_length):
                    eys = [self.embed_y(y) for y in ys_in]
                    _, _, os = self.decoder(hx, cx, eys)
                    op=self._calculate_attention_layer_output(os[0], y)
                    word_id=int(F.argmax(F.softmax(op)).data) # 単語IDに戻す

                    if word_id == EOS:break
                    predict.append(word_id)
                    ys_in=[self.xp.array([word_id], dtype=np.int32)]
                predicts.append(np.array(predict))
            return predict

class Data(chainer.dataset.DatasetMixin):
    def __init__(self):
        mecab = MeCab.Tagger("-Owakati")

        self.vocab={"eos":0,"unk":1}
        def to_dataset(source,target,train=True):
            swords = to_number(mecab.parse(source).strip().split(" "),train)
            twords = to_number(mecab.parse(target).strip().split(" "),train)
            return (np.array(swords).astype(np.int32),np.array(twords).astype(np.int32))

        def to_number(words,train):
            ds=[]
            for w in words:
                if w not in self.vocab:
                    if train:
                        self.vocab[w]=len(self.vocab)
                    else:
                        w="unk"
                ds.append(self.vocab[w])
            return ds

        self.train_data=[]
        self.test_data=[]
        for source,target in train_pair:
            self.train_data.append(to_dataset(source,target))

        for source,target in test_pair:
            self.test_data.append(to_dataset(source,target,False))


        self.vocab_inv={}
        for w in self.vocab.keys():
            self.vocab_inv[self.vocab[w]]=w

def convert(batch, device):
    def to_device_batch(batch):
        return [chainer.dataset.to_device(device, x) for x in batch]

    res= {'xs': to_device_batch([x for x, _ in batch]),
            'ys': to_device_batch([y for _, y in batch])}
    return res




seed = 12345
random.seed(seed)
np.random.seed(seed)

data = Data()

batchsize=5
train_iter = chainer.iterators.SerialIterator(data.train_data,batchsize)
test_iter=chainer.iterators.SerialIterator(data.test_data,len(data.test_data))


n_vocab=len(data.vocab)
n_out=len(data.vocab)
n_hidden=300
n_layer=1
dropout=0.3
print("n_vocab:",n_vocab)

optimizer=chainer.optimizers.Adam()

mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
optimizer.setup(mlp)

updater=chainer.training.StandardUpdater(train_iter,optimizer,converter=convert,device=-1)

#train
epochs=20
trainer=chainer.training.Trainer(updater,(epochs,"epoch"),out="dialog_result")
trainer.extend(chainer.training.extensions.LogReport())
trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))

trainer.run()
mlp.to_cpu()
chainer.serializers.save_npz("attention.model",mlp)




# https://nojima.hatenablog.com/entry/2017/10/17/034840
mlp=EncoderDecoder(n_layer,n_vocab,n_out,n_hidden,dropout)
chainer.serializers.load_npz("attention.model",mlp,path="")

for source,target in data.test_data:
    predict=mlp.translate(np.array(source))
    print("-----")
    print("source:",[data.vocab_inv[w] for w in source])
    print("predict:",[data.vocab_inv[w] for w in predict])
    print("target:",[data.vocab_inv[w] for w in target])


実行結果

n_vocab: 71
epoch       main/loss   main/accuracy
1           4.00112                    
2           3.148                      
3           2.33681                    
4           1.62404                    
5           1.30338                    
6           0.93243                    
7           0.6091                     
8           0.3701                     
9           0.233166                   
10          0.202335                   
11          0.119416                   
12          0.0804442                  
13          0.0629114                  
14          0.0467078                  
15          0.032828                   
16          0.0285745                  
17          0.0225082                  
18          0.0183743                  
19          0.0152494                  
20          0.0140195                  
-----
source: ['初め', 'まして', '。']
predict: ['初め', 'まして', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします', '。', 'よろしくお願いします']
target: ['初め', 'まして', '。', 'よろしくお願いします', '。']
-----
source: ['どこ', 'から', '来', 'た', 'ん', 'です', 'か', '?']
predict: ['日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来', '日本', 'から', '来']
target: ['unk', 'から', '来', 'まし', 'た', '。']
-----
source: ['unk', 'の', 'どこ', 'に', '住ん', 'でる', 'ん', 'です', 'か', '?']
predict: ['東京', 'に', '住ん', 'で', 'い', 'ます', '。']
target: ['unk', 'に', '住ん', 'で', 'い', 'ます', '。']
-----
source: ['おはよう。']
predict: ['おはよう', 'ござい', 'ます', '。']
target: ['おはよう。']
-----
source: ['いつも', '何', '時', 'に', '起き', 'ます', 'か', '?']
predict: ['6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に', '6', '時', 'に']
target: ['unk', '時', 'に', '起き', 'ます', '。']
-----
source: ['unk', 'は', '何', 'を', '食べ', 'ます', 'か', '?']
predict: ['たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と', 'たいてい', 'トースト', 'と']
target: ['たいてい', 'トースト', 'と', '卵', 'を', '食べ', 'ます', '。']
-----
source: ['unk', 'は', '毎日', '食べ', 'ます', 'か', '?']
predict: ['たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く', 'たまに', '朝食', '抜く']
target: ['たまに', '朝食', 'を', '抜く', 'こと', 'が', 'あり', 'ます', '。']
-----
source: ['unk', 'を', 'たくさん', '取っ', 'て', 'い', 'ます', 'か', '?']
predict: ['毎日', '野菜', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。']
target: ['毎日', 'unk', 'を', '取る', 'よう', 'に', 'し', 'て', 'い', 'ます', '。']
-----
source: ['unk', 'unk', 'は', '何', 'を', 'し', 'て', 'い', 'ます', 'か', '?']
predict: ['友達', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と', 'と']
target: ['友達', 'と', '会っ', 'て', 'いる', 'こと', 'が', '多い', 'です', '。']

あまり,結果が良くないが,訓練データを増やし,Epochを増やすと良くなるかもしれない。

chainer.functions.pad_sequence

chainer.functions.pad_sequenceのサンプル


import chainer.functions as F
import chainer 
import numpy as np

a=chainer.Variable(np.array([[1],[2],[3]],dtype=np.float32))
b=chainer.Variable(np.array([[11],[12]],dtype=np.float32))
c=chainer.Variable(np.array([[21]],dtype=np.float32))

tpl=(a,b,c)
F.pad_sequence(tpl)

variable([[[ 1.],
           [ 2.],
           [ 3.]],

          [[11.],
           [12.],
           [ 0.]],

          [[21.],
           [ 0.],
           [ 0.]]])