Estoy un poco avergonzado de admitir esto, pero parece que estoy bastante perplejo por lo que debería ser un problema de programación simple. Estoy construyendo una implementación de árbol de decisión, y he estado usando recursividad para tomar una lista de muestras etiquetadas, dividir recursivamente la lista por la mitad y convertirla en un árbol.Codificando la creación recursiva de árboles con while loop + stacks
Desafortunadamente, con árboles profundos me encuentro con errores de desbordamiento de pila (ha!), Así que mi primer pensamiento fue utilizar continuaciones para convertirlo en recursividad de cola. Desafortunadamente, Scala no es compatible con ese tipo de TCO, por lo que la única solución es usar un trampolín. Un trampolín parece un poco ineficiente y esperaba que hubiera una solución imperativa simple basada en la pila para este problema, pero estoy teniendo muchos problemas para encontrarlo.
La versión recursiva se ve algo así como (simplificado):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
Así que básicamente estoy de forma recursiva la subdivisión de la lista en dos de acuerdo con alguna característica de los datos, y que pasa a través de una lista de funciones que se utilicen de manera No repito, todo se maneja en la función "getSplittingFeature" para que podamos ignorarlo. ¡El código es realmente simple! Aún así, estoy teniendo problemas para encontrar una solución basada en pila que no solo use cierres y se convierta efectivamente en un trampolín. Sé que al menos tendremos que mantener pequeños "marcos" de argumentos en la pila, pero me gustaría evitar las llamadas de cierre.
Entiendo que debería escribir explícitamente lo que el contador de callstack y el programa manejan implícitamente en la solución recursiva, pero tengo problemas para hacerlo sin continuación. En este punto, casi no se trata de eficiencia, solo tengo curiosidad. Entonces, por favor, no hay necesidad de recordarme que la optimización prematura es la raíz de todo mal y que la solución basada en el trampolín probablemente funcionará bien. Sé que probablemente lo haga, esto es básicamente un rompecabezas por su propio bien.
¿Alguien puede decirme cuál es la solución canónica de este tipo de cosas para el bucle y la pila?
ACTUALIZACIÓN: Basado en la excelente solución de Thipor Kong, he codificado una implementación basada en while-loops/stacks/hashtable del algoritmo que debería ser una traducción directa de la versión recursiva. Esto es exactamente lo que estaba buscando:
ACTUALIZACIÓN FINAL: He utilizado índices de números enteros secuenciales, así como poner todo de nuevo en matrices en lugar de mapas para el rendimiento, la compatibilidad con maxDepth añadido, y finalmente tener una solución con el mismo rendimiento que la versión recursiva (no estoy seguro sobre el uso de la memoria, pero yo supongo menos):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}
¿No se está descargando a una implementación basada en la pila (donde la pila está en el montón) conceptualmente lo mismo que la trampolín? – ron
Más o menos, pero trampolín significa que mantendrás una pila llena de cierres, donde espero que haya una solución que solo utilice una pila llena de datos. Tal vez los datos etiquetados como StepOne (a, b, c), StepTwo (a, b, c) o varias pilas o algo así, pero no se realizarán llamadas a funciones. – lvilnis
Hice otro cambio en mi código. El espacio de nombre de los identificadores de nodo se usa de forma más económica y puede agregar su propio tipo de ID de nodo (o BigInt, si lo desea). –