Bevy dfdx and the Classic Cart Pole

Sunday, October 23, 2022

I'm relatively new to Rust and recently was looking for a new project to challenge the meager skills I have. I decided to tackle the Cart Pole Problem not because of my extensive experience with Bevy or dfdx, but my lack thereof. This post does not show best practices, or code that should probably be copied, but a quick, easy, and fun way to solve the cart pole problem using bevy and dfdx.

I have never built a game, but found Bevy, one of Rust's more popular ECS (entity component system) game engines to be amazing to work with. If you haven't heard of Bevy I highly encourage checking it out, they have an awesome community: https://bevyengine.org/

Dfdx is a simple deep learning library in rust. It has a bunch of awesome features that make deep learning a breeze. Once again, if you haven't seen it before, I would check it out: https://github.com/coreylowman/dfdx

What We Are Building
====================================================================================================================================================================================================================================================================================================================================
The Cart Pole Problem
====================================================================================================================================================================================================================================================================================================================================

For those not familiar, the Cart Pole Problem is the hello world of AI. If you have any experience with Tenserflow, PyTorch, or Jax, most of the examples are solving the Cart Pole Problem.

The premise is simple. You control a cart in a track that only allows for horizontal movement. On top of the cart is a pole with a hinge connected to the cart that limits rotation to only the z axis (it can rotate left and right). The goal is to keep the pole upright as long as possible. The modern version of the Cart Pole ends after 500 steps (approximately 10 seconds according to OpenAI).

For those interested in learning more, please check out OpenAI's wiki. Please note, we are playing v1.

Deep Q-Learning
====================================================================================================================================================================================================================================================================================================================================

There are a multitude of ways to solve the Cart Pole Problem. From my understanding, most of these use some variation of Reinforcement Learning, where the AI learns from its past attempts. In our case, we will be using something called Deep Q-Learning, or techincally, Double Deep Q-Learning, as we will have two models, a target model, and one actually playing the game.

While building a deep reinforcement learning model from scratch would make enough content for a post of its own, dfdx provides us with a number of luxuries that greatly speed up this process. Using dfdx we are able to abstract away the nitty gritty of moving forward and backwards through the network, and the training process. The actual capabilities of the library extend far beyond what it was used for here.

I won't go into the details of Q-Learning, as that is far outside the scope of this post, but for anyone interested I highly recommend Hugging Face's deep reinforcement learning class: https://huggingface.co/blog/deep-rl-dqn

Some Code
====================================================================================================================================================================================================================================================================================================================================

First we want to create our world. Luckily that is incredibly simple with Bevy.

fn main() {
    App::new()
        .insert_resource(WindowDescriptor {
            title: "Cart Pole".to_string(),
            present_mode: PresentMode::AutoVsync,
            ..default()
        })
        .add_plugins(DefaultPlugins)
        .add_startup_system(add_camera)
        .add_system(size_scaling)
        .add_startup_system(add_cart_pole)
        .add_startup_system(add_model.exclusive_system())
        .add_system(step)
        .run();
}

Important things to note:

  1. Bevy's DefaultPlugins include a number of various sytems. Most importantly for us, they include the actual game loop, and the window so we can see our game. We don't need all of the systems included in DefaultPlugins, but for simplicity's sake, we will leave them.
  2. We have 5 custom sytems we add: the add_camera, size_scaling (handles sprite scaling), add_cart_pole, add_model, and step

The add_camera_system adds a camera so we can see the game. The size_scaling is code taken from one of Bevy's tutorials that helps scale sprites as the window size changes. None of these systems are worth noting or viewing here.

The add_cart_pole system adds the cart and pole entities to our world.

let cart_handle = asset_server.load("cart.png");
let pole_handle = asset_server.load("pole.png");
commands
    .spawn_bundle(SpriteBundle {
        sprite: Sprite {
            custom_size: Some(Vec2::new(1., 1.)),
            ..default()
        },
        texture: cart_handle,
        transform: Transform {
            translation: Vec3::new(0., 0., 0.),
            scale: Vec3::new(1., 1., 1.),
            ..default()
        },
        ..default()
    })
    .insert(Cart)
    .insert(Velocity::default())
    .insert(Size {
        width: 0.6,
        height: 0.3,
    });
commands
    .spawn_bundle(SpriteBundle {
        sprite: Sprite {
            anchor: sprite::Anchor::BottomCenter,
            custom_size: Some(Vec2::new(1., 1.)),
            ..default()
        },
        texture: pole_handle,
        transform: Transform {
            translation: Vec3::new(0., 0., 1.),
            scale: Vec3::new(1., 1., 1.),
            ..default()
        },
        ..default()
    })
    .insert(Pole)
    .insert(Velocity::default())
    .insert(Size {
        width: 0.1,
        height: 1.,
    });

This is a pretty standard way to add sprites to Bevy, the only important things to note here are the Velocity components we add to both entities. The Velocity components are used in the step system when calculating the movement of the cart and pole. For those curious, the Size components are used in the sprite_scaling system, but nowhere else.

The last entity we add to our world is the Model. This is the agent that is going to solve the Cart Pole Problem. Note that above when we pass the add_model function into the app system, we pass it as an exclusive system. Dfdx makes use of Rusts' Rc type. This type cannot be sent safely between threads, so we set this system as exclusive to let Bevy know it must run any system using it in the main thread.

The Model is defined by the following code.

type Mlp = (
    Linear<4, 64>,
    (Linear<64, 64>, ReLU),
    (Linear<64, 32>, ReLU),
    Linear<32, 2>,
);

type Transition = ([f32; 4], i32, i32, Option<[f32; 4]>);

#[derive(Debug, Default)]
struct Model {
    model: Mlp,
    target: Mlp,
    optimizer: Adam<Mlp>,
    steps_since_last_merge: i32,
    survived_steps: i32,
    episode: i32,
    epsilon: f32,
    experience: Vec<Transition>,
}

There are probably more effective places to store items like survived_steps, and the epsilon for choosing random actions, but for this simple example, I thought it good enough.

The Game Logic
====================================================================================================================================================================================================================================================================================================================================

The final system to discuss is the step. The step is the actual game logic, a "step" in the game world.

let (mut cart_transform, mut cart_velocity) = q_cart
        .get_single_mut()
        .expect("Could not get the cart information");
let (mut pole_transform, mut pole_velocity) = q_pole
    .get_single_mut()
    .expect("Could not get the pole information");
let mut text = q_text
    .get_single_mut()
    .expect("Could not get the text with the episode info");

let observation = [
    cart_transform.translation.x,
    cart_velocity.0,
    pole_transform.rotation.z,
    pole_velocity.0,
];

let action = match model.epsilon > rand::random::<f32>() {
    true => match rand::random::<bool>() {
        true => 0,
        false => 1,
    },
    false => {
        let tensor_observation: Tensor1D<4> = TensorCreator::new(observation);
        let prediction = model.model.forward(tensor_observation);
        match prediction.data()[0] > prediction.data()[1] {
            true => 0,
            false => 1,
        }
    }
};
model.epsilon = (model.epsilon - EPSILON_DECAY).max(0.05);

// These calculations are directly from openai https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
let force = match action {
    1 => FORCE_MAG * -1.,
    _ => FORCE_MAG,
};
let costheta = pole_transform.rotation.z.cos();
let sintheta = pole_transform.rotation.z.sin();
let temp =
    (force + POLEMASS_LENGTH * pole_transform.rotation.z.powi(2) * sintheta) / TOTAL_MASS;
let thetaacc = (GRAVITY * sintheta - costheta * temp)
    / (LENGTH * (4.0 / 3.0 - MASS_POLE * (costheta * costheta) / TOTAL_MASS));
let xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;

// Apply above calculations
cart_transform.translation.x += TAU * cart_velocity.0 * cart_transform.scale.x;
cart_velocity.0 += TAU * xacc;
pole_transform.rotation.z += TAU * pole_velocity.0;
pole_velocity.0 += TAU * thetaacc;
// Match the pole x to the cart x
pole_transform.translation.x = cart_transform.translation.x;

// Check if the episode is over
if pole_transform.rotation.z > THETA_THRESHOLD_RADIANS
    || pole_transform.rotation.z < -1. * THETA_THRESHOLD_RADIANS
    || (cart_transform.translation.x / cart_transform.scale.x) > X_THRESHOLD
    || (cart_transform.translation.x / cart_transform.scale.x) < -1. * X_THRESHOLD
    || model.survived_steps > 499
{
    println!(
        "RESETTING Episode: {}  SURVIVED: {}",
        model.episode, model.survived_steps,
    );

    // Reset cart and pole variables just like openai does
    let mut rng = rand::thread_rng();
    cart_velocity.0 = rng.gen_range(-0.05..0.05);
    pole_velocity.0 = rng.gen_range(-0.05..0.05);
    cart_transform.translation.x = rng.gen_range(-0.05..0.05);
    pole_transform.translation.x = cart_transform.translation.x;
    pole_transform.rotation.z = rng.gen_range(-0.05..0.05);

    // Update the latest episode survided text
    text.sections[0].value = format!(
        "Episode: {} - Survided: {}",
        model.episode, model.survived_steps
    );

    // Reset the survived_steps, increment episode count, and push_experience
    model.survived_steps = 0;
    model.episode += 1;
    model.push_experience((observation, action, 0, None));
} else {
    model.survived_steps += 1;
    let next_observation = [
        cart_transform.translation.x,
        cart_velocity.0,
        pole_transform.rotation.z,
        pole_velocity.0,
    ];
    model.push_experience((observation, action, 1, Some(next_observation)));
}

