Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System.Collections;
- using System.Collections.Generic;
- using UnityEngine;
- using Unity.MLAgents;
- using CardLogic;
- using System.Linq;
- using Unity.MLAgents.Sensors;
- public class MonsterAgent : Agent
- {
- private AI UnitAI;
- private AI[] localUnits;
- private Vector3 startPosition;
- private int latestCard;
- public override void Initialize()
- {
- UnitAI = GetComponent<AI>();
- //Debug.Log(UnitAI);
- localUnits = transform.parent.GetComponentsInChildren<AI>();
- UnitAI.OnDeath.DynamicCalls += (Unit u, Unit d, int dgm, bool b) =>
- {
- //Debug.Log(string.Format("{0} Died", u.name));
- /* TwitchViewer user = TwitchViewer.GetUser(name.ToLower());
- if (user != null)
- {
- user.GainPoins(1);
- }*/
- if (u != d) // Suicide should not count since we clear the board with it.
- {
- AddReward(-1);
- /*MonsterAgent agent = u.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- localUnits
- .Where(value => value != null && value.faction == u.faction)
- .Select(value => value.GetComponent<MonsterAgent>())
- .ToList()
- .ForEach(value => {
- if (u.faction != d.faction) agent.AddReward(-0.1f);
- });
- agent.AddReward(-1);
- }
- agent = d.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- localUnits
- .Where(value => value != null && value.faction == d.faction)
- .Select(value => value.GetComponent<MonsterAgent>())
- .ToList()
- .ForEach(value => {
- if (u.faction != d.faction) value.AddReward(0.1f);
- });
- if (u.faction != d.faction) agent.AddReward(1);
- else agent.AddReward(-1f);
- }*/
- }
- EndEpisode();
- };
- /*UnitAI.OnDamage.DynamicCalls += (Unit u, Unit d, int dgm, bool b) => {
- //Debug.Log(string.Format("{0} Damaged", u.name));
- MonsterAgent agent = u.GetComponent<MonsterAgent>();
- // if (agent != null) agent.AddReward(-0.1f);
- agent = d.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- if (u.faction != d.faction) AddReward(0.1f);
- else agent.AddReward(-0.05f);
- }
- };*/
- UnitAI.OnStackEnd.AddListener(delegate {
- UnitAI.currPlayCard = 0;
- });
- }
- // Reset
- public override void OnEpisodeBegin() { }
- // Got data
- public override void OnActionReceived(float[] vectorAction)
- {
- //Debug.Log(string.Format("{0} Action Recived ({1}, ({2}, {3}))", transform.name, vectorAction[0], vectorAction[1], vectorAction[2]));
- if (UnitAI.isMyTurn && !GameData.lockHand)
- {
- bool alreadyPlayedCard = false;
- // Check what cards that can be played
- List<Card> availableCards = UnitAI.hand.Where(value =>
- {
- List<Game.Keyword.CheckPlay> CheckPlay = value.keywords.Select(k => k.Get<Game.Keyword.CheckPlay>()).ToList();
- foreach (Game.Keyword.CheckPlay stack in CheckPlay)
- {
- if (stack != null && !stack.CheckPlay(UnitAI, value)) return false;
- }
- return true;
- }).ToList();
- if (availableCards.Exists(value => value.keywords.Exists(key => key.keyword == "Priority")))
- {
- availableCards = availableCards.Where(value => value.keywords.Exists(key => key.keyword == "Priority")).ToList();
- }
- if (0 <= vectorAction[0] && UnitAI.currPlayCard < UnitAI.maxPlayCard)
- {
- if (availableCards.Count > 0)
- {
- int cardIndex = Mathf.RoundToInt(vectorAction[0] * (availableCards.Count - 1));
- latestCard = Helper.StringToInt(availableCards[cardIndex].Description());
- UnitAI.AddToStack(availableCards[cardIndex]);
- UnitAI.currPlayCard++;
- alreadyPlayedCard = true;
- }
- if (UnitAI.currPlayCard == UnitAI.maxPlayCard || availableCards.Count == 0)
- {
- UnitAI.isMyTurn = false;
- // AddReward(0.1f);
- }
- }
- else if (0 > vectorAction[0])
- {
- UnitAI.isMyTurn = false;
- }
- /*else if (UnitAI.currPlayCard <= UnitAI.maxPlayCard && !alreadyPlayedCard)
- {
- AddReward(-0.01f);
- }*/
- }
- // Should pick a tile based on what card it played.
- if (GameData.lockHand)
- {
- Vector3Int selectedPos = new Vector3Int(
- Mathf.FloorToInt(transform.position.x + (vectorAction[1] * 10)),
- Mathf.FloorToInt(transform.position.y + (vectorAction[2] * 10)),
- 0
- );
- GameData.main.Placeholder.Select(selectedPos);
- //if (!GameData.lockHand) AddReward(0.02f);
- //else AddReward(-0.1f);
- }
- // if (StepCount > 0) AddReward(-0.00001f * AILearning.totalCount);
- //Debug.Log(string.Format("{0} Reward: {1}", transform.name, GetCumulativeReward()));
- }
- // Give data
- public override void CollectObservations(VectorSensor sensor)
- {
- //Debug.Log(string.Format("{0} Collect Observations", transform.name));
- AddUnitObservation(sensor, UnitAI);
- sensor.AddObservation(UnitAI.currPlayCard);
- sensor.AddObservation(UnitAI.maxPlayCard);
- sensor.AddObservation(UnitAI.hand.Count);
- sensor.AddObservation(UnitAI.startDeck.Count);
- sensor.AddObservation(UnitAI.isMyTurn);
- sensor.AddObservation(GameData.lockHand);
- sensor.AddObservation(GameData.main.Placeholder.Range);
- sensor.AddObservation(latestCard);
- localUnits = transform.parent.GetComponentsInChildren<AI>(); // Update the list incase of removed units
- sensor.AddObservation(localUnits.Count());
- List<AI> listUnits = localUnits.Where(value => value != UnitAI).OrderBy(value => Vector2.Distance(value.transform.position, UnitAI.transform.position)).Take(10).ToList();
- listUnits.ForEach(value => {
- AddUnitObservation(sensor, value);
- });
- for (int x = 10 - listUnits.Count; x > 0; x--)
- {
- sensor.AddObservation(0); // Name
- sensor.AddObservation(Vector3.zero); // Pos
- sensor.AddObservation(0); // Health
- sensor.AddObservation(0); // Max Health
- sensor.AddObservation(0); // Defence
- sensor.AddObservation(0); // Speed
- sensor.AddObservation(0); // Faction
- sensor.AddObservation(0); // Facing
- }
- UnitAI.hand.Take(15).ToList().ForEach(value => sensor.AddObservation(Helper.StringToInt(value.Description())));
- for (int x = 15 - UnitAI.hand.Take(15).ToList().Count(); x > 0; x--)
- {
- sensor.AddObservation(0);
- }
- GameData.main.Placeholder.LayoutObservation(sensor, transform.position, 10);
- // How many of each card do i have in my deck
- /*UnitAI.startDeck.Where(value => Database.main.cards.Contains(value)).Take(100).ToList().ForEach(value => sensor.AddObservation(Database.main.cards.Find(v => value == v).GetInstanceID()));
- for (int x = 100 - UnitAI.startDeck.Take(100).ToList().Count(); x > 0; x--)
- {
- sensor.AddObservation(0);
- }*/
- }
- private void AddUnitObservation(VectorSensor sensor, AI unit) // 7 observations
- {
- sensor.AddObservation(Helper.StringToInt(unit.transform.name));
- sensor.AddObservation(unit.transform.position);
- sensor.AddObservation(unit.health);
- sensor.AddObservation(unit.maxHealth);
- sensor.AddObservation(unit.defence);
- sensor.AddObservation(unit.speed);
- sensor.AddObservation((int)unit.faction);
- sensor.AddObservation((int)unit.facing);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement