MLエンジニアへの道 #43 - ポリシー勾配法

Last Edited: 2/25/2025

この記事では、強化学習におけるポリシー勾配法について紹介します。

ML

ここまで、テーブルや関数で提供される状態-行動価値の推定値を更新して、貪欲またはϵ\epsilon-貪欲ポリシーを導き出す状態-行動価値法について議論してきました。 しかし、テーブル型アプローチはメモリ要件が高く、関数型アプローチは実戦での最適なポリシーへの収束に苦労することがわかりました。 この記事では、これらの問題を解決することを目的とした新しいアプローチ、ポリシー勾配法について説明します。

ポリシー勾配

状態-行動価値の推定値を更新する代わりに、パラメータθ\thetaに関してパラメータ化されたポリシー関数πθ(s,a)\pi_{\theta}(s, a)を適合させることにより、 報酬を最大化するポリシーπ(s,a)\pi_{*}(s, a)を直接学習することができます。最大化しようとする報酬関数J(θ)J(\theta)は、ポリシーの下での状態価値の期待値によって決定されます。

dπ(s)=limtP(st=ss0,π)J(θ)=sSdπθ(s)Vπθ(s)=sSdπθ(s)aAπθ(as)Qπθ(s,a) d_{\pi}(s) = \lim_{t \to \infty} P(s_t = s | s_0, \pi) \\ J(\theta) = \sum_{s \in S} d_{\pi_{\theta}}(s) V_{\pi_{\theta}}(s) = \sum_{s \in S} d_{\pi_{\theta}}(s) \sum_{a \in A} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a)

この期待値を計算するために、マルコフ連鎖の状態の定常分布dπ(s)d_{\pi}(s)を使用できます。次に、報酬関数J(θ)J(\theta)のパラメータθ\thetaに関する勾配を計算し、 最適なパラメータを学習するための勾配上昇法に使用します。この方法は、状態価値関数に対する関数近似よりも滑らかに収束する傾向があり、 高次元または連続的な行動空間でより効果的ですが、局所的な最小値に収束する可能性があり、バイアスバリアンストレードオフに直面する場合があります。

ポリシー勾配定理

報酬関数の勾配を計算するために、まずθVπθ(s)\nabla_{\theta}V_{\pi_{\theta}}(s)(θaAπθ(as)Qπθ(s,a)\nabla_{\theta}\sum_{a \in A} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a)と同等)を計算します。 積の法則を使用して、次のように導関数を導出できます。

