chainerでCRFを用いて固有表現抽出

このあたりを参考に書き直してみました。
CRFなどの解説はリンク先を参照のこと

仕様

単語列から名詞の塊を抜き出す。
英語で複合名詞など複数の名詞で構成される名詞の塊にフラグをつける。
下記の例では、Oが名詞以外、Bが名詞の開始位置、Iが複合名詞の2個目以降を示している。

 the wall street journal reported today that apple corporation made money 
O    B   I      I      I      O        O    O     B     I          O    O      O

これを、Linearの1層で学習し、CRFで出力する

コード



import numpy as np
import argparse
import os
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, optimizers, initializers,reporter
from chainer.training import extensions
import random
from itertools import chain


# my model
class IconDetector(chainer.Chain):

    def __init__(self, n_vocab, window, n_label, n_unit):
        super().__init__()
        with self.init_scope():
            self.embed=L.EmbedID(n_vocab, n_unit)
            self.lin=L.Linear(n_unit , n_label)
            self.crf=L.CRF1d(n_label=n_label)
        self.window=window

    def forward(self, xs):
        l = xs.shape[1]
        ys = []
        # 1 wordづつ
        for i in range(l):
            x = self.embed(xs[:,i])
            h = F.tanh(x)
            y = self.lin(h)
            ys.append(y)
        return ys #[window, batchsize, n_label]

    def __call__(self, xs, ts):
        """error function"""
        ys = self.forward(xs)
        ts = [ts[:, i] for i in range(ts.data.shape[1])] # [window,batchsize]
        loss = self.crf(ys, ts)
        reporter.report({'loss': loss}, self)
        return loss

    def predict(self, xs):
        ts = self.forward(xs)
        _, ys = self.crf.argmax(ts)
        return ys

class WindowIterator(chainer.dataset.Iterator):

    def __init__(self, text, label, window, batch_size, shuffle= True,repeat=True):
        self.text = np.asarray(text, dtype=np.int32)
        self.label = np.asarray(label, dtype=np.int32)
        self.window = window
        self.batch_size = batch_size
        self._repeat = repeat
        self._shuffle=shuffle

        if self._shuffle:
            self.order = np.random.permutation(
                len(text) - window ).astype(np.int32)

        else:
            self.order=np.array(list(range(len(text) - window )))
        self.current_position = 0
        self.epoch = 0
        self.is_new_epoch = False

    def __next__(self):
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        i = self.current_position
        i_end = i + self.batch_size
        position = self.order[i: i_end]
        offset = np.concatenate([np.arange(0, self.window )])
        pos = position[:, None] + offset[None, :]
        context = self.text.take(pos)
        doc = self.label.take(pos)

        if i_end >= len(self.order):
            np.random.shuffle(self.order)
            self.epoch += 1
            self.is_new_epoch = True
            self.current_position = 0
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return  context, doc

    @property
    def epoch_detail(self):
        return self.epoch + float(self.current_position) / len(self.order)

@chainer.dataset.converter()
def convert(batch, device):
    context, doc = batch
    xp = device.xp
    doc = xp.asarray(doc)
    context = xp.asarray(context)
    return context, doc


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', '-d', type=str, default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--window', '-w', default=5, type=int,
                        help='window size')
    parser.add_argument('--batchsize', '-b', type=int, default=2,
                        help='learning minibatch size')
    parser.add_argument('--epoch', '-e', default=100, type=int,
                        help='number of epochs to learn')
    parser.add_argument('--out', default=os.environ["WORK"]+"/summarization/sm_icgw/result",
                        help='Directory to output the result')

    args = parser.parse_args()

    random.seed(12345)
    np.random.seed(12345)

    device = chainer.get_device(args.device)

    training_data = [(
        " the wall street journal reported today that apple corporation made money ".split(),
        "O B I I I O O O B I O O O".split()
    ), (
        " georgia tech is a university in georgia ".split(),
        "O B I O O O O B O".split()
    )]

    validation_data = [(' georgia tech reported today '.split(),"O B I O O O".split())]

    word2id={"":0,"":1,"":2,"":3}
    label2id={"B":0,"I":1,"O":2}

    def get_dataset(data,is_train=True):
        texts=[]
        labels=[]
        for word,attrib in data:
            for w,a in zip(word,attrib):
                if w not in word2id:
                    if is_train:
                        word2id[w]=len(word2id)
                    else:
                        w=""
                texts.append(word2id[w])
                labels.append(label2id[a])
        return texts,labels

    train_text,train_label=get_dataset(training_data)
    valid_text,valid_label=get_dataset(validation_data,False)

    n_vocab=len(word2id)
    n_label=len(label2id)
    n_unit=n_vocab//2
    model=IconDetector(n_vocab=n_vocab, window=args.window,n_label=n_label,n_unit=n_unit)

    model.to_device(device)
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train_iter = WindowIterator(train_text, train_label, args.window, args.batchsize)
    valid_iter = WindowIterator(valid_text, valid_label, args.window,args.batchsize, repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    trainer.extend(extensions.LogReport(),trigger=(10, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss','validation/main/loss']))
    #trainer.extend(extensions.ProgressBar())

    trainer.extend(extensions.Evaluator(valid_iter, model,converter=convert, device=device),trigger=(10, 'epoch'))
    trainer.run()

    # testing
    testing_data = [(' the street journal is a university '.split(),"O B I I O O O O".split())]
    test_text,test_label=get_dataset(testing_data,False)
    with chainer.using_config('train', False), \
            chainer.using_config('enable_backprop', False):
        ys=model.predict(np.array([test_text],dtype=np.int32))

        ys=list(chain.from_iterable(ys))
        print(ys)
        print(test_label)

if __name__ == '__main__':
    main()

実行結果

epoch       main/loss   validation/main/loss
10          6.24186     7.77101               
20          5.43089     6.29938               
30          3.16409     5.09233               
40          2.68146     4.02053               
50          1.47248     3.07853               
60          1.63073     2.28975               
70          1.13568     1.67798               
80          0.842947    1.22796               
90          0.364242    0.912943              
100         0.658557    0.692736              
[2, 0, 1, 1, 2, 2, 2, 2]
[2, 0, 1, 1, 2, 2, 2, 2]

ちゃんと学習できているようです