88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402 | class PPOAgentBase(ABC):
"""
Base for PPO agents.
Can be used for both discrete or continuous action environments, and its use depends on the provided actor network.
Follows the instructions of: https://spinningup.openai.com/en/latest/algorithms/ppo.html
Uses lax.scan for rollout, so trajectories may be truncated.
Training relies on jitting several methods by treating the 'self' arg as static. According to suggested practice,
this can prove dangerous (https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods -
How to use jit with methods?); if attrs of 'self' change during training, the changes will not be registered in
jit. In this case, neither agent training nor evaluation change any 'self' attrs, so using Strategy 2 of the
suggested practice is valid. Otherwise, strategy 3 should have been used.
"""
# Function for performing a minibatch update of the actor network.
_actor_minibatch_fn: ClassVar[Callable[
[Tuple[TrainState, ActorLossInputType, float]],
Tuple[TrainState, ActorLossInputType, float]]
]
# Function for performing a minibatch update of the critic network.
_critic_minibatch_fn: ClassVar[Callable[
[Tuple[TrainState, CriticLossInputType]],
Tuple[TrainState, CriticLossInputType]]
]
agent_trained: ClassVar[bool] = False # Whether the agent has been trained.
training_runner: ClassVar[Optional[Runner]] = None # Runner object after training.
actor_training: ClassVar[Optional[TrainState]] = None # Actor training object.
critic_training: ClassVar[Optional[TrainState]] = None # Critic training object.
training_metrics: ClassVar[Optional[Dict[str, Float[Array, "1"]]]] = None # Metrics collected during training.
eval_during_training: ClassVar[bool] = False # Whether the agent's performance is evaluated during training
# The maximum step reached in precious training. Zero by default for starting a new training. Will be set by
# restoring or passing a trained agent (from serial training or restoring)
previous_training_max_step: ClassVar[int] = 0
def __init__(
self,
env: Environment,
env_params: EnvParams,
config: AgentConfig,
eval_during_training: bool = True
) -> None:
"""
:param env: A gymnax or custom environment that inherits from the basic gymnax class.
:param env_params: A dataclass named "EnvParams" containing the parametrization of the environment.
:param config: The configuration of the agent as and AgentConfig object (from vpf_utils).
"""
self.config = config
self.eval_during_training = eval_during_training
self._init_checkpointer()
self._init_env(env, env_params)
def __str__(self) -> str:
"""
Returns a string containing only the non-default field values.
"""
output_lst = [field + ': ' + str(getattr(self.config, field)) for field in self.config._fields]
output_lst = ['Agent configuration:'] + output_lst
return '\n'.join(output_lst)
""" GENERAL METHODS"""
def _init_env(self, env: Environment, env_params: EnvParams) -> None:
"""
Environment initialization.
:param env: A gymnax or custom environment that inherits from the basic gymnax class.
:param env_params: A dataclass containing the parametrization of the environment.
:return:
"""
env = TruncationWrapper(env, self.config.max_episode_steps)
# env = FlattenObservationWrapper(env)
# self.env = LogWrapper(env)
self.env = env
self.env_params = env_params
def _init_checkpointer(self) -> None:
"""
Sets whether checkpointing should be performed, decided by whether a checkpoint directory has been provided. If
so, sets the checkpoint manager using orbax.
:return:
"""
self.checkpointing = self.config.checkpoint_dir is not None
if self.checkpointing:
if not self.config.restore_agent:
dir_exists = os.path.exists(self.config.checkpoint_dir)
if not dir_exists:
os.makedirs(self.config.checkpoint_dir)
dir_files = [
file for file in os.listdir(self.config.checkpoint_dir)
if os.path.isdir(os.path.join(self.config.checkpoint_dir, file))
]
if len(dir_files) > 0:
for file in dir_files:
file_path = os.path.join(self.config.checkpoint_dir, file)
shutil.rmtree(file_path)
# Log training configuration
with open(os.path.join(self.config.checkpoint_dir, 'training_configuration.txt'), "w") as f:
f.write(self.__str__())
orbax_checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
options = orbax.checkpoint.CheckpointManagerOptions(
create=True,
step_prefix='trainingstep',
)
self.checkpoint_manager = orbax.checkpoint.CheckpointManager(
self.config.checkpoint_dir,
orbax_checkpointer,
options
)
else:
self.checkpoint_manager = None
def _create_empty_trainstate(self, network) -> TrainState:
"""
Creates an empty TrainState object for restoring checkpoints.
:param network: The actor or critic network.
:return:
"""
rng = jax.random.PRNGKey(1) # Just a dummy PRNGKey for initializing the networks parameters.
network, params = self._init_network(rng, network)
optimizer_params = OptimizerParams() # Use the default values of the OptimizerParams object.
tx = self._init_optimizer(optimizer_params)
empty_training = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
return empty_training
def restore(
self,
mode: str = "best",
best_fn: Optional[Callable[[Dict[str, Float[Array, "1"]]], [int]]] = None
) -> None:
"""
Restores a checkpoint (best or latest) and collects the history of metrics as assessed during training. Then,
post-processes the restored checkpoint.
:param mode: Determines whether the best performing or latest checkpoint should be restored.
:param best_fn: The function that should be used in determining the best performing checkpoint.
:return:
"""
steps = self.checkpoint_manager.all_steps()
# Log keys in checkpoints
ckpt = self.checkpoint_manager.restore(steps[0])
ckpt_keys = [key for key in list(ckpt.keys()) if key != "runner"]
# Collect history of metrics in training. Useful for continuing training.
metrics = {key: [None] * len(steps) for key in ckpt_keys}
for i, step in enumerate(steps):
ckpt = self.checkpoint_manager.restore(step)
for key in ckpt_keys:
metrics[key][i] = ckpt[key][jnp.newaxis, :]
metrics = {key: jnp.concatenate(val, axis=0) for (key, val) in metrics.items()}
if mode == "best":
if best_fn is not None:
step = steps[best_fn(metrics)]
else:
raise Exception("Function for determining best checkpoint not provided")
elif mode == "last":
step = self.checkpoint_manager.latest_step()
else:
raise Exception("Unknown method for selecting a checkpoint.")
"""
Create an empty target for restoring the checkpoint.
Some of the arguments come from restoring one of the ckpts.
"""
empty_actor_training = self._create_empty_trainstate(self.config.actor_network)
empty_critic_training = self._create_empty_trainstate(self.config.critic_network)
# Get some obs and envstate for restoring the checkpoint.
_, obs, envstate = self.env_reset(jax.random.PRNGKey(1))
empty_runner = Runner(
actor_training=empty_actor_training,
critic_training=empty_critic_training,
envstate=envstate,
obs=obs,
rng=jax.random.split(jax.random.PRNGKey(1), self.config.batch_size), # Just a dummy PRNGKey for initializing the networks parameters.
# Hyperparams can be loaded as a dict. If training continues, new hyperparams will be provided.
hyperparams=ckpt["runner"]["hyperparams"]
)
target_ckpt = {
"runner": empty_runner,
"terminated": jnp.zeros(metrics["terminated"].shape[1]),
"truncated": jnp.zeros(metrics["truncated"].shape[1]),
"final_rewards": jnp.zeros(metrics["final_rewards"].shape[1]),
"returns": jnp.zeros(metrics["returns"].shape[1]),
}
ckpt = self.checkpoint_manager.restore(step, items=target_ckpt)
self.collect_training(ckpt["runner"], metrics, previous_training_max_step=max(steps))
def _init_optimizer(self, optimizer_params: OptimizerParams) -> optax.chain:
"""
Optimizer initialization. This method uses the optax optimizer function given in the agent configuration to
initialize the appropriate optimizer. In this way, the optimizer can be initialized within the "train" method,
and thus several combinations of its parameters can be ran with jax.vmap. Jit is neither possible nor necessary.
:param optimizer_params: A NamedTuple containing the parametrization of the optimizer.
:return: An optimizer in optax.chain.
"""
optimizer_params_dict = optimizer_params._asdict() # Transform from NamedTuple to dict
optimizer_params_dict.pop('grad_clip', None) # Remove 'grad_clip', since it is not part of the optimizer args.
"""
Get dictionary of optimizer parameters to pass in optimizer. The procedure preserves parameters that:
- are given in the OptimizerParams NamedTuple and are requested as args by the optimizer
- are requested as args by the optimizer and are given in the OptimizerParams NamedTuple
"""
optimizer_arg_names = self.config.optimizer.__code__.co_varnames # List names of args of optimizer.
# Keep only the optimizer arg names that are also part of the OptimizerParams (dict from NamedTuple)
optimizer_arg_names = [
arg_name for arg_name in optimizer_arg_names if arg_name in list(optimizer_params_dict.keys())
]
if len(optimizer_arg_names) == 0:
raise Exception(
"The defined optimizer parameters do not include relevant arguments for this optimizer."
"The optimizer has not been implemented yet. Define your own OptimizerParams object."
)
# Keep only the optimizer params that are arg names for the specific optimizer
optimizer_params_dict = {arg_name: optimizer_params_dict[arg_name] for arg_name in optimizer_arg_names}
# No need to scale by -1.0. 'TrainState.apply_gradients' is used for training, which subtracts the update.
tx = optax.chain(
optax.clip_by_global_norm(optimizer_params.grad_clip),
self.config.optimizer(**optimizer_params_dict)
)
return tx
def _init_network(
self,
rng: PRNGKeyArray,
network: flax.linen.Module
) -> Tuple[flax.linen.Module, FrozenDict]:
"""
Initialization of the actor or critic network.
:param rng: Random key for initialization.
:param network: The actor or critic network.
:return: A random key after splitting the input and the initial parameters of the policy network.
"""
# Initialize the agent networks. The number of actions is irrelevant for the Critic network, which should return
# a single value in the final layer. However, the network class should accept the number of actions as an
# argument, even if it isn't used.
network = network(self.config)
rng, *_rng = jax.random.split(rng, 3)
dummy_reset_rng, network_init_rng = _rng
dummy_obs, _ = self.env.reset(dummy_reset_rng, self.env_params)
init_x = jnp.zeros((1, dummy_obs.size))
params = network.init(network_init_rng, init_x)
return network, params
@partial(jax.jit, static_argnums=(0,))
def env_reset(self, rng: PRNGKeyArray) -> Tuple[PRNGKeyArray, ObsType, LogEnvState | EnvState | TruncationEnvState]:
"""
Environment reset.
:param rng: Random key for initialization.
:return: A random key after splitting the input, the reset environment in array and LogEnvState formats.
"""
rng, reset_rng = jax.random.split(rng)
obs, envstate = self.env.reset(reset_rng, self.env_params)
return rng, obs, envstate
@partial(jax.jit, static_argnums=(0,))
def env_step(
self,
rng: PRNGKeyArray,
envstate: LogEnvState | EnvState | TruncationEnvState,
action: ActionType
) -> Tuple[
PRNGKeyArray,
ObsType,
LogEnvState | EnvState | TruncationEnvState,
Float[Array, "1"],
Bool[Array, "1"],
Dict[str, float | bool]
]:
"""
Environment step.
:param rng: Random key for initialization.
:param envstate: The environment state in LogEnvState format.
:param action: The action selected by the agent.
:return: A tuple of: a random key after splitting the input, the next obs in array and LogEnvState formats,
the collected reward after executing the action, episode termination and a dictionary of optional
additional information.
"""
rng, step_rng = jax.random.split(rng)
next_obs, next_envstate, reward, done, info = (
self.env.step(step_rng, envstate, action.squeeze(), self.env_params))
return rng, next_obs, next_envstate, reward, done, info
""" METHODS FOR TRAINING """
@partial(jax.jit, static_argnums=(0,))
def _make_transition(
self,
obs: ObsType,
action: ActionType,
value: Float[Array, "1"],
log_prob: Float[Array, "1"],
reward: Float[Array, "1"],
next_obs: ObsType,
terminated: Bool[Array, "1"],
) -> Transition:
"""
Creates a transition object based on the input and output of an episode step.
:param obs: The current obs of the episode step in array format.
:param action: The action selected by the agent.
:param value: The critic value of the obs.
:param log_prob: The actor log-probability of the selected action.
:param reward: The collected reward after executing the action.
:param next_obs: The next obs of the episode step in array format.
:param terminated: Episode termination.
:return: A transition object storing information about the state before and after executing the episode step,
the executed action, the collected reward, episode termination and optional additional information.
"""
transition = Transition(obs.squeeze(), action, value, log_prob, reward, next_obs, terminated)
transition = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), transition)
return transition
@partial(jax.jit, static_argnums=(0,))
def _generate_metrics(self, runner: Runner, update_step: int) -> Dict[str, Float[Array, "1"]]:
"""
Generates metrics for on-policy learning. The agent performance during training is evaluated by running
n_evals episodes (until termination). The selected metric is the sum of rewards collected dring the episode.
If the user selects not to generate metrics (leading to faster training), an empty dictinary is returned.
:param runner: The update runner object, containing information about the current status of the actor's/critic's
training, the state of the environment and training hyperparameters.
:param update_step: The number of the update step.
:return: A dictionary of the sum of rewards collected over 'n_evals' episodes, or empty dictionary.
"""
metric = {}
if self.eval_during_training:
metric = self._eval_agent(
self.config.eval_rng,
runner.actor_training,
runner.critic_training,
self.config.n_evals
)
metric.update({
"actor_loss": runner.actor_loss,
"critic_loss": runner.critic_loss
})
return metric
def _create_training(
self,
rng: PRNGKeyArray,
network: type[flax.linen.Module],
optimizer_params: OptimizerParams
)-> TrainState:
"""
Creates a TrainState object for the actor or the critic.
:param rng: Random key for initialization.
:param network: The actor or critic network.
:param optimizer_params: A NamedTuple containing the parametrization of the optimizer.
:return: A TrainState object to be used in training the actor and cirtic networks.
"""
network, params = self._init_network(rng, network)
tx = self._init_optimizer(optimizer_params)
return TrainState.create(apply_fn=network.apply, tx=tx, params=params)
@partial(jax.jit, static_argnums=(0,))
def _create_update_runner(
self,
rng: PRNGKeyArray,
actor_training: TrainState,
critic_training: TrainState,
hyperparams: HyperParameters
) -> Runner:
"""
Initializes the update runner as a Runner object. The runner contains n_evals initializations of the
environment, which are used for sampling trajectories. The update runner has one TrainState for the actor and
one for the critic network, so that trajectory batches are used to train the same parameters.
:param rng: Random key for initialization.
:param actor_training: The actor TrainState object used in training.
:param critic_training: The critic TrainState object used in training.
:param hyperparams: An instance of HyperParameters for training.
:return: An update runner object to be used in trajectory sampling and training.
"""
rng, reset_rng, runner_rng = jax.random.split(rng, 3)
reset_rngs = jax.random.split(reset_rng, self.config.batch_size)
runner_rngs = jax.random.split(runner_rng, self.config.batch_size)
_, obs, envstate = jax.vmap(self.env_reset)(reset_rngs)
update_runner = Runner(
actor_training=actor_training,
critic_training=critic_training,
envstate=envstate,
obs=obs,
rng=runner_rngs,
hyperparams=hyperparams,
actor_loss=jnp.zeros(1),
critic_loss=jnp.zeros(1),
)
return update_runner
@partial(jax.jit, static_argnums=(0,))
def _add_next_values(
self,
traj_batch: Transition,
last_obs: ObsType,
critic_training: TrainState
) -> Transition:
"""
Concatenates all state values but the first one with the value estimate of the final state, to represent the
values of the next state (1-step lag).
:param traj_batch: The batch of trajectories.
:param last_obs: The obs at the end of every trajectory in the batch.
:param critic_training: The critic TrainState object (either mid- or post-training).
:return: The batch of trajectories with the updated next-state values.
"""
last_state_value_vmap = jax.vmap(critic_training.apply_fn, in_axes=(None, 0))
last_state_value = last_state_value_vmap(lax.stop_gradient(critic_training.params), last_obs)
"""Remove first entry so that the next state values per step are in sync with the state rewards."""
next_values_t = jnp.concatenate(
[traj_batch.value.squeeze(), last_state_value[..., jnp.newaxis]],
axis=-1)[:, 1:]
traj_batch = traj_batch._replace(next_value=next_values_t)
return traj_batch
@partial(jax.jit, static_argnums=(0,))
def _add_advantages(self, traj_batch: Transition, advantage: ReturnsType) -> Transition:
"""
Simply inputs the advantages in the batch of trajectories.
:param traj_batch: The batch of trajectories.
:param advantage: The advantage over the trajectory batch.
:return: The batch of trajectories with the updated advantage.
"""
traj_batch = traj_batch._replace(advantage=advantage)
return traj_batch
@partial(jax.jit, static_argnums=(0,))
def _returns(
self,
traj_batch: Transition,
last_next_state_value: Float[Array, "batch_size"],
gamma: float,
gae_lambda: float
) -> ReturnsType:
"""
Calculates the returns of every step in the trajectory batch. To do so, it identifies episodes in the
trajectories. Note that because lax.scan is used in sampling trajectories, they do not necessarily finish with
episode termination (episodes may be truncated). Also, since the environment is not re-initialized per sampling
step, trajectories do not start at the initial state.
:param traj_batch: The batch of trajectories.
:param last_next_state_value: The value of the last next state in each trajectory.
:param gamma: Discount factor
:param gae_lambda: The GAE λ factor.
:return: The returns over the episodes of the trajectory batch.
"""
rewards_t = traj_batch.reward.squeeze()
terminated_t = 1.0 - traj_batch.terminated.astype(jnp.float32).squeeze()
discounts_t = (terminated_t * gamma).astype(jnp.float32)
"""Remove first entry so that the next state values per step are in sync with the state rewards."""
next_state_values_t = jnp.concatenate(
[traj_batch.value.squeeze(), last_next_state_value[..., jnp.newaxis]],
axis=-1)[:, 1:]
rewards_t, discounts_t, next_state_values_t = jax.tree_util.tree_map(
lambda x: jnp.swapaxes(x, 0, 1), (rewards_t, discounts_t, next_state_values_t)
)
gae_lambda = jnp.ones_like(discounts_t) * gae_lambda
traj_runner = (rewards_t, discounts_t, next_state_values_t, gae_lambda)
end_value = jnp.take(next_state_values_t, -1, axis=0) # Start from end of trajectory and work in reverse.
_, returns = lax.scan(self._trajectory_returns, end_value, traj_runner, reverse=True)
returns = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), returns)
return returns
@partial(jax.jit, static_argnums=(0,))
def _advantages(
self,
traj_batch: Transition,
gamma: float,
gae_lambda: float
) -> ReturnsType:
"""
Calculates the advantage of every step in the trajectory batch. To do so, it identifies episodes in the
trajectories. Note that because lax.scan is used in sampling trajectories, they do not necessarily finish with
episode termination (episodes may be truncated). Also, since the environment is not re-initialized per sampling
step, trajectories do not start at the initial state.
:param traj_batch: The batch of trajectories.
:param last_next_state_value: The value of the last next state in each trajectory.
:param gamma: Discount factor
:param gae_lambda: The GAE λ factor.
:return: The returns over the episodes of the trajectory batch.
"""
rewards_t = traj_batch.reward.squeeze()
values_t = traj_batch.value.squeeze()
terminated_t = traj_batch.terminated.squeeze()
next_state_values_t = traj_batch.next_value.squeeze()
gamma_t = jnp.ones_like(terminated_t) * gamma
gae_lambda_t = jnp.ones_like(terminated_t) * gae_lambda
rewards_t, values_t, next_state_values_t, terminated_t, gamma_t, gae_lambda_t = jax.tree_util.tree_map(
lambda x: jnp.swapaxes(x, 0, 1),
(rewards_t, values_t, next_state_values_t, terminated_t, gamma_t, gae_lambda_t)
)
traj_runner = (rewards_t, values_t, next_state_values_t, terminated_t, gamma_t, gae_lambda_t)
"""
TODO:
Advantage of last step is taken from the critic, in contrast to traditional approaches, where the rollout
ends with episode termination and the advantage is zero. Training is still successful and the influence of this
implementation choice is negligible.
"""
end_advantage = jnp.zeros(self.config.batch_size)
_, advantages = jax.lax.scan(self._trajectory_advantages, end_advantage, traj_runner, reverse=True)
advantages = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), advantages)
return advantages
@partial(jax.jit, static_argnums=(0,))
def _make_rollout_runners(self, update_runner: Runner) -> Tuple[StepRunnerType, ...]:
"""
Creates a rollout_runners tuple to be used in rollout by combining the batched environments in the update_runner
object and broadcasting the TrainState object for the critic and the network in the update_runner object to the
same dimension.
:param update_runner: The Runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:return: tuple with step runners to be used in rollout.
"""
rollout_runner = (
update_runner.envstate,
update_runner.obs,
update_runner.actor_training,
update_runner.critic_training,
update_runner.rng,
)
rollout_runners = jax.vmap(
lambda v, w, x, y, z: (v, w, x, y, z), in_axes=(0, 0, None, None, 0)
)(*rollout_runner)
return rollout_runners
@partial(jax.jit, static_argnums=(0,))
def _rollout(self, step_runner: StepRunnerType, i_step: int) -> Tuple[StepRunnerType, Transition]:
"""
Evaluation of trajectory rollout. In each step the agent:
- evaluates policy and value
- selects action
- performs environment step
- creates step transition
:param step_runner: A tuple containing information on the environment state, the actor and critic training
(parameters and networks) and a random key.
:param i_step: Unused, required for lax.scan.
:return: The updated step_runner tuple and the rollout step transition.
"""
envstate, obs, actor_training, critic_training, rng = step_runner
rng, rng_action = jax.random.split(rng)
action = self._sample_action(rng_action, actor_training, obs)
value = critic_training.apply_fn(lax.stop_gradient(critic_training.params), obs)
log_prob = self._log_prob(actor_training, lax.stop_gradient(actor_training.params), obs, action)
rng, next_obs, next_envstate, reward, done, info = self.env_step(rng, envstate, action)
step_runner = (next_envstate, next_obs, actor_training, critic_training, rng)
terminated = info["terminated"]
transition = self._make_transition(
obs=obs,
action=action,
value=value,
log_prob=log_prob,
reward=reward,
next_obs=next_obs,
terminated=terminated,
)
return step_runner, transition
@partial(jax.jit, static_argnums=(0,))
def _process_trajectory(self, update_runner: Runner, traj_batch: Transition, last_obs: ObsType) -> Transition:
"""
Estimates the value and advantages for a batch of trajectories. For the last state of trajectory, which is not
guaranteed to end with termination, the value is estimated using the critic network. This assumption has been
shown to have no influence by the end of training.
:param update_runner: The Runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param traj_batch: The batch of trajectories, as collected by in rollout.
:param last_obs: The obs at the end of every trajectory in the batch.
:return: A batch of trajectories that includes an estimate of values and advantages.
"""
traj_batch = jax.tree_util.tree_map(lambda x: x.squeeze(), traj_batch)
traj_batch = self._add_next_values(traj_batch, last_obs, update_runner.critic_training)
advantages = self._advantages(traj_batch, update_runner.hyperparams.gamma, update_runner.hyperparams.gae_lambda)
traj_batch = self._add_advantages(traj_batch, advantages)
return traj_batch
@staticmethod
def _actor_minibatch_update(
i_minibatch: int,
minibatch_runner: Tuple[TrainState, ActorLossInputType, float],
grad_fn: Callable[[Any], ActorLossInputType]
) -> Annotated[Tuple[TrainState, ActorLossInputType, float], "n_minibatch"]:
"""
Performs a minibatch update of the actor network. Not jitted, so that the grad_fn argument can be
passed. This choice doesn't hurt performance. To be called using a lambda function for defining grad_fn.
:param i_minibatch: Number of minibatch update.
:param minibatch_runner: A tuple containing the TranState object, the loss input arguments and the KL divergence.
:param grad_fn: The gradient function of the training loss.
:return: Minibatch runner with an updated TrainState.
"""
actor_training, actor_loss_input, kl = minibatch_runner
*traj_batch, hyperparams = actor_loss_input
traj_minibatch = jax.tree_map(lambda x: jnp.take(x, i_minibatch, axis=0), traj_batch)
grad_input_minibatch = (actor_training, *traj_minibatch, hyperparams)
grads, kl = grad_fn(*grad_input_minibatch)
actor_training = actor_training.apply_gradients(grads=grads.params)
return actor_training, actor_loss_input, kl
@staticmethod
def _critic_minibatch_update(
i_minibatch: int,
minibatch_runner: Tuple[TrainState, CriticLossInputType],
grad_fn: Callable[[Any], CriticLossInputType]
) -> Tuple[TrainState, CriticLossInputType]:
"""
Performs a minibatch update of the critic network. Not jitted, so that the grad_fn argument can be
passed. This choice doesn't hurt performance. To be called using a lambda function for defining grad_fn.
:param i_minibatch: Number of minibatch update.
:param minibatch_runner: A tuple containing the TranState object and the loss input arguments.
:param grad_fn: The gradient function of the training loss.
:return: Minibatch runner with an updated TrainState.
"""
critic_training, critic_loss_input = minibatch_runner
*traj_batch, hyperparams = critic_loss_input
traj_minibatch = jax.tree_map(lambda x: jnp.take(x, i_minibatch, axis=0), traj_batch)
grad_input_minibatch = (critic_training, *traj_minibatch, hyperparams)
grads = grad_fn(*grad_input_minibatch)
critic_training = critic_training.apply_gradients(grads=grads.params)
return critic_training, critic_loss_input
@partial(jax.jit, static_argnums=(0,))
def _actor_epoch(
self,
epoch_runner: Tuple[TrainState, ActorLossInputType, Float[Array, "1"], int, float]
) -> Tuple[TrainState, ActorLossInputType, Float[Array, "1"], int, float]:
"""
Performs a Gradient Ascent update of the actor.
:param epoch_runner: A tuple containing the following information about the update:
- actor_training: TrainState object for actor training
- actor_loss_input: tuple with the inputs required by the actor loss function.
- kl: The KL divergence collected during the update (used in checking for early stopping).
- epoch: The number of the current training epoch.
- kl_threshold: The KL divergence threshold for early stopping.
:return: The updated epoch runner.
"""
actor_training, actor_loss_input, kl, epoch, kl_threshold = epoch_runner
minibatch_runner = (actor_training, actor_loss_input, 0)
n_minibatch_updates = self.config.batch_size // self.config.minibatch_size
minibatch_runner = lax.fori_loop(0, n_minibatch_updates, self._actor_minibatch_fn, minibatch_runner)
actor_training, _, kl = minibatch_runner
return actor_training, actor_loss_input, kl, epoch+1, kl_threshold
@partial(jax.jit, static_argnums=(0,))
def _actor_training_cond(
self,
epoch_runner: Tuple[TrainState, ActorLossInputType, Float[Array, "1"], int, float]
) -> Bool[Array, "1"]:
"""
Checks whether the lax.while_loop over epochs should be terminated (either because the number of epochs has been
met or due to KL divergence early stopping).
:param epoch_runner: A tuple containing the following information about the update:
- actor_training: TrainState object for actor training
- actor_loss_input: tuple with the inputs required by the actor loss function.
- kl: The KL divergence collected during the update (used in checking for early stopping).
- epoch: The number of the current training epoch.
- kl_threshold: The KL-divergence threshold for early stopping.
:return: Whether the lax.while_loop over training epochs finishes.
"""
_, _, kl, epoch, kl_threshold = epoch_runner
return jnp.logical_and(
jnp.less(epoch, self.config.actor_epochs),
jnp.less_equal(kl, kl_threshold)
)
@partial(jax.jit, static_argnums=(0,))
def _actor_update(self, update_runner: Runner, traj_batch: Transition) -> Tuple[TrainState, Float[Array, "1"]]:
"""
Prepares the input and performs Gradient Ascent for the actor network.
:param update_runner: The Runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param traj_batch: The batch of trajectories.
:return: The actor training object updated after actor_epochs steps of Gradient Ascent.
"""
actor_loss_input = self._actor_loss_input(update_runner, traj_batch)
start_kl, start_epoch = -jnp.inf, 1
actor_epoch_runner = (
update_runner.actor_training,
actor_loss_input,
start_kl,
start_epoch,
update_runner.hyperparams.kl_threshold
)
actor_epoch_runner = lax.while_loop(self._actor_training_cond, self._actor_epoch, actor_epoch_runner)
actor_training, _, _, _, _ = actor_epoch_runner
actor_loss, _ = self._actor_loss(
actor_training,
traj_batch.obs,
traj_batch.action,
traj_batch.log_prob,
traj_batch.advantage,
update_runner.hyperparams
)
return actor_training, actor_loss
@partial(jax.jit, static_argnums=(0,))
def _critic_epoch(
self,
i_epoch: int,
epoch_runner: Tuple[TrainState, CriticLossInputType]
) -> Tuple[TrainState, CriticLossInputType]:
"""
Performs a Gradient Descent update of the critic.
:param: i_epoch: The current training epoch (unused but required by lax.fori_loop).
:param epoch_runner: A tuple containing the following information about the update:
- critic_training: TrainState object for critic training
- critic_loss_input: tuple with the inputs required by the critic loss function.
:return: The updated epoch runner.
"""
critic_training, critic_loss_input = epoch_runner
minibatch_runner = (critic_training, critic_loss_input)
n_minibatch_updates = self.config.batch_size // self.config.minibatch_size
minibatch_runner = lax.fori_loop(0, n_minibatch_updates, self._critic_minibatch_fn, minibatch_runner)
critic_training, _ = minibatch_runner
return critic_training, critic_loss_input
@partial(jax.jit, static_argnums=(0,))
def _critic_update(self, update_runner: Runner, traj_batch: Transition) -> Tuple[TrainState, Float[Array, "1"]]:
"""
Prepares the input and performs Gradient Descent for the critic network.
:param update_runner: The Runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param traj_batch: The batch of trajectories.
:return: The critic training object updated after actor_epochs steps of Gradient Ascent.
"""
critic_loss_input = self._critic_loss_input(update_runner, traj_batch)
critic_epoch_runner = (update_runner.critic_training, critic_loss_input)
critic_epoch_runner = lax.fori_loop(0, self.config.critic_epochs, self._critic_epoch, critic_epoch_runner)
critic_training, _ = critic_epoch_runner
critic_targets = critic_loss_input[1].reshape(-1, self.config.rollout_length)
critic_loss = self._critic_loss(critic_training, traj_batch.obs, critic_targets, update_runner.hyperparams)
return critic_training, critic_loss
@partial(jax.jit, static_argnums=(0,))
def _update_step(self, i_update_step: int, update_runner: Runner) -> Runner:
"""
An update step of the actor and critic networks. This entails:
- performing rollout for sampling a batch of trajectories.
- assessing the value of the last state per trajectory using the critic.
- evaluating the advantage per trajectory.
- updating the actor and critic network parameters via the respective loss functions.
- generating in-training performance metrics.
In this approach, the update_runner already has a batch of environments initialized. The environments are not
initialized in the beginning of every update step, which means that trajectories to not necessarily start from
an initial state (which lead to better results when benchmarking with Cartpole-v1). Moreover, the use of lax.scan
for rollout means that the trajectories do not necessarily stop with episode termination (episodes can be
truncated in trajectory sampling).
:param i_update_step: Unused, required for progressbar.
:param update_runner: The runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:return: The updated runner
"""
rollout_runners = self._make_rollout_runners(update_runner)
scan_rollout_fn = lambda x: lax.scan(self._rollout, x, None, self.config.rollout_length)
rollout_runners, traj_batch = jax.vmap(scan_rollout_fn)(rollout_runners)
last_envstate, last_obs, _, _, rng = rollout_runners
traj_batch = self._process_trajectory(update_runner, traj_batch, last_obs)
actor_training, actor_loss = self._actor_update(update_runner, traj_batch)
critic_training, critic_loss = self._critic_update(update_runner, traj_batch)
"""Update runner as a dataclass."""
update_runner = update_runner.replace(
envstate=last_envstate,
obs=last_obs,
actor_training=actor_training,
critic_training=critic_training,
rng=rng,
actor_loss=jnp.expand_dims(actor_loss, axis=-1),
critic_loss=jnp.expand_dims(critic_loss, axis=-1)
)
return update_runner
@partial(jax.jit, static_argnums=(0,))
def _checkpoint(self, update_runner: Runner, metrics: Dict[str, Float[Array, "1"]], i_training_step: int) -> None:
"""
Wraps the base checkpointing method in a Python callback.
:param update_runner: The runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param metrics: Dictionary of evaluation metrics (return per environment evaluation)
:param i_training_step: Training step
:return:
"""
jax.experimental.io_callback(self._checkpoint_base, None, update_runner, metrics, i_training_step)
def _checkpoint_base(
self,
update_runner: Runner,
metrics: Dict[str, Float[Array, "1"]],
i_training_step: int
) -> None:
"""
Implements checkpointing, to be wrapped in a Python callback. Checkpoints the following:
- The training runner object.
- Returns of the evaluation episodes
The average return over the evaluated episodes is used as the checkpoint metric.
:param update_runner: The runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param metrics: Dictionary of evaluation metrics (return per episode evaluation)
:param i_training_step: Training step
:return:
"""
if self.checkpointing:
ckpt = {
"runner": update_runner,
"terminated": metrics["terminated"],
"truncated": metrics["truncated"],
"final_rewards": metrics["final_rewards"],
"returns": metrics["returns"]
}
save_args = orbax_utils.save_args_from_target(ckpt)
self.checkpoint_manager.save(
# Use maximum number of steps reached in previous training. Set to zero by default during agent
# initialization if a new training is executed. In case of continuing training, the checkpoint of step
# zero replaces the last checkpoint of the previous training. The two checkpoints are the same.
i_training_step+self.previous_training_max_step,
ckpt,
save_kwargs={'save_args': save_args},
)
@partial(jax.jit, static_argnums=(0,))
def _training_step(
self,
update_runner: Runner,
i_training_batch: int
) -> Tuple[Runner, Dict[str, Float[Array, "1"]]]:
"""
Performs trainings steps to update the agent per training batch.
:param update_runner: The runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters.
:param i_training_batch: Training batch loop counter.
:return: tuple with updated runner and dictionary of metrics.
"""
n_training_steps = self.config.n_steps - self.config.n_steps // self.config.eval_frequency * i_training_batch
n_training_steps = jnp.clip(n_training_steps, 1, self.config.eval_frequency)
update_runner = lax.fori_loop(0, n_training_steps, self._update_step, update_runner)
if self.eval_during_training:
metrics = self._generate_metrics(runner=update_runner, update_step=i_training_batch)
i_training_step = self.config.eval_frequency * (i_training_batch + 1)
i_training_step = jnp.minimum(i_training_step, self.config.n_steps)
if self.checkpointing:
self._checkpoint(update_runner, metrics, i_training_step)
else:
metrics = {}
return update_runner, metrics
@partial(jax.jit, static_argnums=(0,))
def train(self, rng: PRNGKeyArray, hyperparams: HyperParameters) -> Tuple[Runner, Dict[str, Float[Array, "1"]]]:
"""
Trains the agent. A jax_tqdm progressbar has been added in the lax.scan loop.
:param rng: Random key for initialization. This is the original key for training.
:param hyperparams: An instance of HyperParameters for training.
:return: The final state of the step runner after training and the training metrics accumulated over all
training batches and steps.
"""
rng, *_rng = jax.random.split(rng, 4)
actor_init_rng, critic_init_rng, runner_rng = _rng
actor_training = self._create_training(
actor_init_rng, self.config.actor_network, hyperparams.actor_optimizer_params
)
critic_training = self._create_training(
critic_init_rng, self.config.critic_network, hyperparams.critic_optimizer_params
)
update_runner = self._create_update_runner(runner_rng, actor_training, critic_training, hyperparams)
# Checkpoint initial state
if self.eval_during_training:
metrics_start = self._generate_metrics(runner=update_runner, update_step=0)
if self.checkpointing:
self._checkpoint(update_runner, metrics_start, self.previous_training_max_step)
# Initialize agent updating functions, which can be avoided to be done within the training loops.
actor_grad_fn = jax.grad(self._actor_loss, has_aux=True, allow_int=True)
self._actor_minibatch_fn = lambda x, y: self._actor_minibatch_update(x, y, actor_grad_fn)
critic_grad_fn = jax.grad(self._critic_loss, allow_int=True)
self._critic_minibatch_fn = lambda x, y: self._critic_minibatch_update(x, y, critic_grad_fn)
# Train, evaluate, checkpoint
n_training_batches = self.config.n_steps // self.config.eval_frequency
progressbar_desc = f'Training batch (training steps = batch x {self.config.eval_frequency})'
runner, metrics = lax.scan(
scan_tqdm(n_training_batches, desc=progressbar_desc)(self._training_step),
update_runner,
jnp.arange(n_training_batches),
n_training_batches
)
if self.eval_during_training:
metrics = {
key: jnp.concatenate((metrics_start[key][jnp.newaxis, :], metrics[key]), axis=0)
for key in metrics.keys()
}
else:
metrics= {}
return runner, metrics
@abstractmethod
def _trajectory_returns(self, value: Float[Array, "batch_size"], traj: Transition) -> Tuple[float, float]:
"""
Calculates the returns per episode step over a batch of trajectories.
:param value: The values of the steps in the trajectory according to the critic (including the one of the last
state).
:param traj: The trajectory batch.
:return: A tuple of returns.
"""
raise NotImplementedError
@abstractmethod
def _trajectory_advantages(self, value: Float[Array, "batch_size"], traj: Transition) -> Tuple[float, float]:
"""
Calculates the advantages per episode step over a batch of trajectories.
:param value: The values of the steps in the trajectory according to the critic (including the one of the last
state).
:param traj: The trajectory batch.
:return: An array of returns.
"""
raise NotImplementedError
@abstractmethod
def _actor_loss(
self,
training: TrainState,
obs: Float[Array, "n_rollout batch_size obs_size"],
action: Float[Array, "n_rollout batch_size"],
log_prob_old: Float[Array, "n_rollout batch_size"],
advantage: ReturnsType,
hyperparams: HyperParameters
)-> Tuple[Float[Array, "1"], Float[Array, "1"]]:
"""
Calculates the actor loss. For the REINFORCE agent, the advantage function is the difference between the
discounted returns and the value as estimated by the critic.
:param training: The actor TrainState object.
:param obs: The obs in the trajectory batch.
:param action: The actions in the trajectory batch.
:param log_prob_old: Log-probabilities of the old policy collected over the trajectory batch.
:param advantage: The advantage over the trajectory batch.
:param hyperparams: The HyperParameters object used for training.
:return: A tuple containing the actor loss and the KL divergence (for early checking stopping criterion).
"""
raise NotImplementedError
@abstractmethod
def _critic_loss(
self,
training: TrainState,
obs: Float[Array, "n_rollout batch_size obs_size"],
targets: Float[Array, "batch_size n_rollout"],
hyperparams: HyperParameters
) -> float:
"""
Calculates the critic loss.
:param training: The critic TrainState object.
:param obs: The obs in the trajectory batch.
:param targets: The returns over the trajectory batch, which act as the targets for training the critic.
:param hyperparams: The HyperParameters object used for training.
:return: The critic loss.
"""
raise NotImplementedError
@abstractmethod
def _actor_loss_input(self, update_runner: Runner, traj_batch: Transition) -> Tuple[ActorLossInputType]:
"""
Prepares the input required by the actor loss function. The input is reshaped so that it is split into
minibatches.
:param update_runner: The runner object used in training.
:param traj_batch: The batch of trajectories.
:return: A tuple of input to the actor loss function.
"""
raise NotImplementedError
@abstractmethod
def _critic_loss_input(self, update_runner: Runner, traj_batch: Transition) -> CriticLossInputType:
"""
Prepares the input required by the critic loss function. The input is reshaped so that it is split into
minibatches.
:param update_runner: The Runner object used in training.
:param traj_batch: The batch of trajectories.
:return: A tuple of input to the critic loss function.
"""
raise NotImplementedError
@abstractmethod
def _entropy(self, training: TrainState, obs: ObsType)-> Float[Array, "1"]:
raise NotImplemented
@abstractmethod
def _log_prob(
self,
training: TrainState,
params: FrozenDict,
obs: ObsType,
action: ActionType
) -> Float[Array, "n_actors"]:
raise NotImplemented
@abstractmethod
def _sample_action(
self,
rng: PRNGKeyArray,
training: TrainState,
obs: ObsType
) -> ActionType:
raise NotImplemented
""" METHODS FOR APPLYING AGENT"""
@abstractmethod
def policy(self, training: TrainState, obs: ObsType) -> ActionType:
"""
Evaluates the action of the optimal policy (argmax) according to the trained agent for the given state.
:param obs: The current obs of the episode step in array format.
:return:
"""
raise NotImplemented
""" METHODS FOR PERFORMANCE EVALUATION """
def _eval_agent(
self,
rng: PRNGKeyArray,
actor_training: TrainState,
critic_training: TrainState,
n_episodes: int = 1
) -> Dict[str, Float[Array, "1"] | Bool[Array, "1"]]:
"""
Evaluates the agents for n_episodes complete episodes using 'lax.while_loop'.
:param rng: A random key used for evaluating the agent.
:param actor_training: The actor TrainState object (either mid- or post-training).
:param critic_training: The critic TrainState object (either mid- or post-training).
:param n_episodes: The update_runner object used during training.
:return: The sum of rewards collected over n_episodes episodes.
"""
rng_eval = jax.random.split(rng, n_episodes)
rng, obs, envstate = jax.vmap(self.env_reset)(rng_eval)
eval_runner = (
envstate,
obs,
actor_training,
jnp.zeros(1, dtype=jnp.bool).squeeze(),
jnp.zeros(1, dtype=jnp.bool).squeeze(),
jnp.zeros(1).squeeze(),
jnp.zeros(1).squeeze(),
rng,
)
eval_runners = jax.vmap(
lambda s, t, u, v, w, x, y, z: (s, t, u, v, w, x, y, z),
in_axes=(0, 0, None, None, None, None, None, 0)
)(*eval_runner)
eval_runner = jax.vmap(lambda x: lax.while_loop(self._eval_cond, self._eval_body, x))(eval_runners)
_, _, _, terminated, truncated, final_rewards, returns, _ = eval_runner
return self._eval_metrics(terminated, truncated, final_rewards, returns)
def _eval_metrics(
self,
terminated: Bool[Array, "1"],
truncated: Bool[Array, "1"],
final_rewards: Float[Array, "1"],
returns: Float[Array, "1"]
) -> Dict[str, Float[Array, "1"] | Bool[Array, "1"]]:
"""
Evaluate the metrics.
:param terminated: Whether the episode finished by termination.
:param truncated: Whether the episode finished by truncation.
:param final_rewards: The rewards collected in the final step of the episode.
:param returns: The sum of rewards collected during the episode.
:return: Dictionary combining the input arguments and the case-specific special metrics.
"""
metrics = {
"terminated": terminated,
"truncated": truncated,
"final_rewards": final_rewards,
"returns": returns
}
return metrics
@partial(jax.jit, static_argnums=(0,))
def _eval_body(self, eval_runner: EvalRunnerType) -> EvalRunnerType:
"""
A step in the episode to be used with 'lax.while_loop' for evaluation of the agent in a complete episode.
:param eval_runner: A tuple containing information about the environment state, the actor and critic training
states, whether the episode is terminated (for checking the condition in 'lax.while_loop'), the sum of rewards
over the episode and a random key.
:return: The updated eval_runner tuple.
"""
envstate, obs, actor_training, terminated, truncated, reward, returns, rng = eval_runner
action = self.policy(actor_training, obs)
rng, next_obs, next_envstate, reward, done, info = self.env_step(rng, envstate, action)
terminated = info["terminated"]
truncated = info["truncated"]
returns += reward
eval_runner = (next_envstate, next_obs, actor_training, terminated, truncated, reward, returns, rng)
return eval_runner
@partial(jax.jit, static_argnums=(0,))
def _eval_cond(self, eval_runner: EvalRunnerType) -> Bool[Array, "1"]:
"""
Checks whether the episode is terminated, meaning that the 'lax.while_loop' can stop.
:param eval_runner: A tuple containing information about the environment state, the actor and critic training
states, whether the episode is terminated (for checking the condition in 'lax.while_loop'), the sum of rewards
over the episode and a random key.
:return: Whether the episode is terminated, which means that the while loop must stop.
"""
_, _, _, terminated, truncated, _, _, _ = eval_runner
return jnp.logical_and(jnp.logical_not(terminated), jnp.logical_not(truncated))
def eval(self, rng: PRNGKeyArray, n_evals: int = 32) -> Float[Array, "n_evals"]:
"""
Evaluates the trained agent's performance post-training using the trained agent's actor and critic.
:param rng: Random key for evaluation.
:param n_evals: Number of steps in agent evaluation.
:return: Dictionary of evaluation metrics.
"""
eval_metrics = self._eval_agent(rng, self.actor_training, self.critic_training, n_evals)
return eval_metrics
""" METHODS FOR POST-PROCESSING """
def log_hyperparams(self, hyperparams: HyperParameters) -> None:
"""
Logs training hyperparameters in a text file. To be used outside training.
:param hyperparams: An instance of HyperParameters for training.
:return:
"""
output_lst = [field + ': ' + str(getattr(hyperparams, field)) for field in hyperparams._fields]
output_lst = ['Hyperparameters:'] + output_lst
output_lst = '\n'.join(output_lst)
if self.checkpointing:
with open(os.path.join(self.config.checkpoint_dir, 'hyperparameters.txt'), "w") as f:
f.write(output_lst)
def collect_training(
self,
runner: Optional[Runner] = None,
metrics: Optional[Dict[str, Float[Array, "1"]]] = None,
previous_training_max_step: int = 0
) -> None:
"""
Collects training or restored checkpoint of output (the final state of the runner after training and the
collected metrics).
:param runner: The runner object, containing information about the current status of the actor's/
critic's training, the state of the environment and training hyperparameters. This is at the state reached at
the end of training.
:param metrics: Dictionary of evaluation metrics (return per environment evaluation)
:param previous_training_max_step: Maximum step reached during training.
:return:
"""
self.agent_trained = True
self.previous_training_max_step = previous_training_max_step
self.training_runner = runner
self.training_metrics = metrics
n_evals = list(metrics.values())[0].shape[0]
self.eval_steps_in_training = jnp.arange(n_evals) * self.config.eval_frequency
self._pp()
def _pp(self) -> None:
"""
Post-processes the training results, which includes:
- Setting the policy actor and critic TrainStates of a Runner object (e.g. last in training of restored).
:return:
"""
self.actor_training = self.training_runner.actor_training
self.critic_training = self.training_runner.critic_training
def summarize(
self,
metrics: Annotated[NDArray[np.float32], "size_metrics"] | Float[Array, "size_metrics"]
) -> MetricStats:
"""
Summarizes collection of per-episode metrics.
:param metrics: Metric per episode.
:return: Summary of metric per episode.
"""
return MetricStats(
episode_metric=metrics,
mean=metrics.mean(axis=-1),
var=metrics.var(axis=-1),
std=metrics.std(axis=-1),
min=metrics.min(axis=-1),
max=metrics.max(axis=-1),
median=jnp.median(metrics, axis=-1),
has_nans=jnp.any(jnp.isnan(metrics), axis=-1)
)
|