Skip to content

JAXAgents

JAXAgents is a high-performance (Multi-Agent) Reinforcement Learning library built on JAX, designed for rapid experimentation, scalable training of RL agents and fast hyperparameter tuning. It supports a variety of algorithms and environments, making it suitable for both research and practical applications.

Features

  • RL: Implementations of popular RL algorithms, including:
  • Q-learning:
    • Deep Q Networks (DQN)
    • Double Deep Q Networks (DDQN)
    • Categorical DQN (C51)
    • Quantile Regression DQN (QRDQN)
  • Policy Gradient:
    • REINFORCE
    • Proximal Policy Optimization (PPO) with Generalized Advantage Estimation (GAE)
  • Multi-Agent RL:

    • Independent PPO (IPPO)
  • High Performance: Leveraging JAX's capabilities for just-in-time compilation and automatic differentiation, enabling efficient computation on CPUs and GPUs.

  • Modular Design: Structured for easy extension and customization, facilitating experimentation with new algorithms and environments.

Installation

Ensure you have Python 3.10 or higher installed. Then, install JAX Agents via pip:

pip install jaxagents

Note: Also available on PyPI

Getting Started

Here's a simple example to train a PPO agent:

import jaxagents

# Initialize environment and agent
env = jaxagents.environments.make('CartPole-v1')
agent = jaxagents.agents.PPO(env)

# Train the agent
agent.train(num_episodes=1000)

For more detailed examples and usage, refer to the documentation.

Documentation

Comprehensive documentation is available at jax-agents.readthedocs.io, covering:

  • Installation and setup
  • Detailed API references
  • Tutorials and examples
  • Advanced topics and customization

Development

To contribute or modify the library:

git clone https://github.com/amavrits/jax-agents.git
cd jax-agents
pip install -e .

License

This project is licensed under a proprietary license. For more information, please refer to the LICENSE file.


For any questions or contributions, feel free to open an issue or submit a pull request on the GitHub repository.