// Train if we have the necessary experience
if model.experience.len() > BATCH_SIZE {
    model.train();
}

// Merge the target model after a certain number of steps
if model.steps_since_last_merge > 10 {
    model.target = model.model.clone();
    model.steps_since_last_merge = 0;
} else {
    model.steps_since_last_merge += 1;
}

While it seems like a lot (and could be broken down into smaller steps using event emitters), the actual logic is fairly simple, and anyone who has worked with reinforcement learning should recognize the pattern.

The first thing we do is grab the Cart and Pole position and velocity. Using these variables, we create our current step's observation. We perform an epsilon-greedy action, and given that action, calculate the force applied to the cart and pole.

The calculations for the pole's rotation and cart's position have been taken from OpenaAI's cart pole. This project would not have been possible without this already clear code. For those curious, there is a paper outlining how to correctly calculate the force on the pole, but it is far above my head: Correct equations for the dynamics of the cart-pole system

If the episode is over, we reset the cart and pole position, velocity, and rotation, reset the survived_steps to 0, increment the episode counter, and push this observation with a reward of 0. If the episode is not over, we increase the survived_steps and push the observation with a reward of 1.

Training The Model
====================================================================================================================================================================================================================================================================================================================================

After each step we train the Model.

pub fn train(&mut self) {
    // Select the experience batch
    let mut rng = rand::thread_rng();
    let distribution = rand::distributions::Uniform::from(0..self.experience.len());
    let experience: Vec<Transition> = (0..BATCH_SIZE)
        .map(|_index| self.experience[distribution.sample(&mut rng)])
        .collect();

    // Get the models expected rewards
    let observations: Vec<_> = experience.iter().map(|x| x.0.to_owned()).collect();
    let observations: [[f32; 4]; BATCH_SIZE] = observations.try_into().unwrap();
    let observations: Tensor2D<BATCH_SIZE, 4> = TensorCreator::new(observations);
    let predictions = self.model.forward(observations.trace());
    let actions_indices: Vec<_> = experience.iter().map(|x| x.1 as usize).collect();
    let actions_indices: [usize; BATCH_SIZE] = actions_indices.try_into().unwrap();
    let predictions: Tensor1D<BATCH_SIZE, dfdx::prelude::OwnedTape> =
        predictions.select(&actions_indices);

    // Get the targets expected rewards for the next_observation
    // This could be optimized but I can't think of a easy way to do it without making this
    // code much more gross, and since we are already far faster than we need to be, this is
    // fine BUT when not rendering the window, this is the bottleneck in the program
    let mut target_predictions: [f32; BATCH_SIZE] = [0.; BATCH_SIZE];
    for (i, x) in experience.iter().enumerate() {
        let target_prediction = match x.3 {
            Some(next_observation) => {
                let next_observation: Tensor1D<4> = TensorCreator::new(next_observation);
                let target_prediction = self.target.forward(next_observation);
                let target_prediction =
                    target_prediction.data()[0].max(target_prediction.data()[1]);
                target_prediction * NEXT_STATE_DISCOUNT + experience[i].2 as f32
            }
            None => experience[i].2 as f32,
        };
        target_predictions[i] = target_prediction;
    }
    let target_predictions: Tensor1D<BATCH_SIZE> = TensorCreator::new(target_predictions);

    // Get the loss and train the model
    let loss = mse_loss(predictions, &target_predictions);
    self.optimizer
        .update(&mut self.model, loss.backward())
        .expect("Oops, we messed up");
}
The train function is relatively simple, we select a random batch from the experience buffer, get our current model's value predictions for those states, compare those to the target's predictions on the next_observation plus the current observation's reward, and train on the loss.

As mentioned in the comments above, this is the biggest bottleneck for performance. Specifically, because the next_observation can be None, we cannot run the entire batch of next_observations through the target model as easily as we did the acting model. I am sure there are many ways to increase the performance, and reduce the vector to some index aware observations without the Nones, but for the purposes of this post, the performance is already more than adequate.

Every 10 updates we copy the current model to the target_model.

The Result
====================================================================================================================================================================================================================================================================================================================================

The last thing to do is create the sprites for the entities.

I've been having fun using NixOs as my daily system, and so instead of creating the sprites in Pixelmator or Adobe, I used Gimp, an open source image editor.

The sprite themselves were very simple. They pretty closely match the traditional cart pole except for the Ferris decal I added to the cart.

All said and done we end up with this:

To keep the video short, I only show it training till it survives for 500 steps the first time.

All code for this is publicly available on my github.

Thanks for reading!