加载中…
个人资料
Yode
Yode
  • 博客等级:
  • 博客积分:0
  • 博客访问:596,795
  • 关注人气:250
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
相关博文
推荐博文
谁看过这篇博文
加载中…
正文 字体大小:

PRanking的代码实现(小蒯的代码被俺简单地合并了一下)

(2007-04-30 16:32:46)
  希望小蒯不要追究俺地责任呢,数据集呢就是OHSUMED数据集了,query1-9做trainingset(不要包含4),query4做testset.代码如下:
#include <iostream>
#include <fstream>
#include <limits>
#include <iomanip>
using namespace std;
#define K 3//排序的序数,即如排成全相关,部分相关,不相关,序数就是3
#define N 25//特征的维数
int *b,*y,*t;
double *w;

//从文件中获得特征值
bool getData(double *x,int &yt,ifstream &fin)
{
    if (fin.eof())
        return false;

    char data[1024];
    int index = 1;
    fin.getline(data,1024);
    char *p = data;
    char q[100];
    q[0] = p[0];
    q[1] = '\0';
    yt = atoi(q) + 1;
    p = p+7;//跳过qid:1的冒号
    for(;*p != '\0';++p)
    {
        if(*p == ':')
        {
            ++p;
            for(int i=0; *p != ' ';i++,p++)
            {
                q[i] = *p;
            }

            q[i] = '\0';

           
            x[index ++] = atof(q);
        }

    }

    return true;
}

//各变量进行初始化
void Initialize()
{
    w = new double[N+1];
    b = new int[K+1];
    y = new int[K+1];
    t = new int[K+1];

    int i;
    for(i=1; i<=N;i++)
        w[i] = 0 ;

    for(i=1;i<=K-1;i++)
        b[i] = 0;

    b[K] = std::numeric_limits<int>::max();//无穷大
   
}

//利用Prank算法进行训练
void PrankTraining(double *x,int yt)
{
    int i;
    double wx = 0;
    for(i =1; i<=N; i++)
        wx+= w[i] * x[i];
    for(i =1; i<=K; i++)
    {
        if(wx - b[i] <0 )
            break;
    }

    int yy = i ;
    if (yy == yt)
        return;
    else
    {

        for(i=1; i<K; i++)
        {
            if(yt <= i)
                y[i] = -1;
            else
                y[i] = 1;
        }

        for(i=1;i<K;i++)
        {

            if ((wx-b[i])*y[i] <= 0)
            {
                t[i] = y[i];

            }
            else
                t[i] = 0;
        }
       
        //调整 w 和 b
        int sumt = 0;
        for(i=1;i<K;i++)
            sumt = sumt + t[i];

        for(i=1;i<=N;i++)
            w[i] = w[i] + sumt*x[i];

        for(i=1;i<K;i++)
            b[i] = b[i] - t[i];
    }
}
//利用得到的model进行测试
int Pranking(double *x)
{

    int i;
    double wx = 0;
    for(i =1; i<=N; i++)
        wx = wx + w[i] * x[i];

    for(i =1; i<=K;i++)
        if(wx - b[i] <0 )
            break;

    return i;

}


int main(int argc,char **argv)
{
    int right=0,wrong=0;//排正确和错误的样本数

    if(argc !=5)
    {
        cout <<"Usage: PRank testFile modelFile resultFile"<<endl;
        return -1;
    }

    ifstream fin_train(argv[1]);
    if(fin_train.fail())
    {
        cout << "can't open the traningsetFile!"<<endl;
        return -1;
    }
   
    ofstream fout_model(argv[2]);
    if(fout_model.fail())
    {
        cout << "can't open the ModelFile!"<<endl;
        return -1;
    }

    ifstream fin_test(argv[3]);
    if(fin_test.fail())
    {
        cout << "can't open the testsetFile!"<<endl;
        return -1;
    }

    ofstream fout_result(argv[4]);
    if(fout_result.fail())
    {
        cout << "open resultFile  failed!"<<endl;
        return -1;
    }

    double *tr = new double[N+1];
    int yt;
    Initialize();
    int i = 0;

    //读入训练数据进行训练得到model
    while(true)
    {
        if (getData(tr,yt,fin_train))
        {
            PrankTraining(tr,yt);//训练
        }
        else
            break;
    }

   //将得到的w和b写入文件
    char   buff[128];
    cout<<"训练出的w为:\n";
    for(i=1; i<=N; i++)//写w
    {
        cout<<setw(8)<<w[i]<<'\t';
        memset(buff,0,sizeof(buff));  
        sprintf(buff,"%f",w[i]);
        fout_model << buff << " ";
    }
    fout_model<<endl;

    cout<<"\n\n训练出的b为:\n";
    for(i = 1; i<K;i++)//写b
    {
        cout<<b[i]<<'\t';
        memset(buff,0,sizeof(buff));  
        sprintf(buff,"%d",b[i]);
        fout_model << buff << " ";
    }


   //读入测试数据进行测试得到正确率
    while(true)
    {
        if (getData(tr,yt,fin_test))
        {
            int yy = Pranking(tr);
            char p[2];
            p[0] = yy -1 + 48;
            p[1] = '\0';
            fout_result << p << endl;

            if (yy == yt)
                right ++;
            else
                wrong ++;

        }
        else
            break;
    }
    cout<<"\n\n排正确的个数为"<<right<<",错误的个数为"<<wrong<<",正确率为%"<<right*100*1.0/(right+wrong)<<endl;
   
   
   //释放申请的空间并关闭文件  
    delete []w;   
    delete []y;
    delete []t;
    delete []b;
    delete []tr;
    fin_train.close();
    fin_test.close();
    fout_result.close();
    fout_model.close();
    return 0;
}
 

0

阅读 评论 收藏 转载 喜欢 打印举报/Report
  • 评论加载中,请稍候...
发评论

    发评论

    以上网友发言只代表其个人观点,不代表新浪网的观点或立场。

      

    新浪BLOG意见反馈留言板 电话:4000520066 提示音后按1键(按当地市话标准计费) 欢迎批评指正

    新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 会员注册 | 产品答疑

    新浪公司 版权所有