ddθVπθ(s)=ddθaAπθ(as)Qπθ(s,a)=aAddθπθ(as)Qπθ(s,a)+πθ(as)ddθQπθ(s,a)=aAddθπθ(as)Qπθ(s,a)+πθ(as)ddθs,rP(s,rs,a)(r+Vπθ(s))=aAddθπθ(as)Qπθ(s,a)+πθ(as)s,rP(s,rs,a)ddθVπθ(s)=aAddθπθ(as)Qπθ(s,a)+πθ(as)sP(ss,a)ddθVπθ(s) \frac{d}{d \theta} V_{\pi_{\theta}}(s) = \frac{d}{d \theta} \sum_{a \in A} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) \\ = \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) + \pi_{\theta}(a | s) \frac{d}{d \theta} Q_{\pi_{\theta}}(s, a) \\ = \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) + \pi_{\theta}(a | s) \frac{d}{d \theta} \sum_{s', r} P(s', r | s, a) (r + V_{\pi_{\theta}}(s')) \\ = \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) + \pi_{\theta}(a | s) \sum_{s', r} P(s', r | s, a) \frac{d}{d \theta} V_{\pi_{\theta}}(s') \\ = \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) + \pi_{\theta}(a | s) \sum_{s'} P(s' | s, a) \frac{d}{d \theta} V_{\pi_{\theta}}(s')

上記の表現には再帰的な構造があることに気づきます。この関係を明確にするために、2つの新しい関数を設定します:ρπ(ss,k)\rho_{\pi}(s \to s', k)は、 ポリシーπ\piの下でkkステップで状態ssから状態ss'への状態遷移確率を表し、ϕ(s)=aAddθπθ(as)Qπθ(s,a)\phi(s) = \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a)です。 これらの関数を使用して、再帰的構造を次のように展開できます。

ddθVπθ(s)=ϕ(s)+aAπθ(as)sP(ss,a)ddθVπθ(s)=ϕ(s)+saπθ(as)P(ss,a)ddθVπθ(s)=ϕ(s)+sρπθ(ss,1)ddθVπθ(s)=ϕ(s)+sρπθ(ss,1)[ϕ(s)+sρπθ(ss,1)ddθVπθ(s)]=ϕ(s)+sρπθ(ss,1)ϕ(s)+sρπθ(ss,2)ddθVπθ(s)=xSk=0ρπθ(sx,k)ϕ(s) \frac{d}{d \theta} V_{\pi_{\theta}}(s) = \phi(s) + \sum_{a \in A} \pi_{\theta}(a | s) \sum_{s'} P(s' | s, a) \frac{d}{d \theta} V_{\pi_{\theta}}(s') \\ = \phi(s) + \sum_{s'} \sum_a \pi_{\theta}(a | s) P(s' | s, a) \frac{d}{d \theta} V_{\pi_{\theta}}(s') \\ = \phi(s) + \sum_{s'} \rho_{\pi_{\theta}}(s \to s', 1) \frac{d}{d \theta} V_{\pi_{\theta}}(s') \\ = \phi(s) + \sum_{s'} \rho_{\pi_{\theta}}(s \to s', 1) [\phi(s') + \sum_{s''} \rho_{\pi_{\theta}}(s' \to s'', 1) \frac{d}{d \theta} V_{\pi_{\theta}}(s'')] \\ = \phi(s) + \sum_{s'} \rho_{\pi_{\theta}}(s \to s', 1) \phi(s') + \sum_{s''} \rho_{\pi_{\theta}}(s \to s'', 2) \frac{d}{d \theta} V_{\pi_{\theta}}(s'') \\ = \sum_{x \in S} \sum_{k=0}^{\infty} \rho_{\pi_{\theta}}(s \to x, k) \phi(s)

この展開は、ρπ(sx,0)=1\rho_{\pi}(s \to x, 0) = 1およびρπ(sx,k+1)=ρπ(ss,k)+ρπ(sx,1)\rho_{\pi}(s \to x, k+1) = \rho_{\pi}(s \to s', k) + \rho_{\pi}(s' \to x, 1)であるため有効です。 ここで、報酬関数の勾配が上記の表現に比例していることがわかります。これはさらに次のように導出できます。

ddθJ(θ)ddθVπθ(s)=xSk=0ρπθ(sx,k)ϕ(s)xSk=0ρπθ(sx,k)sk=0ρπθ(sx,k)ϕ(s)=xSdπθ(s)ϕ(s)=xSdπθ(s)aAddθπθ(as)Qπθ(s,a)=xSdπθ(s)aAπθ(as)Qπθ(s,a)ddθπθ(as)πθ(as)=Esdπθ,aπθ[Qπθ(s,a)ddθln(πθ(as))] \frac{d}{d \theta} J(\theta) \propto \frac{d}{d \theta} V_{\pi_{\theta}}(s) = \sum_{x \in S} \sum_{k=0}^{\infty} \rho_{\pi_{\theta}}(s \to x, k) \phi(s) \\ \propto \sum_{x \in S} \frac{ \sum_{k=0}^{\infty} \rho_{\pi_{\theta}}(s \to x, k)}{\sum_s \sum_{k=0}^{\infty} \rho_{\pi_{\theta}}(s \to x, k)} \phi(s) \\ = \sum_{x \in S} d_{\pi_{\theta}}(s) \phi(s) \\ = \sum_{x \in S} d_{\pi_{\theta}}(s) \sum_{a \in A} \frac{d}{d \theta} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) \\ = \sum_{x \in S} d_{\pi_{\theta}}(s) \sum_{a \in A} \pi_{\theta}(a | s) Q_{\pi_{\theta}}(s, a) \frac{\frac{d}{d \theta} \pi_{\theta}(a | s)}{\pi_{\theta}(a | s)} \\ = \text{E}_{s \sim d_{\pi_{\theta}}, a \sim \pi_{\theta}} [Q_{\pi_{\theta}}(s, a) \frac{d}{d \theta} \ln(\pi_{\theta}(a | s))]

この展開を観察することで、報酬関数の勾配が状態-行動価値とポリシーの自然対数の導関数の積の期待値に比例していることがわかります。 この比例関係は ポリシー勾配定理 として知られており、勾配上昇法を実行するための勾配の計算を簡略化することができます。 直感的には、この計算は最高の状態-行動価値やアドバンテージを持つ行動を取る可能性を最大化する方向にパラメータを調整していると解釈できます。

