Double Deep Q-Network: DQN mit Stabilitätsupgrades
Das Deep Q-Network leidet manchmal unter verschiedenen Problemen. Die Probleme werden hier vorgestellt und eine Lösung für diese Probleme präsentiert.
Henrik Bartsch
Einordnung
In einem früheren Post wurde das Prinzip des Deep Q-Networks vorgestellt. Bei genauerer Untersuchung des Netzwerkes stellt
sich häufig heraus, dass das Modell unter Overestimation leidet, welches das Training instabil werden lässt.
Overestimation beschreibt das Prinzip, das erwartete Rewards als zu hoch vorhergesagt werden.
Overestimation verringert hierbei die Güte des Trainingsprozesses und sollte damit verringert oder vermieden werden.
Eine weitere Erklärung für Probleme im DQN ist die Tatsache, dass das sogenannte Q-Target kein konstanter Wert ist.
Als Q-Target wir der Term yi=ri+γmaxai′Q(si′,ai′) in L=N1∑i=0N−1(Q(si,ai)−yi)2 bezeichnet.
Durch die Approximation eines nicht konstanten Wertes ist die Approximation grundsätzlich instabiler. Ein Target-Network verringert die Problematik hierbei. 12
Die Forschung im Bereich der künstlichen neuronalen Netze erzielte dann ab dem Jahr 2010 Erfolge darin, dieser Algorithmus zu verbessern. Das Ergebnis wurde als Double Deep Q-Network
bezeichnet. Im Gegensatz zu Deep Q-Networks verzichtet es hierbei auf Frozen Target-Networks, um unter anderem die oben angesprochene Overestimation weiter zu verringern.
Im Folgenden kann dieser Algorithmus auch als D2QN bezeichnet werden.
Versionen des Algorithmus
Die folgenden Informationen über die verschiedenen Arten des Algorithmus wurden aus der Quelle 3 entnommen.
Das Grundprinzip des Algorithmus besteht darin, dass eine Kombinnation aus zwei verschiedenen Netzwerken verwendet wird, um die Overestimation zu verringern. Dies geschieht dadurch, dass die
beiden Netze miteinander trainiert werden, und so den Bias aus den Netzwerk-Updates herausbekommen.
Der ursprüngliche Algorithmus aus dem Jahr 2010 beinhaltet zwei verschiedene Netzwerke, welche in einer Art ϵ-greedy-Schema (mit ϵ=0.5) ausgewählt werden. In jedem
Zeitschritt wird ein zufälliges Netzwerk ausgewählt, das Update wird anschließend über den Mean Squared Error des Unterschiedes in der Vorhersage gefittet.
Problematisch bei dieser Implementierung (im Vergleich zu den neueren Alternativen) liegt darin, dass theoretisch jedem Netzwerk lediglich 50% der generierten Informationen zukommen, die
anderen 50 % werden nicht für Updates des anderen Netzwerkes verwendet. Dies verringert die Sample Efficiency des Algorithmus bedeutend. Weiterhin kann das Problem der Overestimation
noch weiter verbessert werden.
Sample Efficieny beschreibt wie effizient ein Algorithmus aus den gegebenen Informationen lernen kann; ein effizienter Algorithmus wird mit bedeutend weniger Episoden (Samples) auskommen.
Der Algorithmus aus dem Jahr 2015 stellt einen großen Meilenstein in der Entwicklung des DDQN-Algorithmus dar. Hier wird das erste Mal ein Primary- und Target-Network eingeführt.
Diese beiden Netzwerke sind hierbei nur als teilweise unabhängig definiert, beide werden am Anfang auch klassischerweise mit den identischen Gewichten initialisiert.
Das Primary-Network stellt das Netzwerk dar, welches für die Auswahl der Aktionen in Abhängigkeit des aktuellen Zustands verantwortlich ist.
Das Target-Network stellt das Netzwerk dar, welches die Overestimation verhindert. Es stellt einen “älteren Zustand” des Netzwerkes dar und verhindert das schnelle Vergessen von Informationen,
welche bereits durch das Primary-Network gelernt wurden. Das Target-Network “bewertet” damit die ausgewählte Aktion.
Eine wichtige Änderung ist hierbei das Update der Netzwerke. Der Target-Q-Value für das Primary-Network wird hierbei durch die Voraussage des Target-Networks bestimmt und über Mean Squared
Error gefittet, während wir bei dem Target-Network ein Soft-Update durchführen:
θ′←τθ+(1−τ)θ′
Hierbei gilt für die Konstante τ klassischerweise die folgende Beschränkung: τ∈(0,1). Für den Grenzfall τ=1 erhält man zwei identische Netzwerke, für τ=0
erhält man kein Update auf dem Target-Network. Hierdurch würde es nicht mehr lernen.
Das Original-Paper zu diesem Algorithmus von Hasselt ist hier zu finden.
Eine kleine Verbesserung wurde anschließend noch 2018 veröffentlicht. Man kann die Overestimation noch dadurch verringern, indem man das Minimum der Voraussage von den beiden Netzwerken dafür
verwendet, den Target-Q-Value zu berechnen. Ansonsten existiert hier kein Unterschied zu der Veröffentlichung von 2015.
Implementierung
Unter den Imports findet sich nichts Neues, diese sind identisch zu der Implementierung vom Deep Q-Network. Auch das
ExperienceMemory ist identisch.
In der DDQNAgent-Klasse könnten einem schnell ein paar Änderungen ins Auge fallen. Hier ist eine klare Trennung zwischen primary_network und target_network, welche identisch initialisiert
werden müssen. Zusätzlich wurde der Soft-Update-Parameterτ∈[0,1] hier definiert. Ansonsten findet sich hierbei keine Änderung.
Bei der Trainingsfunktion findet man direkt die Änderung zum Deep Q-Network, welche sich in der Berechnung des Q-Targets findet. Hierbei wird diesmal nicht nur primary_network oder target_network
verwendet, um das Q-Target zu berechnen, sondern jeweils komponentenweise das Minimum der Vorhersagen. Anschließend wird hierbei das primary_network auf dem Q-Target gefittet und das target_network
über ein Soft-Update updatet.
Die Trainings- und Evaluierungsloops orientiert sich am Standard von der Implementierung des Deep Q-Networks’s.
Als Ergebnis erhält man beispielsweise das folgende Diagramm:
Hinweis: Trainingsperformance kann sich von Gerät zu Gerät und entsprechenden Seeds teilweise stark unterscheiden. Allgemeine Reproduzierbarkeit von solchen Ergebnissen ist im Allgemeinen nicht
garantierbar.
Hierbei kann man klar erkennen, dass das D2QN am Anfang bedeutend besser traininert als dies das DQN tut. Später flacht der Erfolg allerdings ab, was wahrscheinlich auf eine schlechte Einstellung
der Hyperparameter zurückzuführen ist.
Zum Vergleich wurden zwei Netzwerke von ähnlicher Struktur verwendet; beide Netzwerke (dqn_network und primary_network) hatten hierbei die gleiche Anzahl von Neuronen. Dies muss bei praktischer
Anwendung nicht zielführend sein sondern gilt hier lediglich der Vergleichbarkeit.
Änderungen
[23.01.2023] Einführung interaktiver Plots, Hinweis auf Nicht-Reproduzierbarkeit von Ergebnissen