2012-07-31 15 views
7

Estoy tratando de encontrar la forma más rápida de encontrar el primer valor distinto de cero para cada fila de una matriz ordenada bidimensional. Técnicamente, los únicos valores en la matriz son ceros y unos, y está "ordenado".Encontrar el primer valor distinto de cero a lo largo del eje de una matriz numpy bidimensional ordenada

Por ejemplo, la matriz podría ser similar al siguiente:

v =

0 0 0 1 1 1 1 
0 0 0 1 1 1 1 
0 0 0 0 1 1 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 0 

que podría utilizar la función argmax

argmax(v, axis=1)) 

encontrar cuando se cambia de cero a uno , pero creo que esto haría una búsqueda exhaustiva a lo largo de cada fila. Mi matriz tendrá un tamaño razonable (~ 2000x2000). ¿Haría argmax aún mejor que simplemente haciendo un enfoque de búsqueda para cada fila dentro de un bucle for, o hay una mejor alternativa?

Además, la matriz siempre será tal que la primera posición de una para una fila siempre es> = la primera posición de una en la fila de arriba (pero no se garantiza que haya una en las últimas filas). Podría explotar esto con un bucle for y un "valor inicial de índice" para cada fila igual a la posición del primer 1 de la fila anterior, pero estoy en lo cierto al pensar que la función numpy argmax aún superará un bucle escrito en python .

Simplemente compararía las alternativas, pero la longitud del borde de la matriz podría cambiar bastante (de 250 a 10.000).

+0

que muy haría mucho esperar que la función argmax sea más rápida. Si es crítico para el rendimiento, podría intentar escribir una extensión en C – SudoNhim

Respuesta

4

Es razonablemente rápido de usar np.where:

>>> a 
array([[0, 0, 0, 1, 1, 1, 1], 
     [0, 0, 0, 1, 1, 1, 1], 
     [0, 0, 0, 0, 1, 1, 1], 
     [0, 0, 0, 0, 0, 0, 1], 
     [0, 0, 0, 0, 0, 0, 1], 
     [0, 0, 0, 0, 0, 0, 1], 
     [0, 0, 0, 0, 0, 0, 0]]) 
>>> np.where(a>0) 
(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5]), array([3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 6, 6, 6])) 

que ofrece tuplas con las coordenadas de los valores mayores que 0.

También puede utilizar np.donde para probar cada sub array:

def first_true1(a): 
    """ return a dict of row: index with value in row > 0 """ 
    di={} 
    for i in range(len(a)): 
     idx=np.where(a[i]>0) 
     try: 
      di[i]=idx[0][0] 
     except IndexError: 
      di[i]=None  

    return di  

Prints:

{0: 3, 1: 3, 2: 4, 3: 6, 4: 6, 5: 6, 6: None} 

es decir, la fila 0: Índice 3> 0; fila 4: índice 4> 0; fila 6: ningún índice mayor que 0

Como se sospecha, argmax puede ser más rápido:

def first_true2(): 
    di={} 
    for i in range(len(a)): 
     idx=np.argmax(a[i]) 
     if idx>0: 
      di[i]=idx 
     else: 
      di[i]=None  

    return di  
    # same dict is returned... 

Si usted puede hacer frente a la lógica de no tener un None para las filas de todas las nadas, esto es aún más rápido :

def first_true3(): 
    di={} 
    for i, j in zip(*np.where(a>0)): 
     if i in di: 
      continue 
     else: 
      di[i]=j 

    return di  

Y aquí es una versión que utiliza eje en argmax (como se sugiere en sus comentarios):

def first_true4(): 
    di={} 
    for i, ele in enumerate(np.argmax(a,axis=1)): 
     if ele==0 and a[i][0]==0: 
      di[i]=None 
     else: 
      di[i]=ele 

    return di   

Para las comparaciones de velocidad (en su matriz de ejemplo), me sale esto:

  rate/sec usec/pass first_true1 first_true2 first_true3 first_true4 
first_true1 23,818 41.986   --  -34.5%  -63.1%  -70.0% 
first_true2 36,377 27.490  52.7%   --  -43.6%  -54.1% 
first_true3 64,528 15.497  170.9%  77.4%   --  -18.6% 
first_true4 79,287 12.612  232.9%  118.0%  22.9%   -- 

Si escala que a una matriz np 2000 x 2000, esto es lo que me sale:

  rate/sec usec/pass first_true3 first_true1 first_true2 first_true4 
first_true3  3 354380.107   --  -0.3%  -74.7%  -87.8% 
first_true1  3 353327.084  0.3%   --  -74.6%  -87.7% 
first_true2  11 89754.200  294.8%  293.7%   --  -51.7% 
first_true4  23 43306.494  718.3%  715.9%  107.3%   -- 
+0

En realidad, lo mejor de argmax es que puede especificar un eje, es decir 'argmax (a, axis = 1)' y recorrerá las filas utilizando un bucle escrito en C por lo que no tiene que usar un ciclo de pitón, que debería ser más lento. – user1554752

+0

@ user1554752: Sí, pero si usa 'argmax (a, axis = 1)', existe una ambigüedad entre las filas en 'a' que son' [1, x, x, x,] 'o' [0, 0,0,0] 'ya que' argmax (a, axis = 1) 'devolvería' 0' para cualquier caso. Todavía tendrá que recorrer la matriz que argmax devuelve para probar esta ambigüedad, ¿no? – dawg

+0

Ahí es donde podría aprovechar el patrón en los datos donde el primer 1 nunca está en una posición a la izquierda del primer 1 en la fila de arriba. Una vez que tengo la matriz de argmax (llámala indx), puedo ejecutar un argmin en ella. Si devuelve un valor p! = 0, todas las filas desde p hacia abajo se hicieron únicamente con ceros. – user1554752

4

argmax() utiliza bucle de nivel C, que es mucho más rápido que el bucle de Python, así que creo incluso que escribir un algoritmo inteligente en Python, que es difícil de superar argmax(), puede utilizar Cython a speedup:

@cython.boundscheck(False) 
@cython.wraparound(False) 
def find(int[:,:] a): 
    cdef int h = a.shape[0] 
    cdef int w = a.shape[1] 
    cdef int i, j 
    cdef int idx = 0 
    cdef list r = [] 
    for i in range(h): 
     for j in range(idx, w): 
      if a[i, j] == 1: 
       idx = j 
       r.append(idx) 
       break 
     else: 
      r.append(-1) 
    return r 

En mi PC para la matriz 2000x2000, es 100us vs 3ms.

Cuestiones relacionadas