ここでは、学習を安定させるためにQ(s,a)Q(s, a)の代わりにアドバンテージA(s,a)A(s, a)を使用することもできます。また、勾配を目的関数から 自動で計算し、様々な最適化アルゴリズムを用いることのできるDNNのフレームワークを用いる場合、J(θ)=Eπθ[Qπθ(s,a)ln(πθ(as))]J(\theta) = \text{E}_{\pi_{\theta}} [Q_{\pi_{\theta}}(s, a) \ln(\pi_{\theta}(a | s))]を 目的関数として与えることで、Eπθ[Qπθ(s,a)ddθln(πθ(as))]\text{E}_{\pi_{\theta}} [Q_{\pi_{\theta}}(s, a) \frac{d}{d \theta} \ln(\pi_{\theta}(a | s))]を導かせることができます。

REINFORCE

REINFORCE(モンテカルロポリシー勾配)はエピソードをサンプリングし、Qπ(s,a)Q_{\pi}(s, a)の近似としてGtG_t​を使用して勾配(θJ(θ)=Eπ[Gtθln(πθ(atst))]\nabla_{\theta} J(\theta) = \text{E}_{\pi}[G_t \nabla_{\theta} \ln(\pi_{\theta}(a_t | s_t))])を計算します。 以下は、パラメータ化された非線形ポリシー関数(フィードフォワードニューラルポリシーネットワーク)とGymnasiumのFrozen Lake環境でポリシーネットワークをトレーニングするためのREINFORCEアルゴリズムの実装例です。

class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(16, 8)
        self.fc2 = nn.Linear(8, 4)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=0)
 
def one_hot_encoding(index):
  encoding = np.array([1 if i == index else 0 for i in range(16)])
  return torch.tensor(encoding, dtype=torch.float32)
 
def policy_loss(action, sampled_action, result_sum, gamma, t):
  log_prob = torch.log(action[sampled_action])
  loss = log_prob * gamma**t * result_sum
  loss = -torch.mean(loss)
  return loss
 
def REINFORCE(env, policy_network, optimizer, num_episodes, gamma):
  stats = 0
  for episode in range(num_episodes):
      state, _ = env.reset()
      done = False
      results_list = []
      result_sum = 0.0
 
      # Policy Evaluation
      policy_network.eval()
      while not done:
          action = policy_network(one_hot_encoding(state))
          sampled_action = torch.multinomial(action, 1)[0]
          new_state, reward, done, _, _ = env.step(sampled_action.item())
          results_list.append((state, action, sampled_action.item()))
          result_sum += reward
          state = new_state
 
      # Policy Improvement
      if result_sum != 0:
        policy_network.train()
        for t, (state, action, sampled_action) in enumerate(results_list):
            # Backpropagation using Policy Gradient Theorem
            loss = policy_loss(action, sampled_action, result_sum, gamma, t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
      stats += result_sum
      if episode % 10 == 0 and episode != 0:
          print(f"episode: {episode}, success: {stats/10}")
          stats = 0
 
  print(f"episode: {episode}, success: {stats/episode}")
 
  env.close()
 
learning_rate = 0.001
gamma = 1
num_episodes = 1000
policy_network = PolicyNetwork()
optimizer = torch.optim.Adam(policy_network.parameters(), lr=learning_rate)
REINFORCE(env, policy_network, optimizer, num_episodes, gamma)

エピソードをサンプリングした後、エピソードを遡り、J(θ)=γtGtln(πθ(AtSt))J(\theta) = \gamma^t G_t \ln(\pi_{\theta}(A_t | S_t))でパラメータを更新します。 追加される項は、リターンがゼロの場合にゼロになり、この環境ではリターンはほとんどゼロであるため、学習が始まるまでに時間がかかり、最適なポリシーに収束するにはさらに時間がかかります。

Actor-Critic

ポリシー勾配を計算するためにモンテカルロ法を使用する代わりに、関数近似器qwq_w​をπθ\pi_{\theta}​と一緒にトレーニングし、qwq_w​の出力を使用して各ステップでポリシーを更新することでTD(0)を利用できます(θJ(θ)=Eπ[qw(st,at)θln(πθ(atst))]\nabla_{\theta} J(\theta) = \text{E}_{\pi}[q_w(s_t, a_t) \nabla_{\theta} \ln(\pi_{\theta}(a_t | s_t))])。 この場合、πθ\pi_{\theta}​を探索者(actor)として、qwq_w​を批評家(critic)として解釈できるため、この方法はActor-Criticと呼ばれます。以下は、Frozen Lake環境でのActor-Criticの実装例です。

class QNetwork(nn.Module):
    def __init__(self):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, 1)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
 
def q_one_hot_encoding(state, action):
  state_encoding = one_hot_encoding(state)
  action_encoding = np.array([1 if i == action else 0 for i in range(4)])
  action_encoding = torch.tensor(action_encoding, dtype=torch.float32)
  return torch.cat((state_encoding, action_encoding), dim=0)
 
