FunctionApproximator class¶
keras_gym.FunctionApproximator |
A generic function approximator. |
-
class
keras_gym.
FunctionApproximator
(env, optimizer=None, **optimizer_kwargs)[source]¶ A generic function approximator.
This is the central object object that provides an interface between a gym-type environment and function approximators like value functions and updateable policies.
In order to create a valid function approximator, you need to implement the body method. For example, to implement a simple multi-layer perceptron function approximator you would do something like:
import gym import keras_gym as km from tensorflow.keras.layers import Flatten, Dense class MLP(km.FunctionApproximator): """ multi-layer perceptron with one hidden layer """ def body(self, S): X = Flatten()(S) X = Dense(units=4)(X) return X # environment env = gym.make(...) # generic function approximator mlp = MLP(env, lr=0.001) # policy and value function pi, v = km.SoftmaxPolicy(mlp), km.V(mlp)
The default heads are simple (multi) linear regression layers, which can be overridden by your own implementation.
Parameters: - env : environment
A gym-style environment.
- optimizer : keras.optimizers.Optimizer, optional
If left unspecified (
optimizer=None
), the function approximator’s DEFAULT_OPTIMIZER is used. See keras documentation for more details.- **optimizer_kwargs : keyword arguments
Keyword arguments for the optimizer. This is useful when you want to use the default optimizer with a different setting, e.g. changing the learning rate.
-
DEFAULT_OPTIMIZER
¶ alias of
tensorflow.python.keras.optimizer_v2.adam.Adam
-
body
(self, S)[source]¶ This is the part of the computation graph that may be shared between e.g. policy (actor) and value function (critic). It is typically the part of a neural net that does most of the heavy lifting. One may think of the
body()
as an elaborate automatic feature extractor.Parameters: - S : nd Tensor: shape: [batch_size, …]
The input state observation.
Returns: - X : nd Tensor, shape: [batch_size, …]
The intermediate keras tensor.
-
body_q1
(self, S, A)[source]¶ This is similar to
body()
, except that it takes a state-action pair as input instead of only state observations.Parameters: - S : nd Tensor: shape: [batch_size, …]
The input state observation.
- A : nd Tensor: shape: [batch_size, …]
The input actions.
Returns: - X : nd Tensor, shape: [batch_size, …]
The intermediate keras tensor.
-
head_pi
(self, X)[source]¶ This is the policy head. It returns logits, i.e. not probabilities. Use a softmax to turn the output into probabilities.
Parameters: - X : nd Tensor, shape: [batch_size, …]
X
is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of thebody()
method.
Returns: - *params : Tensor or tuple of Tensors, shape: [batch_size, …]
These constitute the raw policy distribution parameters.
-
head_q1
(self, X)[source]¶ This is the type-I Q-value head. It returns a scalar Q-value \(q(s,a)\in\mathbb{R}\).
Parameters: - X : nd Tensor, shape: [batch_size, …]
X
is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of thebody()
method.
Returns: - Q_sa : 2d Tensor, shape: [batch_size, 1]
The output type-I Q-values \(q(s,a)\in\mathbb{R}\).
-
head_q2
(self, X)[source]¶ This is the type-II Q-value head. It returns a vector of Q-values \(q(s,.)\in\mathbb{R}^n\).
Parameters: - X : nd Tensor, shape: [batch_size, …]
X
is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of thebody()
method.
Returns: - Q_s : 2d Tensor, shape: [batch_size, num_actions]
The output type-II Q-values \(q(s,.)\in\mathbb{R}^n\).
-
head_v
(self, X)[source]¶ This is the state value head. It returns a scalar V-value \(v(s)\in\mathbb{R}\).
Parameters: - X : nd Tensor, shape: [batch_size, …]
X
is an intermediate tensor in the full forward-pass of the computation graph; it’s the output of the last layer of thebody()
method.
Returns: - V : 2d Tensor, shape: [batch_size, 1]
The output state values \(v(s)\in\mathbb{R}\).