| using System; |
| using System.Collections; |
| using System.Collections.Generic; |
| using System.Collections.ObjectModel; |
| using System.Linq; |
| using UnityEngine; |
|
|
| [Serializable] |
| public enum Action |
| { |
| Up, |
| Down, |
| Left, |
| Right, |
| None |
| } |
| [Serializable] |
| public class Agent : MonoBehaviour |
| { |
| #region Fields |
| [SerializeField] |
| private int _step; |
| [SerializeField] |
| private int _iteration; |
| [SerializeField] |
| private int _currentGridX; |
| [SerializeField] |
| private int _currentGridY; |
| [SerializeField] |
| private (int,int)? _previousState = null; |
| [SerializeField] |
| private Action? _previousAction = null; |
| [SerializeField] |
| private float? _previousReward = null; |
| [SerializeField] |
| private GUIController _gUIController; |
| [SerializeField] |
| [Range(0f, 1f)] |
| private float _learningRate; |
| [SerializeField] |
| [Range(0f, 1f)] |
| private float _discountingFactor; |
| |
| [SerializeField] |
| private int _mimumumStateActionPairFrequencies; |
| [SerializeField] |
| private float _estimatedBestPossibleRewardValue; |
| [SerializeField] |
| private Coroutine _waitThenActionCoroutine; |
| [SerializeField] |
| private bool _isPause; |
| [SerializeField] |
| [Range(0.001f, 30f)] |
| private float _restTime; |
| [SerializeField] |
| private GameObject _roadBlock; |
| [SerializeField] |
| private GameObject _Goodies; |
|
|
| public int Step { get => _step; set => _step = value; } |
| public int Iteration { get => _iteration; set => _iteration = value; } |
| public int CurrentGridX { get => _currentGridX; set => _currentGridX = value; } |
| public int CurrentGridY { get => _currentGridY; set => _currentGridY = value; } |
| public (int, int)? PreviousState { get => _previousState; set => _previousState = value; } |
| public Action? PreviousAction { get => _previousAction; set => _previousAction = value; } |
| public float? PreviousReward { get => _previousReward; set => _previousReward = value; } |
| public GUIController GUIController { get => _gUIController; set => _gUIController = value; } |
| public float LearningRate { get => _learningRate; set => _learningRate = value; } |
| public float DiscountingFactor { get => _discountingFactor; set => _discountingFactor = value; } |
| public int MimumumStateActionPairFrequencies { get => _mimumumStateActionPairFrequencies; set => _mimumumStateActionPairFrequencies = value; } |
| public float EstimatedBestPossibleRewardValue { get => _estimatedBestPossibleRewardValue; set => _estimatedBestPossibleRewardValue = value; } |
| public Coroutine WaitThenActionCoroutine { get => _waitThenActionCoroutine; set => _waitThenActionCoroutine = value; } |
| public bool IsPause { get => _isPause; set => _isPause = value; } |
| public float RestTime { get => _restTime; set => _restTime = value; } |
| public GameObject RoadBlock { get => _roadBlock; set => _roadBlock = value; } |
| public GameObject Goodies { get => _Goodies; set => _Goodies = value; } |
|
|
| public (int,int) StartState; |
| public (int,int) FinalState = (7,9); |
| |
| public int StartX; |
| public int StartY; |
| |
| public int GrizSizeX; |
| public int GrizSizeY; |
|
|
| public Dictionary<((int,int),Action),float> StateActionPairQValue { get; set; } |
| |
| public Dictionary<(int, int), float> StateRewardGrid { get; set; } |
| public Dictionary<Action, System.Action> ActionDelegatesDictonary { get; set; } |
| #endregion |
|
|
| #region Q_Learning_Agent |
| private Action Q_Learning_Agent((int,int) currentState, float rewardSignal) |
| { |
| UpdateStep(); |
| if (PreviousState == FinalState) |
| { |
| StateActionPairQValue[(PreviousState.Value, Action.None)] = rewardSignal; |
| } |
|
|
| if (PreviousState.HasValue) |
| { |
| ((int, int), Action) stateActionPair = (PreviousState.Value, PreviousAction.Value); |
| |
| |
| |
| |
| |
|
|
| StateActionPairQValue[stateActionPair] += LearningRate * (PreviousReward.Value + (DiscountingFactor * MaxStateActionPairQValue(ref currentState)) - StateActionPairQValue[stateActionPair]); |
| } |
| PreviousState = currentState; |
| PreviousAction = ArgMaxActionExploration(ref currentState); |
| PreviousReward = rewardSignal; |
| return PreviousAction.Value; |
| } |
|
|
| |
| private float MaxStateActionPairQValue(ref (int, int) currentState) |
| { |
| if (currentState == FinalState) |
| return StateActionPairQValue[(currentState, Action.None)]; |
|
|
| float max = float.NegativeInfinity; |
|
|
| foreach (Action action in SuffledActions()) |
| { |
| max = Mathf.Max(StateActionPairQValue[(currentState, action)], max); |
| } |
| return max; |
| } |
|
|
| private static Action[] SuffledActions() |
| { |
| Action[] actions = new Action[4]; |
| int i = 0; |
| foreach (Action action in Enum.GetValues(typeof(Action))) |
| { |
| if (action != Action.None) |
| { |
| actions[i] = action; |
| i++; |
| } |
| } |
| System.Random random = new System.Random(); |
| return actions.OrderBy(_ => random.Next()).ToArray(); |
| } |
| #region Conflicts with the wall check and out of bound check |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #endregion |
| private Action ArgMaxActionExploration(ref (int, int) currentState) |
| { |
| if (currentState == FinalState) |
| return Action.None; |
|
|
| Action argMaxAction = Action.None; |
| float max = float.NegativeInfinity; |
|
|
| foreach (Action action in SuffledActions()) |
| { |
| float value = StateActionPairQValue[(currentState, action)]; |
| if (value >= max) |
| { |
| max = value; |
| argMaxAction = action; |
| } |
| } |
| return argMaxAction; |
| } |
| private void Left() |
| { |
| transform.position -= new Vector3(1f, 0f, 0f); |
| CurrentGridX--; |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| } |
|
|
| private void Right() |
| { |
| transform.position += new Vector3(1f, 0f, 0f); |
| CurrentGridX++; |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| } |
|
|
| private void Up() |
| { |
| transform.position += new Vector3(0f, 0f, 1f); |
| CurrentGridY++; |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| } |
|
|
| private void Down() |
| { |
| transform.position -= new Vector3(0f, 0f, 1f); |
| CurrentGridY--; |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| } |
|
|
| private void None() |
| { |
| ResetAgentToStart(); |
| UpdateIteration(); |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| } |
|
|
| private void ResetAgentToStart() |
| { |
| transform.position = new Vector3(StartState.Item1, 1f, StartState.Item2); |
| CurrentGridX = StartState.Item1; |
| CurrentGridY = StartState.Item2; |
| Grid.instance.ClearColors(); |
| } |
|
|
| private IEnumerator WaitThenAction(float waitTime, (int,int) GridCoordinate) |
| { |
| |
| while(IsPause) |
| { |
| yield return null; |
| } |
| yield return new WaitForSeconds(waitTime); |
| ActionDelegatesDictonary[Q_Learning_Agent(GridCoordinate, StateRewardGrid[GridCoordinate])](); |
| } |
| #endregion |
|
|
| #region Unity |
| private void Start() |
| { |
| FinalState = Grid.instance.goalPosition; |
| |
| ActionDelegatesDictonary = new Dictionary<Action, System.Action>(); |
| ActionDelegatesDictonary[Action.Left] = Left; |
| ActionDelegatesDictonary[Action.Right] = Right; |
| ActionDelegatesDictonary[Action.Up] = Up; |
| ActionDelegatesDictonary[Action.Down] = Down; |
| ActionDelegatesDictonary[Action.None] = None; |
| StartX = UnityEngine.Random.Range(0, GrizSizeX); |
| StartY = UnityEngine.Random.Range(0, GrizSizeY); |
| Initialized(); |
| } |
|
|
| private void Initialized() |
| { |
| PreviousAction = null; |
| PreviousReward = null; |
| PreviousState = null; |
| Step = 0; |
| Iteration = 0; |
| transform.position = new Vector3(StartX, 1f, StartY); |
| StartState = (StartX, StartY); |
| CurrentGridX = StartState.Item1; |
| CurrentGridY = StartState.Item2; |
| StateActionPairQValue = new Dictionary<((int, int), Action), float>(); |
| |
| StateRewardGrid = new Dictionary<(int, int), float>(); |
|
|
| for (int i = 0; i < GrizSizeX; i++) |
| { |
| for (int j = 0; j < GrizSizeY; j++) |
| { |
| foreach (Action action in Enum.GetValues(typeof(Action))) |
| { |
| StateActionPairQValue[((i, j), action)] = 0; |
| |
| } |
| StateRewardGrid[(i, j)] = 0f; |
| } |
| } |
| StateRewardGrid[FinalState] = 100f; |
|
|
| for (int i = 0; i < GrizSizeX; i++) |
| { |
| for (int j = 0; j < GrizSizeY; j++) |
| { |
| if (i != StartState.Item1 && i != FinalState.Item1 && j != StartState.Item2 && j != FinalState.Item2) |
| { |
| float random = UnityEngine.Random.Range(0f, 1f); |
| if (random <= 0.3f) |
| { |
| if (random <= 0.2f) |
| { |
| Instantiate(RoadBlock, new Vector3(i, 0.5f, j), Quaternion.identity); |
| if (i + 1 < GrizSizeX) |
| { |
| StateActionPairQValue[((i + 1, j), Action.Left)] = float.NegativeInfinity; |
| } |
| if (i - 1 >= 0) |
| { |
| StateActionPairQValue[((i - 1, j), Action.Right)] = float.NegativeInfinity; |
| } |
| if (j + 1 < GrizSizeY) |
| { |
| StateActionPairQValue[((i, j + 1), Action.Down)] = float.NegativeInfinity; |
| } |
| if (j - 1 >= 0) |
| { |
| StateActionPairQValue[((i, j - 1), Action.Up)] = float.NegativeInfinity; |
| } |
| } |
| |
| |
| |
| |
| |
| } |
| } |
| if (i == 0 || j == 0 || i == GrizSizeX-1 || j == GrizSizeY-1) |
| { |
| StateRewardGrid[(i, j)] = 0f; |
| |
| if(i == 0) |
| { |
| StateActionPairQValue[((i, j), Action.Left)] = float.NegativeInfinity; |
| } |
| if(j == 0) |
| { |
| StateActionPairQValue[((i, j), Action.Down)] = float.NegativeInfinity; |
| } |
| if(i == GrizSizeX-1) |
| { |
| StateActionPairQValue[((i, j), Action.Right)] = float.NegativeInfinity; |
| } |
| if(j == GrizSizeY-1) |
| { |
| StateActionPairQValue[((i, j), Action.Up)] = float.NegativeInfinity; |
| } |
| } |
| } |
| } |
| } |
| private void ReInitialized() |
| { |
| PreviousAction = null; |
| PreviousReward = null; |
| PreviousState = null; |
| Step = 0; |
| Iteration = 0; |
| transform.position = new Vector3(StartX, 1f, StartY); |
| StartState = (StartX, StartY); |
| CurrentGridX = StartState.Item1; |
| CurrentGridY = StartState.Item2; |
| |
|
|
| for (int i = 0; i < GrizSizeX; i++) |
| { |
| for (int j = 0; j < GrizSizeY; j++) |
| { |
| foreach (Action action in Enum.GetValues(typeof(Action))) |
| { |
| if(!(StateActionPairQValue.ContainsKey(((i, j), action)) && StateActionPairQValue[((i, j), action)] == float.NegativeInfinity)) |
| { |
| StateActionPairQValue[((i, j), action)] = 0; |
| |
| } |
| } |
| } |
| } |
| } |
| private void Update() |
| { |
| Grid.instance.UpdateColor(CurrentGridX, CurrentGridY); |
| } |
|
|
| public void StartExploring() |
| { |
| UpdateIteration(); |
| WaitThenActionCoroutine = StartCoroutine(WaitThenAction(1f, StartState)); |
| } |
|
|
| public void Stop() |
| { |
| ReInitialized(); |
| StopCoroutine(WaitThenActionCoroutine); |
| } |
|
|
| private void UpdateStep() |
| { |
| Step++; |
| GUIController?.UpdateStepText(Step.ToString()); |
| } |
|
|
| private void UpdateIteration() |
| { |
| Iteration++; |
| GUIController?.UpdateInterationText(Iteration.ToString()); |
| } |
| #endregion |
| } |