# 如何绘制美妙的注意力序列可视化图
本章节主要介绍如何采用latex绘制注意力序列的可视化图。这里主要采用的package是tcolorbox。 故首先要引用该包如下:
~~~
\usepackage{tcolorbox}
~~~
在序列上体现注意力的核心思想是通过渐变色来体现注意力的权重,即越大的权重对应越深的颜色,越小的权重对应越浅的颜色。例如,如果想在单词 `What` 上体现 `30%` 的注意力权重,可以使用如下命令来实现
~~~
\colorbox{red!30}{\strut What}
~~~
其中 `red!30` 指的即是不透明度为`30%`的红色,而 `\strut` 是为了让 `colorbox` 与所包裹的文字边界对齐。
掌握了基本原理后,我们可以手动地构造序列,或者使用 Python脚本自动生成。下面我们提供了一个Python脚本,你需要先安装 `numpy` 库
~~~
pip install numpy
~~~
然后在代码 `TODO:` 处填写你需要的内容,将生成的 `text_attention.tex` 中的图片块复制到需要展示的地方即可。
> 以下代码修改自开源库[Text-Attention-Heatmap-Visualization](https://github.com/jiesutd/Text-Attention-Heatmap-Visualization),如果您觉得该代码有用,请考虑引用原作者的论文。
```python
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2019-03-29 16:10:23
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
# @Last Modified time: 2019-04-12 09:56:12
## convert the text/attention list to latex code, which will further generates the text heatmap based on attention weights.
import numpy as np
latex_special_token = ["!@#$%^&*()"]
def generate(text_list, attention_list, latex_file, color='red', rescale_value=False):
assert (len(text_list) == len(attention_list))
if rescale_value:
attention_list = rescale(attention_list)
word_num = len(text_list)
text_list = clean_word(text_list)
with open(latex_file, 'w') as f:
f.write(r'''\begin{figure}
\centering
''')
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.85\textwidth}{''' + "\n"
for idx in range(word_num):
string += "\\colorbox{%s!%s}{" % (color, attention_list[idx]) + "\\strut " + text_list[idx] + "} "
string += "\n}}}"
f.write(string + "\n \end{figure}")
def rescale(input_list):
the_array = np.asarray(input_list)
the_max = np.max(the_array)
the_min = np.min(the_array)
rescale = (the_array - the_min) / (the_max - the_min) * 100
return rescale.tolist()
def clean_word(word_list):
new_word_list = []
for word in word_list:
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
if latex_sensitive in word:
word = word.replace(latex_sensitive, '\\' + latex_sensitive)
new_word_list.append(word)
return new_word_list
if __name__ == '__main__':
# TODO: 文本输入处,以空格分割单词
sent = "Who are the only plaerys listed that played in 2011 ?"
words = sent.split()
# TODO: 注意力权重输入处,最大值是100.0
attention = [14.9, 13.8, 9.7, 6.5, 12.3, 6.9, 7.1, 8.5, 5.6, 3.8, 12.1]
assert len(attention) == len(words)
# TODO: latex 支持的颜色,包括 red, green, blue, cyan, magenta, yellow, black, gray, white, darkgray, lightgray, brown, lime, olive, orange, pink, purple, teal, violet 等
color = 'red'
generate(words, attention, "text_attention.tex", color)
```