A common question we get asked is how to set up model checkpoints to continue training. In this document, we take this PPO example to explain that question.
Save model checkpoints
The first step is to save models periodically. By default, we save the model to wandb.
1 2 3 4 5 6 7 8 91011121314
num_updates=args.total_timesteps//args.batch_sizeCHECKPOINT_FREQUENCY=50starting_update=1forupdateinrange(starting_update,num_updates+1):# ... do rollouts and train modelsifargs.track:# make sure to tune `CHECKPOINT_FREQUENCY` # so models are not saved too frequentlyifupdate%CHECKPOINT_FREQUENCY==0:torch.save(agent.state_dict(),f"{wandb.run.dir}/agent.pt")wandb.save(f"{wandb.run.dir}/agent.pt",policy="now")
Then we could run the following to train our agents
num_updates=args.total_timesteps//args.batch_sizeCHECKPOINT_FREQUENCY=50starting_update=1ifargs.trackandwandb.run.resumed:starting_update=run.summary.get("charts/update")+1global_step=starting_update*args.batch_sizeapi=wandb.Api()run=api.run(f"{run.entity}/{run.project}/{run.id}")model=run.file("agent.pt")model.download(f"models/{experiment_name}/")agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt",map_location=device))agent.eval()print(f"resumed at update {starting_update}")forupdateinrange(starting_update,num_updates+1):# ... do rollouts and train modelsifargs.track:# make sure to tune `CHECKPOINT_FREQUENCY` # so models are not saved too frequentlyifupdate%CHECKPOINT_FREQUENCY==0:torch.save(agent.state_dict(),f"{wandb.run.dir}/agent.pt")wandb.save(f"{wandb.run.dir}/agent.pt",policy="now")
To resume training, note the ID of the experiment is 21421tda as in the URL https://wandb.ai/costa-huang/cleanRL/runs/21421tda, so we need to pass in the ID via environment variable to trigger the resume mode of W&B: