2012-05-29 19 views
13

Estoy tratando de utilizar los enlaces Java para libsvm:libsvm aplicación java

http://www.csie.ntu.edu.tw/~cjlin/libsvm/ 

he implementado un ejemplo de 'trivial' que es fácilmente separables linealmente en y. Los datos se definen como:

double[][] train = new double[1000][]; 
double[][] test = new double[10][]; 

for (int i = 0; i < train.length; i++){ 
    if (i+1 > (train.length/2)){  // 50% positive 
     double[] vals = {1,0,i+i}; 
     train[i] = vals; 
    } else { 
     double[] vals = {0,0,i-i-i-2}; // 50% negative 
     train[i] = vals; 
    }   
} 

Donde la primera 'característica' es la clase y el conjunto de entrenamiento se define de manera similar.

para entrenar el modelo:

private svm_model svmTrain() { 
    svm_problem prob = new svm_problem(); 
    int dataCount = train.length; 
    prob.y = new double[dataCount]; 
    prob.l = dataCount; 
    prob.x = new svm_node[dataCount][];  

    for (int i = 0; i < dataCount; i++){    
     double[] features = train[i]; 
     prob.x[i] = new svm_node[features.length-1]; 
     for (int j = 1; j < features.length; j++){ 
      svm_node node = new svm_node(); 
      node.index = j; 
      node.value = features[j]; 
      prob.x[i][j-1] = node; 
     }   
     prob.y[i] = features[0]; 
    }    

    svm_parameter param = new svm_parameter(); 
    param.probability = 1; 
    param.gamma = 0.5; 
    param.nu = 0.5; 
    param.C = 1; 
    param.svm_type = svm_parameter.C_SVC; 
    param.kernel_type = svm_parameter.LINEAR;  
    param.cache_size = 20000; 
    param.eps = 0.001;  

    svm_model model = svm.svm_train(prob, param); 

    return model; 
} 

Luego de evaluar el modelo que utilizo:

public int evaluate(double[] features) { 
    svm_node node = new svm_node(); 
    for (int i = 1; i < features.length; i++){ 
     node.index = i; 
     node.value = features[i]; 
    } 
    svm_node[] nodes = new svm_node[1]; 
    nodes[0] = node; 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(_model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(_model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return (int)v; 
} 

Cuando la matriz pasada es un punto a partir del conjunto de pruebas.

Los resultados vuelven siempre la clase 0. cuyos resultados serán exacta:

(0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0) 
(0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0) 
(0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0) 
(0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0) 
(0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0) 
(0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0) 
(0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0) 
(0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0) 
(0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0) 
(0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0) 

Puede alguien explicar por qué este clasificador no está funcionando? ¿Hay algún paso que haya echado a perder, o un paso que me falta?

Gracias

Respuesta

13

me parece que su método de evaluación es incorrecto. Debe ser algo como esto:

public double evaluate(double[] features, svm_model model) 
{ 
    svm_node[] nodes = new svm_node[features.length-1]; 
    for (int i = 1; i < features.length; i++) 
    { 
     svm_node node = new svm_node(); 
     node.index = i; 
     node.value = features[i]; 

     nodes[i-1] = node; 
    } 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return v; 
} 
+4

¿Podría explicarme cuál es el error en el código de pregunta? ¡Tengo problemas para detectar el error! :( – Daniel

2

Aquí es una reelaboración del ejemplo anterior que he probado utilizando datos procedentes del siguiente código R: http://cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf

import libsvm.*; 

public class libsvmTest { 

    public static void main(String [] args) { 

     double[][] xtrain = ... 
     double[][] xtest = ... 
     double[][] ytrain = ... 
     double[][] ytest = ... 

     svm_model m = svmTrain(xtrain,ytrain); 

     double[] ypred = svmPredict(xtest, m); 

     for (int i = 0; i < xtest.length; i++){ 
      System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")"); 
     } 

    } 

    static svm_model svmTrain(double[][] xtrain, double[][] ytrain) { 
     svm_problem prob = new svm_problem(); 
     int recordCount = xtrain.length; 
     int featureCount = xtrain[0].length; 
     prob.y = new double[recordCount]; 
     prob.l = recordCount; 
     prob.x = new svm_node[recordCount][featureCount];  

     for (int i = 0; i < recordCount; i++){    
      double[] features = xtrain[i]; 
      prob.x[i] = new svm_node[features.length]; 
      for (int j = 0; j < features.length; j++){ 
       svm_node node = new svm_node(); 
       node.index = j; 
       node.value = features[j]; 
       prob.x[i][j] = node; 
      }   
      prob.y[i] = ytrain[i][0]; 
     }    

     svm_parameter param = new svm_parameter(); 
     param.probability = 1; 
     param.gamma = 0.5; 
     param.nu = 0.5; 
     param.C = 100; 
     param.svm_type = svm_parameter.C_SVC; 
     param.kernel_type = svm_parameter.LINEAR;  
     param.cache_size = 20000; 
     param.eps = 0.001;  

     svm_model model = svm.svm_train(prob, param); 

     return model; 
    } 

    static double[] svmPredict(double[][] xtest, svm_model model) 
    { 

     double[] yPred = new double[xtest.length]; 

     for(int k = 0; k < xtest.length; k++){ 

     double[] fVector = xtest[k]; 

     svm_node[] nodes = new svm_node[fVector.length]; 
     for (int i = 0; i < fVector.length; i++) 
     { 
      svm_node node = new svm_node(); 
      node.index = i; 
      node.value = fVector[i]; 
      nodes[i] = node; 
     } 

     int totalClasses = 2;  
     int[] labels = new int[totalClasses]; 
     svm.svm_get_labels(model,labels); 

     double[] prob_estimates = new double[totalClasses]; 
     yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates); 

     } 

     return yPred; 
    } 


} 

Aquí está la salida:

(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
+0

Muchas gracias por el útil código. ¿Por qué usaste param.probability = 1 ;? y segundo, ¿sabes cómo se puede establecer el peso si uno tiene clases desbalanceadas? Quiero decir el peso con el que el parámetro C es ponderado – machinery

+0

No prob_estimates scope cuando llama a svm.svm_predict_probability()? – user1040535

+0

Esto es simplemente una publicación para ayudar a comenzar con LIBSVM, de ahí, es el usuario para determinar qué funciona de acuerdo con el problema. Para preguntas sobre esto, le sugiero que visite el sitio de los mantenedores de este paquete: https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q06:_Probability_outputs –