LDA(トピックモデル)を使ってみる
はじめに
LDAを試しに使ってみたので備忘として残します。
加藤公一氏著の機械学習図鑑からコードを拝借しました。
学習&分類
以下書籍からの引用です。(コードはgithubから落とせます。)
やってることは以下の通りです。
・scikit-learnのdataset"fetch_20newsgroups"よりサンプルデータを取得
・CountVectorizerにてサンプルデータの文章をベクトル化
・ベクトル化された文章をLDAで次元削減(サンプルデータのラベルが20種類なので、これに合わせて20次元に削減)
from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import CountVectorizer from sklearn.decomposition import LatentDirichletAllocation # removeで本文以外の情報を取り除く data = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')) max_features = 1000 # 文書 データをベクトルに変換(n_data, max_features=1,000) tf_vectorizer = CountVectorizer(max_features=max_features, stop_words='english') tf = tf_vectorizer.fit_transform(data.data) n_topics = 20 model = LatentDirichletAllocation(n_components=n_topics) model.fit(tf)
CountVectorizerクラスに設定されている引数2つは以下の通りです。
・max_features
文章をベクトル化する際に最頻n個の単語のみを考慮するように設定できます。
ここでは最頻1,000個の単語をベクトル化するように設定しています。
・stop_words
文章をベクトル化するにあたり除外すべき単語をここで指定します。
ここでは'english'となっていますがこれでbuilt-inの英語のstop-wordsが指定できます。
.get_stop_words()のメソッドで登録されている単語群が確認できます。
stop_words = tf_vectorizer.get_stop_words() stop_words = [w for w in stop_words] print(f'len:{len(stop_words)} \nword_list: {stop_words}') # len:318 # word_list: ['anyone', 'you', 'to', 'thereby', 'via', 'someone', 'seeming', 'find', 'least', 'besides', 'now', 'thick', 'itself', 'own', 'back', 'latter', 'became', 'take', 'name', 'moreover', 'due', 'could', .....
結果
書籍だとトピック毎に頻出する単語をリストにして結果を解釈しておりましたので同じように結果を出力してみます。
やってることは以下の通りです。
・文章毎に削減した20次元のベクトル(トピック)のうち、一番値の大きいものを選ぶ(一番もっともらしいトピックを選ぶ)
・同じトピックに分類された文章毎に単語を頻出頻度順(降順)に出力
import numpy as np # 文書データのベクトルを20次元に削減(n_data, n_topics=20) tpc_cls = np.argmax(model.transform(tf), axis=1) for a in range(20): tpc_id = [i for i,c in enumerate(tpc_cls) if c == a] tpc_stc = [s for i,s in enumerate(data.data) if i in tpc_id] tf_tpc = tf_vectorizer.fit_transform(tpc_stc) term_frequency = np.array(tf_tpc.sum(axis=0))[0] print(f'topic_{a+1}') print([tf_vectorizer.get_feature_names()[b] for b in term_frequency.argsort()[:-10:-1]])
出力はこんな感じになるはずです。
topic_1 ['god', 'jesus', 'people', 'bible', 'christ', 'does', 'christian', 'believe', 'church'] topic_2 ['don', 'know', 'think', 'people', 'does', 'just', 'like', 'say', 'god'] topic_3 ['00', 'new', '50', 'sale', '10', 'dos', 'good', 'offer', '20'] topic_4 ['president', 'mr', 'stephanopoulos', 'people', 'think', 'going', 'know', 'don', 'jobs'] topic_5 ['entry', 'file', 'cx', 'output', 'w7', 'program', 'c_', 'chz', 't7'] topic_6 ['key', 'encryption', 'chip', 'use', 'clipper', 'keys', 'privacy', 'government', 'security'] topic_7 ['pts', 'period', 'la', 'play', 'power', 'pp', '10', 'pt', 'scorer'] topic_8 ['space', 'nasa', '1993', 'research', 'university', 'launch', 'program', 'information', 'use'] topic_9 ['people', 'israel', 'government', 'right', 'israeli', 'state', 'war', 'don', 'jews'] topic_10 ['said', 'people', 'know', 'didn', 'don', 'just', 'went', 'like', 'time'] topic_11 ['edu', 'com', 'available', 'use', 'file', 'image', 'window', 'ftp', 'program'] topic_12 ['10', '25', '55', '11', '12', '14', '15', '20', '16'] topic_13 ['gun', 'file', 'guns', 'firearms', 'crime', 'control', 'law', '000', 'use'] topic_14 ['scsi', 'use', 'bit', 'data', 'bus', 'ide', 'chip', 'like', 'does'] topic_15 ['team', 'game', 'year', 'season', 'games', 'hockey', 'players', 'league', 'nhl'] topic_16 ['drive', 'windows', 'card', 'use', 'thanks', 'disk', 'problem', 'know', 'does'] topic_17 ['like', 'just', 'good', 'don', 'time', 'car', 'think', 'know', 've'] topic_18 ['ax', 'max', 'g9v', 'b8f', 'a86', 'pl', '145', '1d9', '0t'] topic_19 ['db', 'armenian', 'turkish', 'armenians', 'people', 'turkey', 'jews', 'turks', 'genocide'] topic_20 ['people', 'just', 'like', 'think', 'don', 'time', 'post', 'know', 'does']
topic_1はキリスト教、topic_8は宇宙、といったところでしょうか。
反対にtopic_12,18, 20あたりは解釈が難しく、stop_wordsを見直すなどしてより精度の高い分類ができるような工夫が必要になりそうです。