2012-05-21 28 views
7

Estoy un poco avergonzado de admitir esto, pero parece que estoy bastante perplejo por lo que debería ser un problema de programación simple. Estoy construyendo una implementación de árbol de decisión, y he estado usando recursividad para tomar una lista de muestras etiquetadas, dividir recursivamente la lista por la mitad y convertirla en un árbol.Codificando la creación recursiva de árboles con while loop + stacks

Desafortunadamente, con árboles profundos me encuentro con errores de desbordamiento de pila (ha!), Así que mi primer pensamiento fue utilizar continuaciones para convertirlo en recursividad de cola. Desafortunadamente, Scala no es compatible con ese tipo de TCO, por lo que la única solución es usar un trampolín. Un trampolín parece un poco ineficiente y esperaba que hubiera una solución imperativa simple basada en la pila para este problema, pero estoy teniendo muchos problemas para encontrarlo.

La versión recursiva se ve algo así como (simplificado):

private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = { 
    if (shouldStop(samples)) { 
    DTLeaf(makeProportions(samples)) 
    } else { 
    val featureIdx = getSplittingFeature(samples, usedFeatures) 
    val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
    DTBranch(
     trainTree(statsWithFeature, usedFeatures + featureIdx), 
     trainTree(statsWithoutFeature, usedFeatures + featureIdx), 
     featureIdx) 
    } 
} 

Así que básicamente estoy de forma recursiva la subdivisión de la lista en dos de acuerdo con alguna característica de los datos, y que pasa a través de una lista de funciones que se utilicen de manera No repito, todo se maneja en la función "getSplittingFeature" para que podamos ignorarlo. ¡El código es realmente simple! Aún así, estoy teniendo problemas para encontrar una solución basada en pila que no solo use cierres y se convierta efectivamente en un trampolín. Sé que al menos tendremos que mantener pequeños "marcos" de argumentos en la pila, pero me gustaría evitar las llamadas de cierre.

Entiendo que debería escribir explícitamente lo que el contador de callstack y el programa manejan implícitamente en la solución recursiva, pero tengo problemas para hacerlo sin continuación. En este punto, casi no se trata de eficiencia, solo tengo curiosidad. Entonces, por favor, no hay necesidad de recordarme que la optimización prematura es la raíz de todo mal y que la solución basada en el trampolín probablemente funcionará bien. Sé que probablemente lo haga, esto es básicamente un rompecabezas por su propio bien.

¿Alguien puede decirme cuál es la solución canónica de este tipo de cosas para el bucle y la pila?

ACTUALIZACIÓN: Basado en la excelente solución de Thipor Kong, he codificado una implementación basada en while-loops/stacks/hashtable del algoritmo que debería ser una traducción directa de la versión recursiva. Esto es exactamente lo que estaba buscando:

ACTUALIZACIÓN FINAL: He utilizado índices de números enteros secuenciales, así como poner todo de nuevo en matrices en lugar de mapas para el rendimiento, la compatibilidad con maxDepth añadido, y finalmente tener una solución con el mismo rendimiento que la versión recursiva (no estoy seguro sobre el uso de la memoria, pero yo supongo menos):

private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = { 
    // Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit 
    type DenseIntMap[T] = ArrayBuffer[T] 
    def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = { 
    if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) } 
    ab.update(idx, item) 
    } 
    var currentChildId = 0 // get childIdx or create one if it's not there already 
    def child(childMap: DenseIntMap[Int], heapIdx: Int) = 
    if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx) 
    else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId } 
    // go down 
    val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx 
    val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx 
    val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx 
    val nodes = new DenseIntMap[DTree]() // heapIdx -> node 
    while (!todo.isEmpty) { 
    val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop() 
    if (shouldStop(samples) || maxDepth == 0) { 
     updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples))) 
    } else { 
     val featureIdx = getSplittingFeature(samples, usedFeatures) 
     val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
     todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx))) 
     todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx))) 
     branches.push((heapIdx, featureIdx)) 
    } 
    } 
    // go up 
    while (!branches.isEmpty) { 
    val (heapIdx, featureIdx) = branches.pop() 
    updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx)) 
    } 
    nodes(0) 
} 
+0

¿No se está descargando a una implementación basada en la pila (donde la pila está en el montón) conceptualmente lo mismo que la trampolín? – ron

+0

Más o menos, pero trampolín significa que mantendrás una pila llena de cierres, donde espero que haya una solución que solo utilice una pila llena de datos. Tal vez los datos etiquetados como StepOne (a, b, c), StepTwo (a, b, c) o varias pilas o algo así, pero no se realizarán llamadas a funciones. – lvilnis

+0

Hice otro cambio en mi código. El espacio de nombre de los identificadores de nodo se usa de forma más económica y puede agregar su propio tipo de ID de nodo (o BigInt, si lo desea). –

Respuesta

3

tienda sólo el árbol binario en una matriz, tal como se describe en Wikipedia: Para el nodo i, el hijo izquierdo entra en 2*i+1 y el niño correcto en 2*i+2. Al hacer "abajo", mantienes una colección de todos, que aún tienen que dividirse para llegar a una hoja. Una vez que tenga solo hojas, para ir hacia arriba (de derecha a izquierda en la matriz) para construir los nodos de decisión:

Actualización: Una versión limpia, que también admite las funciones almacenadas en las ramas (tipo parámetro B) y que es más funcional/completamente puro y que admite árboles dispersos con un mapa como lo sugiere ron.

Update2-3: Haga un uso económico del espacio de nombres para identificadores de nodos y abstracte sobre el tipo de identificadores para permitir árboles grandes. Tome los ID de nodo de Stream.

sealed trait DTree[A, B] 
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B] 
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B] 

def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = { 
    @tailrec 
    def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) = 
    todo match { 
     case Nil => (branches, leafs) 
     case (a, b, id) :: rest => 
     split(a, b) match { 
      case None => 
      goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids) 
      case Some((left, right, b2)) => 
      val leftId #:: rightId #:: idRest = ids 
      goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest) 
     } 
    } 

    @tailrec 
    def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] = 
    branches match { 
     case Nil => nodes 
     case (id, b, leftId, rightId) :: rest => 
     goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b))) 
    } 

    val rootId #:: restIds = ids 
    val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds) 
    goUp(branches, leafs)(rootId) 
} 

// try it out 

def split(xs: Seq[Int], b: Int) = 
    if (xs.size > 1) { 
    val (left, right) = xs.splitAt(xs.size/2) 
    Some((left, right, b + 1)) 
    } else { 
    None 
    } 

val tree = mktree(0 to 1000, 0, split _, Stream.from(0)) 
println(tree) 
+0

¿Qué pasa con el hecho de que cada DTBranch necesita un "featureIndex"? Eso lo hace un poco más complicado ya que para convertir todas las hojas en ramas necesitamos su featureIndex, y luego para combinar esas ramas necesitamos su featureIndexes, y así sucesivamente. Creo que esta es la idea correcta, así que jugaré con eso. – lvilnis

+0

Pone los featureIndices en el montón cuando baja (en lugar de None), para tenerlo disponible para crear el DTBranch, cuando vuelva a subir. –

+0

¡Eso es increíble! Lo probaré y marcaré el tuyo como respuesta dentro de una hora. – lvilnis