def Actor_Critic(env, policy_network, q_network, policy_optimizer, q_optimizer, num_episodes, gamma):
  torch.autograd.set_detect_anomaly(True)
  for episode in range(num_episodes):
      state, _ = env.reset()
      done = False
      action = policy_network(one_hot_encoding(state))
      sampled_action = torch.multinomial(action, 1)[0]
      q_value = q_network(q_one_hot_encoding(state, sampled_action.item()))
 
      while not done:
          # Forward Pass
          policy_network.eval()
          q_network.eval()
          new_state, reward, done, _, _ = env.step(sampled_action.item())
          next_action = policy_network(one_hot_encoding(new_state))
          next_sampled_action = torch.multinomial(next_action, 1)[0]
 
          # Polocy Update
          policy_network.train()
          loss = policy_loss(action, sampled_action, q_value, gamma, 0)
          policy_optimizer.zero_grad()
          loss.backward()
          policy_optimizer.step()
 
          # Q Update
          q_value_next = q_network(q_one_hot_encoding(new_state, next_sampled_action.item()))
          q_network.train()
          td_loss = (reward + gamma * q_value_next - q_value)**2
          q_optimizer.zero_grad()
          td_loss.backward()
          q_optimizer.step()
 
          state = new_state
          action = next_action
          sampled_action = next_sampled_action
          q_value = q_value_next
 
      if episode % 10 == 0 and episode != 0:
        print(f"episode: {episode}/{num_episodes}")
 
  print(f"episode: {episode}/{num_episodes}")
 
  env.close()
 
q_network = QNetwork()
policy_optimizer = torch.optim.Adam(policy_network.parameters(), lr=learning_rate)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=learning_rate)
Actor_Critic(env, policy_network, q_network, policy_optimizer, q_optimizer, num_episodes, gamma)

状態が複雑な場合、2つのネットワークは潜在状態埋め込みを生成するために同じネットワークを共有できます。 このActor-Critic法は反復が速い傾向がありますが、データがi.i.d.(独立同一分布)でないため、依然としてポリシーの学習に苦労します。

適格度トレース

前述したように、学習を安定させ、批評家が状態価値のみを出力できるようにするために、ポリシーの目的関数でQ値の代わりにアドバンテージを使用することができます。 この変更により、ポリシーネットワークと批評家ネットワーク間で状態処理部分を共有でき、モデルアーキテクチャが簡素化されます。TD(0)を使用する場合、 探索者と批評家の目的関数は次のようになります:

J(θ)Et[A(st,at)ln(πθ(atst))]+Et[(V(st)Vw(st))2]=Et[Q(st,at)V(st)ln(πθ(atst))]+Et[(V(st)Vw(st))2]Et[rt+γVw(st+1)Vw(st)ln(πθ(atst))]+Et[(rt+γVw(st+1)Vw(st))2] J(\theta) \approx \text{E}_t[A(s_t, a_t) \ln(\pi_{\theta}(a_t | s_t))] + \text{E}_t[(V(s_t) - V_{w}(s_t))^2] \\ = \text{E}_t[Q(s_t, a_t) - V(s_t) \ln(\pi_{\theta}(a_t | s_t))] + \text{E}_t[(V(s_t) - V_{w}(s_t))^2] \\ \approx \text{E}_t[r_t + \gamma V_{w}(s_{t+1}) - V_{w}(s_t) \ln(\pi_{\theta}(a_t | s_t))] + \text{E}_t[(r_t + \gamma V_{w}(s_{t+1}) - V_{w}(s_t))^2]

ここで、アドバンテージとTD(0)ターゲットを使用した価値の差がTD(0)損失δt\delta_{t}​になることに気づくことができます。 単純な損失計算にはTD(0)を使用できますが、一歩先しか見ないため、バイアスが高く分散が低いという特徴があります。 多くの場合、環境に応じてバイアスと分散のバランスを取るために、次のようにより長い適格度トレースでさらに先を見ることが望ましいです。

J(θ)Et[rt+γrt+1+...+γ(Tt+1)r(Tt+1)+γ(Tt)Vw(s(Tt))Vw(st)ln(πθ(atst))]+Et[(rt+γrt+1+...+γ(Tt+1)r(Tt+1)+γ(Tt)Vw(s(Tt))Vw(st))2] J(\theta) \approx \text{E}_t[r_t + \gamma r_{t+1} + ... + \gamma^{(T-t+1)} r_{(T-t+1)} + \gamma^{(T-t)} V_{w}(s_{(T-t)}) - V_{w}(s_t) \ln(\pi_{\theta}(a_t | s_t))] \\ + \text{E}_t[(r_t + \gamma r_{t+1} + ... + \gamma^{(T-t+1)} r_{(T-t+1)} + \gamma^{(T-t)} V_{w}(s_{(T-t)}) - V_{w}(s_t))^2]

