加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

EM算法分析与实现

(2012-12-16 16:27:00)
标签:

杂谈

分类: MachineLearning

    花了2天时间,看了多方面的资料,也得到了师兄的指导,参考了别人写的代码,终于自己还是在matlab下实现了,下面先简单分析下EM算法的思想(主要针对混合高斯分布的情况):

    对于混合高斯分布的情况,已知条件是样本数据,其他的都是未知的。为了简单分析,下面假设有3个不同的高斯分布(一般是3--5),每个高斯分布的均值和方差都是未知的。EM算法就是通过不断计算每个样本的均值和方差,使得似然函数达到最大值。

http://s1/mw690/98b365154d0f471baec90&690

http://s15/mw690/98b3651507b4ba5aa0d8e&690
2.EM算法
http://s10/mw690/98b365154d0f4724e5639&690
http://s16/mw690/98b365154d0f48070a64f&690

   详细介绍可以参考Pattern Recognition and Machine Learning.pdf,或者http://wenku.baidu.com/view/60c583294b73f242336c5fbc.html

   下面就是我基于matlab实现的EM算法,真心觉得理解原理不容易,在写代码的过程中也要很仔细,在实现的过程中还是参照了一位大牛的代码,理解思想,看了他的代码之后自己重新写的,思想多少有点受他影响,但是个人觉得代码还是有优化过。

%EM
M=3;          % M个高斯分布混合
N=600;        % 样本数
th=0.000001;  % 收敛阈值
K=2;          % 样本维数
% 待生成数据的参数
a_real =[2/3;1/6;1/6];%混合模型中基模型高斯密度函数的权重
mu_real=[3 4 6;5 3 7];%均值
cov_real(:,:,1)=[5 0;0 0.2];%协方差
cov_real(:,:,2)=[0.1 0;0 0.1];
cov_real(:,:,3)=[0.1 0;0 0.1];                    
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,1) , cov_real(:,:,1) , round(N*a_real(1)) )' ,...
    mvnrnd( mu_real(:,2) , cov_real(:,:,2) , round(N*a_real(2)) )' ,...
    mvnrnd( mu_real(:,3) , cov_real(:,:,3) , round(N*a_real(3)) )' ];
%初始化参数
a=[1/3;1/3;1/3];
mu=[1 2 3;2 1 4];
cov(:,:,1)=[1 0;0 1];
cov(:,:,2)=[1 0;0 1];
cov(:,:,3)=[1 0;0 1];
t=inf;
while t>=th
    a_old  = a;
    mu_old = mu;
    cov_old= cov;     
    rznk_temp=zeros(M,N);
    for k=1:M
        for n=1:N
            %计算P(x|mu_cm,cov_cm)
            rznk_temp(k,n)=exp(-1/2*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
        end
        rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
    end
    rznk_temp=rznk_temp*(2*pi)^(-K/2);
%E step
    %求rznk
    rznk=zeros(M,N);
    for n=1:N
        for k=1:M
            rznk(k,n)=a(k)*rznk_temp(k,n);
        end
        rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
    end
% M step
    %求Nk
    nk=zeros(1,M);
    nk=sum(rznk');
   
    % 求a
    a=nk/N;
       
    % 求MU
    for k=1:M
        mu_k_sum=0;
        for n=1:N
            mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
        end
        mu(:,k)=mu_k_sum/nk(k);
    end
   
    % 求COV  
    for k=1:M
        cov_k_sum=0;
        for n=1:N
            cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
        end
        cov(:,:,k)=cov_k_sum/nk(k);
    end
      
    t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]); 
end 

%输出结果并比较

a_real
a

mu_real
mu

cov_real
cov

%结果

a_real =

    0.6667
    0.1667
    0.1667


a =

    0.6657    0.1681    0.1662


mu_real =

           6
           7


mu =

    3.0366    3.9987    6.0406
    4.9941    2.9888    7.0190


cov_real(:,:,1) =

    5.0000         0
           0.2000


cov_real(:,:,2) =

    0.1000         0
           0.1000


cov_real(:,:,3) =

    0.1000         0
           0.1000


cov(:,:,1) =

    5.4894   -0.0389
   -0.0389    0.1939


cov(:,:,2) =

    0.0682    0.0038
    0.0038    0.0959


cov(:,:,3) =

    0.0866   -0.0033
   -0.0033    0.0761
   通过输出结果发现算法的准确性还是比较高的,算法迭代得到的值与实际值出入不是很大。

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有