IPPO
Bases: IPPOBase
IPPO clip agent using the GAE (PPO2) for calculating the advantage. The actor loss function standardizes the advantage.
Source code in jaxagents\ippo.py
1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 |
|
_actor_loss(training, obs, actions, log_prob_old, advantage, hyperparams)
¶
Calculates the actor loss. For the REINFORCE agent, the advantage function is the difference between the discounted returns and the value as estimated by the critic. :param training: The actor TrainState object. :param obs: The obs in the trajectory batch. :param actions: The actions in the trajectory batch. :param log_prob_old: Log-probabilities of the old policy collected over the trajectory batch. :param advantage: The GAE over the trajectory batch. :param hyperparams: The HyperParameters object used for training. :return: A tuple containing the actor loss and the KL divergence (for early checking stopping criterion).
Source code in jaxagents\ippo.py
1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 |
|
_actor_loss_input(update_runner, traj_batch)
¶
Prepares the input required by the actor loss function. For the PPO agent, this entails the: - the actions collected over the trajectory batch. - the log-probability of the actions collected over the trajectory batch. - the returns over the trajectory batch. - the values over the trajectory batch as evaluated by the critic. - the training hyperparameters. The input is reshaped so that it is split into minibatches. :param update_runner: The Runner object used in training. :param traj_batch: The batch of trajectories. :return: A tuple of input to the actor loss function.
Source code in jaxagents\ippo.py
1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 |
|
_critic_loss(training, obs, targets, hyperparams)
¶
Calculates the critic loss. :param training: The critic TrainState object. :param obs: The obs in the trajectory batch. :param targets: The targets over the trajectory batch for training the critic. :param hyperparams: The HyperParameters object used for training. :return: The critic loss.
Source code in jaxagents\ippo.py
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 |
|
_critic_loss_input(update_runner, traj_batch)
¶
Prepares the input required by the critic loss function. For the PPO agent, this entails the: - the states collected over the trajectory batch. - the targets (returns = GAE + next_value) over the trajectory batch. - the training hyperparameters. The input is reshaped so that it is split into minibatches. :param update_runner: The Runner object used in training. :param traj_batch: The batch of trajectories. :return: A tuple of input to the critic loss function.
Source code in jaxagents\ippo.py
1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 |
|
_trajectory_advantages(advantage, traj)
¶
Calculates the GAE per episode step over a batch of trajectories. :param advantage: The GAE advantages of the steps in the trajectory according to the critic (including the one of the last state). In the beginning of the method, 'advantage' is the advantage of the state in the next step in the trajectory (not the reverse iteration), and after calculation it is the advantage of the examined state in each step. :param traj: The trajectory batch. :return: An array of returns.
Source code in jaxagents\ippo.py
1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 |
|
_trajectory_returns(value, traj)
¶
Calculates the returns per episode step over a batch of trajectories. :param value: The values of the steps in the trajectory according to the critic (including the one of the last state). In the begining of the method, 'value' is the value of the state in the next step in the trajectory (not the reverse iteration), and after calculation it is the value of the examined state in the examined step. :param traj: The trajectory batch. :return: An array of returns.
Source code in jaxagents\ippo.py
1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 |
|