2011-11-15 18 views
8

Estoy tratando de usar PyBrain para un entrenamiento simple de NN. Lo que no sé hacer es cargar los datos de entrenamiento de un archivo. No se explica en su sitio web en ninguna parte. No me importa el formato porque puedo compilarlo ahora, pero tengo que hacerlo en un archivo en lugar de agregarlo fila por fila manualmente, porque tendré varios cientos de filas.Cómo cargar datos de entrenamiento en PyBrain?

+1

Varios cientos de filas significa que tiene un conjunto muy pequeño y no debe preocuparse por el rendimiento. ¿Pero PyBrain no solo acepta matrices NumPy? –

+0

No lo sé, apenas estoy comenzando a usarlo, pero en ninguna parte dicen cómo usar las matrices NumPy con su NN:/ –

Respuesta

21

Aquí es cómo lo hice:

 
ds = SupervisedDataSet(6,3) 

tf = open('mycsvfile.csv','r') 

for line in tf.readlines(): 
    data = [float(x) for x in line.strip().split(',') if x != ''] 
    indata = tuple(data[:6]) 
    outdata = tuple(data[6:]) 
    ds.addSample(indata,outdata) 

n = buildNetwork(ds.indim,8,8,ds.outdim,recurrent=True) 
t = BackpropTrainer(n,learningrate=0.01,momentum=0.5,verbose=True) 
t.trainOnDataset(ds,1000) 
t.testOnData(verbose=True) 

En este caso, la red neuronal tiene 6 entradas y 3 salidas. El archivo csv tiene 9 valores en cada línea separados por una coma. Los primeros 6 valores son valores de entrada y los últimos tres son salidas.

+0

que es genial, muchas gracias. ¿Sabes cómo puedo acceder a los valores de peso para cada neurona? –

+1

Puede acceder a capas individuales como esta: n ['in'] para la capa de entrada yn ['out'] para la salida o n ['hidden0'] para la primera capa oculta. No lo sé, pero supongo que puede acceder a los nodos de la capa de alguna manera. dir (n ['in']) debería darle una pista de lo que puede hacer – c0m4

+0

No puedo encontrar cómo hacerlo. Haré una nueva pregunta. Gracias por tu ayuda. –

1

que acaba de utilizar matrices de pandas de esta manera

import pandas as pd 

ds = SupervisedDataSet(6,3) 

dataset = pd.read_csv('mycsvfile.csv','r', delimiter=',',skiprows=1) 
ds.setfield('input' dataset.values[:,0:6]) 
ds.setfield('target', dataset.values[:,-2:-1]) 

y que son buenos para ir.

Cuestiones relacionadas