JMANI

Lecture 3: Dummy Q-learning (table) by Sung Kim 본문

AI/Reinforcement Learning

Lecture 3: Dummy Q-learning (table) by Sung Kim

jmani 2022. 5. 19. 17:10

link : https://www.youtube.com/watch?v=Vd-gmo-qO5E&list=PLlMkM4tgfjnKsCWav-Z2F-MMFRx-2gMGG&index=4

  • 일반적으로 Env 알 수 없음
  • (1) state와 (2) action 일 때, (3) quality(reward)를 받아 잘 가고있는지 판단할 수 있음

  • Max Q: Q의 최대값
  • argmax Q: Q가 최대값이 되게하는 action
  • *: optimal, 항상 최대값을 가지는 방향으로 움직이는 것
  • pi: policy

  • 현재의 Q값은 action을 수행해서 받는 r(reward)와 다음 단계에서 얻을 수 있는 max Q의 합
  • 반복

  1. table를 만들어서 Q를 0으로 초기화
  2. s를 가져옴
  3. 무한 반복
    • 어떤 action을 취함
    • reward 받음
    • 다음 s로 이동
    • Q 업데이트(위의 수식 사용)
    • s 변환
import gym
import numpy as np
import matplotlib.pyplot as plt
from gym.envs.registration import register
import random as pr

# argmax가 똑같을 때는 랜덤한 방향으로 이동하는 함수
def rargmax(vector):
    # Argmax that chooses randomly among eligible maximum indices.
    m = np.amax(vector)
    indices = np.nonzero(vector == m)[0]
    return pr.choice(indices)

# Register Frozen with is_slippery False
register(
    id='FrozenLake-v3',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={'map_name': '4x4', 'is_slippery': False}
)

env = gym.make('FrozenLake-v3')

# Initialize table with all zeros
Q = np.zeros([env.observation_space.n, env.action_space.n])  # 16, 4
# Set learning parameters
num_episodes = 2000

# create lists to contain total rewards ans steps per episode
rList = []
for i in range(num_episodes):
    # Reset environment and get first new observation
    state = env.reset()
    rAll = 0
    done = False

    # The Q-Table learning algorithm
    while not done:
        action = rargmax(Q[state, :])

        # Get new sate and reward from environment
        new_state, reward, done, _ = env.step(action)

        # Update Q-Table with new knowledge using learning rate
        Q[state, action] = reward + np.max(Q[new_state, :])

        rAll += reward
        state = new_state

    rList.append(rAll)

print("Success rate: " + str(sum(rList)/num_episodes))
print("Final Q-Table Values")
print("LEFT DOWN RIGHT UP")
print(Q)
plt.bar(range(len(rList)), rList, color="blue")
plt.show()

95% 성공, Q 출력
초반 이후부터 모두 성공

Comments