次のコードは、Q値の代わりにアドバンテージを使用し、適格度トレースを利用するActor-Critic法を実装しています:

class ActorCriticNetwork(nn.Module):
    def __init__(self):
        super(ActorCriticNetwork, self).__init__()
        self.state_processing = nn.Linear(16, 8)
        self.actor = nn.Linear(8, 4)
        self.critic = nn.Linear(8, 1)
 
    def forward(self, x):
        state_embed = F.relu(self.state_processing(x))
        action_probs = F.softmax(self.actor(state_embed), dim=0)
        value = self.critic(state_embed)
        return action_probs, value
 
def Actor_Critic(env, network, optimizer, num_episodes, gamma, len_trace):
  stats = 0.0
  for episode in range(num_episodes):
      state, _ = env.reset()
      done = False
      results_list = []
      result_sum = 0.0
 
      with torch.no_grad():
          action, _ = network(one_hot_encoding(state))
          sampled_action = torch.multinomial(action, 1)[0]
 
      # Policy Evaluation
      network.eval()
      while not done:
        state, reward, done, _, _ = env.step(sampled_action.item())
 
        with torch.no_grad():
          next_action, next_value = network(one_hot_encoding(state))
          next_sampled_action = torch.multinomial(action, 1)[0]
 
        results_list.append((state, sampled_action.item(), reward))
        result_sum += reward
        action = next_action
        sampled_action = next_sampled_action
 
        # Policy Improvement
        if (len(results_list) == len_trace) or done:
          target_value = 0 if done else next_value
          for t, (state, sampled_action, reward) in enumerate(reversed(results_list)):
              target_value = reward + gamma * target_value
              network.eval()
              action, value = network(one_hot_encoding(state))
              network.train()
              td_loss = target_value - value # advantage / v_target - v
              actor_loss = policy_loss(action, sampled_action, td_loss, gamma, 0)
              critic_loss = torch.mean(td_loss**2)
              loss = actor_loss + critic_loss
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
        results_list.pop(0)
 
      stats += result_sum
      if episode % 10 == 0 and episode != 0:
          print(f"episode: {episode}, success: {stats/10}")
          stats = 0.0
 
  print(f"episode: {episode}, success: {stats/episode}")
 
  env.close()
 
network = ActorCriticNetwork()
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
Actor_Critic(env, network, optimizer, num_episodes, gamma, len_trace=5)

目標の状態価値は、報酬を再帰的に合計することで計算されます。アドバンテージを使用することにより、探索家と批評家のネットワークを統合でき、 これらは結合された損失を使って訓練することができます。このモデルでは、バイアスと分散のバランスを取るために適切な適格度トレースの長さを選択することが不可欠です。

GAE

適格度トレースの代わりに、以下のようにハイパーパラメータλ\lambdaでスケールされた将来の状態のTD(0)損失(将来の状態のアドバンテージ)の割引合計を計算することでアドバンテージを計算することもできます。

At=l=0λlγl(rt+l+γVw(st+l+1)Vw(st+l)) A_t = \sum_{l=0}^{\infty} \lambda^l \gamma^l(r_{t+l} + \gamma V_w(s_{t+l+1}) - V_w(s_{t+l}))

λ\lambdaが0に設定されると、初期のTD損失(l=0l=0の場合)のみが非ゼロとなり、アドバンテージはTD(0)損失と同じになります。 一方、λ\lambdaが1に設定されると、すべての将来の報酬を合計することになり、モンテカルロ法と同等になります。 したがって、スケーラーλ\lambdaの値を0から1の間で変更することで、バイアスと分散を制御することができます。 TD(0)とモンテカルロ法の間のスペクトル上でアドバンテージ計算を一般化するこの方法は、 一般化アドバンテージ推定(GAE) と呼ばれ、 様々なActor-Critic法で使用されています。(実際には、将来のTD(0)損失を合計するために上限TTを設定します。)

結論

この記事では、ポリシー勾配法、これらの方法の背後にあるポリシー勾配定理、そしてオンポリシーポリシー勾配法であるREINFORCEとActor-Criticについて紹介しました。 また、アドバンテージと適格度トレースがActor-Critic法とGAEの文脈でどのように実装できるかを示しました。示されたように、これらのオンポリシー法は、 非i.i.d.データと探索の困難さのために学習に苦労します。次の記事では、これらの問題を克服することを目的としたより高度なアルゴリズムのいくつかを紹介します。

リソース