Back to Modules
intermediateComponents

Memory Model: MDN-RNN

Explore how the MDN-RNN predicts future states and maintains temporal context.

40 min read
2 references

Memory Model: MDN-RNN

Overview

The Memory Model in World Models uses a Mixture Density Network - Recurrent Neural Network (MDN-RNN) to predict future latent states.

Why MDN-RNN?

The environment's dynamics are often:

  • Stochastic: The same action can lead to different outcomes
  • Multimodal: Multiple valid future states may exist
  • Temporal: Current state depends on history

An MDN-RNN addresses these challenges by:

  1. Using an RNN to capture temporal dependencies
  2. Using an MDN to model multimodal distributions

Mixture Density Networks

An MDN models the output distribution as a mixture of Gaussians, allowing the model to:

  • Capture uncertainty in predictions
  • Model multimodal distributions
  • Handle stochastic environments

The Role of Hidden State

The LSTM hidden state h serves multiple purposes:

  1. Memory: Stores information about past observations
  2. Context: Provides temporal context for predictions
  3. Belief State: Represents the agent's belief about the world

Dream Training

A key application of the MDN-RNN is dream training:

  1. Initialize with a real observation
  2. Sample actions from the controller
  3. Use MDN-RNN to predict next states
  4. Train the controller entirely in this "dream"
References
Academic papers and resources

Mixture Density Networks

Christopher M. Bishop (1994)

paper

Long Short-Term Memory

Sepp Hochreiter, Jürgen Schmidhuber (1997)

paper