M-H采样已经可以很好的解决蒙特卡罗方法需要的任意概率分布的样本集的问题。但是M-H采样有两个缺点:一是需要计算接受率,在高维时计算量大。并且由于接受率的原因导致算法收敛时间变长。二是有些高维数据,特征的条件概率分布好求,但是特征的联合分布不好求。因此需要一个好的方法来改进M-H采样,这就是我们下面讲到的Gibbs采样。
上面的这个算法推广到多维的时候也是成立的。比如一个 $\mathrm{n}$ 维的概率分布 $\pi\left(x_1, x_2, \ldots x_n\right)$ ,我们可以通过在 $\mathrm{n}$ 个坐标轴上轮换采样,来得到新的样 本。对于轮换到的任意一个坐标轴 $x_i$ 上的转移,马尔科夫链的状态转移概率为 $P\left(x_i \mid x_1, x_2, \ldots, x_{i-1}, x_{i+1}, \ldots, x_n\right)$ ,即固定 $n-1$ 个坐标轴,在某一个坐 标轴上移动。 具体的算法过程如下: 1) 输入平稳分布 $\pi\left(x_1, x_2, \ldots, x_n\right)$ 或者对应的所有特征的条件概率分布,设定状态转移次数阈值 $n_1$ ,需要的样本个数 $n_2$ 2) 随机初始化初始状态值 $\left(x_1^{(0)}, x_2^{(0)}, \ldots, x_n^{(0)}\right.$ 3) for $t=0$ to $n_1+n_2-1$ : a) 从条件概率分布 $P\left(x_1 \mid x_2^{(t)}, x_3^{(t)}, \ldots, x_n^{(t)}\right)$ 中采样得到样本 $x_1^{t+1}$ b) 从条件概率分布 $P\left(x_2 \mid x_1^{(t+1)}, x_3^{(t)}, x_4^{(t)}, \ldots, x_n^{(t)}\right)$ 中采样得到样本 $x_2^{t+1}$ c) $\ldots$ d) 从条件概率分布 $P\left(x_j \mid x_1^{(t+1)}, x_2^{(t+1)}, \ldots, x_{j-1}^{(t+1)}, x_{j+1}^{(t)} \ldots, x_n^{(t)}\right)$ 中采样得到样本 $x_j^{t+1}$ e)... f) 从条件概率分布 $P\left(x_n \mid x_1^{(t+1)}, x_2^{(t+1)}, \ldots, x_{n-1}^{(t+1)}\right)$ 中采样得到样本 $x_n^{t+1}$ 样本集 $\left\{\left(x_1^{\left(n_1\right)}, x_2^{\left(n_1\right)}, \ldots, x_n^{\left(n_1\right)}\right), \ldots,\left(x_1^{\left(n_1+n_2-1\right)}, x_2^{\left(n_1+n_2-1\right)}, \ldots, x_n^{\left(n_1+n_2-1\right)}\right)\right\}$ 即为我们需要的平稳分布对应的样本集。
整个采样过程和Lasso回归的坐标轴下降法算法非常类似,只不过Lasso回归是固定 $n-1$ 个特征,对某一个特征求极值。而Gibbs采样是固定 $n-1$ 个 特征在某一个特征采样。 同样的,轮换坐标轴不是必须的,我们可以随机选择某一个坐标轴进行状态转移,只不过常用的Gibbs采样的实现都是基于坐标轴轮换的。
具体的代码如下:
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
samplesource = multivariate_normal(mean=[5,-1], cov=[[1,1],[1,4]])
import numpy as np
import math
def p_ygivenx(x, m1, m2, s1, s2):
return (random.normalvariate(m2 + rho * s2 / s1 * (x - m1), math.sqrt((1 - rho ** 2) * (s2**2))))
def p_xgiveny(y, m1, m2, s1, s2):
return (random.normalvariate(m1 + rho * s1 / s2 * (y - m2), math.sqrt((1 - rho ** 2) * (s1**2))))
N = 5000
K = 20
x_res = []
y_res = []
z_res = []
m1 = 5
m2 = -1
s1 = 1
s2 = 2
rho = 0.5
y = m2
for i in range(N):
for j in range(K):
x = p_xgiveny(y, m1, m2, s1, s2)
y = p_ygivenx(x, m1, m2, s1, s2)
z = samplesource.pdf([x,y])
x_res.append(x)
y_res.append(y)
z_res.append(z)
num_bins = 50
plt.hist(x_res, num_bins, density=1, facecolor='green', alpha=0.5)
plt.hist(y_res, num_bins, density=1, facecolor='red', alpha=0.5)
plt.title('Histogram')
plt.show()
然后我们看看样本集生成的二维正态分布,代码如下:
fig = plt.figure()
ax = Axes3D(fig, elev=30, azim=20)
ax.scatter(x_res, y_res, z_res,marker='o')
plt.show()
由于Gibbs采样在高维特征时的优势,目前我们通常意义上的MCMC采样都是用的Gibbs采样。当然Gibbs采样是从M-H采样的基础上的进化而来的,同时Gibbs采样要求数据至少有两个维度,一维概率分布的采样是没法用Gibbs采样的,这时M-H采样仍然成立。
有了Gibbs采样来获取概率分布的样本集,有了蒙特卡罗方法来用样本集模拟求和,他们一起就奠定了MCMC算法在大数据时代高维数据模拟求和时的作用。
来源
random.normalvariate
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
samplesource = multivariate_normal(mean=[5,-1], cov=[[1,1],[1,4]])
import numpy as np
import math
def p_ygivenx(x, m1, m2, s1, s2):
return (random.normalvariate(m2 + rho * s2 / s1 * (x - m1), math.sqrt((1 - rho ** 2) * (s2**2))))
def p_xgiveny(y, m1, m2, s1, s2):
return (random.normalvariate(m1 + rho * s1 / s2 * (y - m2), math.sqrt((1 - rho ** 2) * (s1**2))))
N = 5000
K = 20
x_res = []
y_res = []
z_res = []
m1 = 5
m2 = -1
s1 = 1
s2 = 2
rho = 0.5
y = m2
for i in range(N):
for j in range(K):
x = p_xgiveny(y, m1, m2, s1, s2)
y = p_ygivenx(x, m1, m2, s1, s2)
z = samplesource.pdf([x,y])
x_res.append(x)
y_res.append(y)
z_res.append(z)
num_bins = 50
plt.hist(x_res, num_bins, density=1, facecolor='green', alpha=0.5)
plt.hist(y_res, num_bins, density=1, facecolor='red', alpha=0.5)
plt.title('Histogram')
plt.show()
fig = plt.figure()
ax = Axes3D(fig, elev=30, azim=20)
ax.scatter(x_res, y_res, z_res,marker='o')
plt.show()
C:\Users\Haihua Wang\AppData\Local\Temp\ipykernel_43748\3351083173.py:2: MatplotlibDeprecationWarning: Axes3D(fig) adding itself to the figure is deprecated since 3.4. Pass the keyword argument auto_add_to_figure=False and use fig.add_axes(ax) to suppress this warning. The default value of auto_add_to_figure will change to False in mpl3.5 and True values will no longer work in 3.6. This is consistent with other Axes classes. ax = Axes3D(fig, elev=30, azim=20)