Word2Vecの出力をMDSで解釈してみる
はじめに
前回紹介したWord2Vecを用いた文章の分散表現をMDSで次元削減し視覚的に解釈可能にしてみます。
oryou-san.hatenablog.com
前処理
scikit-learnのdataset"fetch_20newsgroups"を学習用データとして用います。
Word2Vecの入力に適するように文章を単語毎に区切られたリストに変換し、記号は空白に置換します。
from sklearn.datasets import fetch_20newsgroups data = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')) texts = [txt.replace('.', '').replace(',', '').replace('/', '').split() for txt in data.data] texts = [w for w in texts] print(texts[0]) # ['I', 'was', 'wondering', 'if', 'anyone', 'out', 'there', 'could', 'enlighten', ... , 'e-mail']
学習(Word2Vec)
前処理を施したデータをモデルに投入します。
単語の類似度をどの程度学習できたかを確かめるのに'windows'という単語ベクトルとコサイン類似度の高い単語を降順で出力してみます。
PC関係の単語がずらずらと出力されてきたのでモデルの学習が正常に進んだことを確認できました。
from gensim.models import Word2Vec model = Word2Vec(texts, min_count=1, seed=1) key = 'windows' for w in model.wv.most_similar(key, topn = 10): print(w) # ('Windows', 0.936229407787323) # ('drivers', 0.9341704845428467) # ('memory', 0.9085516929626465) # ('disks', 0.9075465798377991) # ... # ('bus', 0.8947420716285706)
単語の分散表現から文章トピックの分散表現を獲得
データセットに存在する単語の分散表現を獲得できたので、これを用いて文章トピックの分散表現を計算してみます。
以下の手順で計算しています。
1.文章に登場する単語のベクトルの算術平均を取り、これを文章の分散表現とする
2.トピックに登場する単語のベクトルの算術平均を取り、これをトピックの分散表現とする
import numpy as np vec_news = [np.average([model.wv[w] for w in txt], axis = 0) if txt != [] else np.zeros(100) for txt in texts] vec_topics = [] for i in range(20): tpc_i = [idx for idx,d in enumerate(data.target) if d == i] vec_topic = np.average(np.array(vec_news)[tpc_i], axis=0) vec_topics.append(vec_topic) vec_topics = np.array(vec_topics)
学習(MDS)
いよいよMDSに学習データを投入します。
先ほど獲得したトピック毎の分散表現(トピック数(20) * ベクトル次元数(100))から各トピック間のL2ノルムを計算します(トピック数(20) * トピック数(20))。
L2ノルム計算にあたってはブロードキャスト処理を用いて効率化しております。
(記事末尾に記載のサイト様を参考にしました。)
diffs = np.expand_dims(vec_topics, axis=1) - np.expand_dims(vec_topics, axis=0) dist = np.sqrt(np.sum(diffs ** 2, axis=-1)) from sklearn import manifold mds = manifold.MDS(n_components=2, dissimilarity="precomputed", random_state=1) pos = mds.fit_transform(dist)
結果を散布図にして出力
MDSにて2次元に圧縮されたトピックの分散表現を散布図に出力してみました。
トピックの大分類(7つ)を用いてドットを色分けしていますが、散布図上でも綺麗に色毎にまとまっているように見えます。
labels = data.target_names label_big = [label[:(label.find('.'))] for label in labels] %matplotlib inline import matplotlib.pyplot as plt x = pos[:,0] y = pos[:,1] for label in list(set(label_big)): label_idx = [i for i,l in enumerate(label_big) if l == label] plt.scatter(x[label_idx], y[label_idx], s=20) plt.legend(list(set(label_big))) plt.show()
参考