加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

SVM算法PYTHON实现

(2017-06-04 21:07:24)
一。数据
x1 x2 y
7.15 14.8 -1
8.85 13 -1
11.45 12.9 -1
19.6 14.1 -1
19.25 16.2 -1
11.65 15.75 -1
8.9 15.85 -1
10.85 14.6 -1
14.3 15.55 -1
16.25 14.95 -1
13.6 14.05 -1
15.5 14.05 -1
16.85 15 -1
7.25 10.25 1
8 9.2 1
13.7 8.1 1
17.65 7.8 1
17.8 9.3 1
14.75 9.85 1
10.35 10 1
9.1 8.4 1
10.8 7.95 1
11.15 8.35 1
13.45 9.35 1
16.25 8.6 1
19 9.05 1
16.8 9.7 1
15.45 9.25 1
11.65 8.45 1
8.45 10.25 1
8.45 10.3 1
10.1 9.65 1
12.9 9.1 1
13.75 9.5 1
16.25 9.05 1
12.35 15.05 1
二、算法

# coding:utf-8
from numpy import *
import numpy.random as rd
#from dml.tool import sign
import matplotlib.pyplot as plt

class SVMC:
    def Gauss_kernel(x,z,sigma=2):  
        return exp(-sum((x-z)**2)/(2*sigma**2))  
    def Linear_kernel(x,z):  
        return sum(x*z)
    def __init__(self,x,y,C=5,tol=0.01,kernel=Linear_kernel):
        self.X=array(x)
        self.y=array(y).flatten(1) 
        self.tol=tol
        self.C=C
        self.kernel=kernel
        self.N,self.M=self.X.shape
        self.E=zeros((1,self.M)).flatten(1)
        self.alpha=zeros((1,self.M)).flatten(1)
        self.b=0

    def fitKKT(self,i):
        if ((self.y[i]*self.E[i]<-self.tol) and (self.alpha[i]self.tol)) and (self.alpha[i]>0)):
            return False  
        return True
    
    def selectalpha1(self):
        for i in range(self.M):
            #self.updateE(i)
            if (not self.fitKKT(i)):
                break
        return i
            
    def selectalpha2(self,i):
        nonalpha=nonzero((self.alpha))[0]
        if (len(nonalpha)>0):
            maxDelta=-1
            kk=-1
            for j in nonalpha:
                if i==j:
                    continue
                if abs(self.E[i]-self.E[j])>maxDelta:
                    maxDelta=abs(self.E[i]-self.E[j])
                    kk=j
            return kk
        else:
            #kk=rd.sample(range(self.M),1)
            kk=rd.choice(range(self.M),1)
            while kk==i:
                kk=rd.choice(range(self.M),1)
            return kk[0]

    def updateE(self,i):
        self.E[i]=0
        for j in range(self.M):
            self.E[i]+=self.alpha[j]*self.y[j]*self.kernel(self.X[:,i],self.X[:,j])
        self.E[i]=self.E[i]-self.y[i]+self.b

    def innerLoop(self,i,threshold):
        #确定alpha1,alpha2
        #i=self.selectalpha1()
        j=self.selectalpha2(i)
        #求解L,H
        if (self.y[i]==self.y[j]):  
            L=max(0,self.alpha[i]+self.alpha[j]-self.C)  
            H=min(self.C,self.alpha[i]+self.alpha[j])  
        else:  
            L=max(0,self.alpha[j]-self.alpha[i])  
            H=min(self.C,self.C+self.alpha[j]-self.alpha[i])
        a2_old=self.alpha[j]  
        a1_old=self.alpha[i]
        K11=self.kernel(self.X[:,i],self.X[:,i])  
        K22=self.kernel(self.X[:,j],self.X[:,j])  
        K12=self.kernel(self.X[:,i],self.X[:,j])  
        eta=K11+K22-2*K12  
        if eta==0:  
            return True
        #更新alpha2
        self.updateE(i)
        self.updateE(j)
        self.alpha[j]=a2_old+self.y[j]*(self.E[i]-self.E[j])/eta
        if self.alpha[j]>H:
            self.alpha[j]=H
        elif self.alpha[j]
            self.alpha[j]=L
            
        #更新alpha1
        self.alpha[i]=a1_old+self.y[i]*self.y[j]*(a2_old-self.alpha[j])
        #更新b
        b1_new=self.b-self.E[i]-self.y[i]*K11*(self.alpha[i]-a1_old)-self.y[j]*K12*(self.alpha[j]-a2_old)  
        b2_new=self.b-self.E[j]-self.y[i]*K12*(self.alpha[i]-a1_old)-self.y[j]*K22*(self.alpha[j]-a2_old)
        if self.alpha[i]>0 and self.alpha[i]
            self.b=b1_new
        elif self.alpha[j]>0 and self.alpha[j]
            self.b=b2_new
        else:
            self.b=(b1_new+b2_new)/2
        #self.updateE(j)  
        #self.updateE(i) 

        #计算精度是否满足
        if abs(self.alpha[j]-a2_old)
            return True
        else:
            return False

    def train(self,maxiter=100,threshold=0.000001):
        iters=0
        flag=False
        while (iters
            flag=True
            iters+=1
            temp_supportVec=nonzero((self.alpha>0))[0]
            #先从边界点找不满足的点确定alpha1,再从不是支持向量的点中找alpha1
            for i in temp_supportVec:  
                self.updateE(i)  
                if (not self.fitKKT(i)):  
                    flag=flag and self.innerLoop(i,threshold)  
                    #if not flag:break  
            if (flag):  
                for i in range(self.M):  
                    self.updateE(i)  
                    if (not self.fitKKT(i)):  
                        flag= flag and self.innerLoop(i,threshold) 
            #判断所有的alpha是否满足kkt条件
            #for i in range(self.M):
                #self.updateE(i)
                #if (not self.fitKKT(i)):
                    #flag=False
                    #self.innerLoop(threshold)
        self.supportVec=nonzero((self.alpha>0))[0]

    def predict(self,x):
        f=0
        for i in range(self.supportVec):
            f+=self.alpha[t]*self.y[t]*self.kernel(self.X[:,t],x).flatten(1)
        f+=self.b
        return sign(f)
                
    def pred(self,X):  
        test_X=np.array(X)  
        y=[]  
        for i in range(test_X.shape[1]):  
            y.append(self.predict(test_X[:,i]))  
        return y

    def prints_test_linear(self):  
        w=0  
        for t in self.supportVec:  
            w+=self.alpha[t]*self.y[t]*self.X[:,t].flatten(1)  
        w=w.reshape(1,w.size)  
        #print sum(sign(np.dot(w,self.X)+self.b).flatten(1)!=self.y),"errrr"  
        #print w,self.b  
        x1=0  
        y1=-self.b/w[0][1]  
        y2=0  
        x2=-self.b/w[0][0]  
        plt.plot([x1+x1-x2,x2],[y1+y1-y2,y2])  
        #plt.plot([x1+x1-x2,x2],[y1+y1-y2-1,y2-1])  
        plt.axis([0,30,0,30])  
  
        for i in range(self.M):  
            if  self.y[i]==-1:  
                plt.plot(self.X[0,i],self.X[1,i],'or')  
            elif  self.y[i]==1:  
                plt.plot(self.X[0,i],self.X[1,i],'ob')  
        for i in self.supportVec:  
            plt.plot(self.X[0,i],self.X[1,i],'oy')  
        plt.show()  
三、结果

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有