Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System.Collections;
- using System.Collections.Generic;
- using UnityEngine;
- using Unity.MLAgents;
- using CardLogic;
- using System.Linq;
- using Unity.MLAgents.Sensors;
- using Unity.MLAgents.Actuators;
- public class MonsterAgent : Agent
- {
- private AI UnitAI;
- private AI[] localUnits;
- private Vector3 startPosition;
- private int latestCard;
- public override void Initialize()
- {
- UnitAI = GetComponent<AI>();
- //Debug.Log(UnitAI);
- localUnits = transform.parent.GetComponentsInChildren<AI>();
- UnitAI.OnDeath.DynamicCalls += (Unit u, Unit d, int dgm, bool b) =>
- {
- //Debug.Log(string.Format("{0} Died", u.name));
- /* TwitchViewer user = TwitchViewer.GetUser(name.ToLower());
- if (user != null)
- {
- user.GainPoins(1);
- }*/
- if (u != d) // Suicide should not count since we clear the board with it.
- {
- MonsterAgent agent = u.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- localUnits
- .Where(value => value != null && value.faction == u.faction)
- .Select(value => value.GetComponent<MonsterAgent>())
- .ToList()
- .ForEach(value => {
- if (u.faction != d.faction) agent.AddReward(-0.1f);
- });
- agent.AddReward(-1);
- }
- agent = d.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- localUnits
- .Where(value => value != null && value.faction == d.faction)
- .Select(value => value.GetComponent<MonsterAgent>())
- .ToList()
- .ForEach(value => {
- if (u.faction != d.faction) value.AddReward(0.1f);
- });
- if (u.faction != d.faction) agent.AddReward(1);
- else agent.AddReward(-1f);
- }
- }
- EndEpisode();
- };
- UnitAI.OnDamage.DynamicCalls += (Unit u, Unit d, int dgm, bool b) => {
- //Debug.Log(string.Format("{0} Damaged", u.name));
- MonsterAgent agent = u.GetComponent<MonsterAgent>();
- if (agent != null) agent.AddReward(-0.1f);
- agent = d.GetComponent<MonsterAgent>();
- if (agent != null)
- {
- if (u.faction != d.faction) AddReward(0.1f);
- else agent.AddReward(-0.05f);
- }
- };
- UnitAI.OnStackEnd.AddListener(delegate {
- UnitAI.currPlayCard = 0;
- });
- }
- // Reset
- public override void OnEpisodeBegin() { }
- // Got data
- public override void OnActionReceived(ActionBuffers actionBuffers)
- {
- //Debug.Log(string.Format("{0} Action Recived ({1}, ({2}, {3}))", transform.name, vectorAction[0], vectorAction[1], vectorAction[2]));
- if (UnitAI.isMyTurn && !GameData.lockHand)
- {
- bool alreadyPlayedCard = false;
- // Check what cards that can be played
- List<Card> availableCards = UnitAI.hand.Where(value =>
- {
- List<Game.Keyword.CheckPlay> CheckPlay = value.keywords.Select(k => k.Get<Game.Keyword.CheckPlay>()).ToList();
- foreach (Game.Keyword.CheckPlay stack in CheckPlay)
- {
- if (stack != null && !stack.CheckPlay(UnitAI, value)) return false;
- }
- return true;
- }).ToList();
- if (availableCards.Exists(value => value.keywords.Exists(key => key.keyword == "Priority")))
- {
- availableCards = availableCards.Where(value => value.keywords.Exists(key => key.keyword == "Priority")).ToList();
- }
- foreach (int value in actionBuffers.DiscreteActions)
- {
- Debug.Log(value);
- }
- if (actionBuffers.DiscreteActions[0] != 0 && UnitAI.currPlayCard < UnitAI.maxPlayCard)
- {
- if (availableCards.Count > 0 && availableCards.Count <= actionBuffers.DiscreteActions[0])
- {
- int cardIndex = actionBuffers.DiscreteActions[0] - 1;
- int playIndex = availableCards.FindIndex(value => Database.main.cards[cardIndex].card.name == value.name);
- latestCard = Helper.StringToInt(availableCards[playIndex].Description());
- UnitAI.AddToStack(availableCards[cardIndex]);
- UnitAI.currPlayCard++;
- alreadyPlayedCard = true;
- }
- if (UnitAI.currPlayCard == UnitAI.maxPlayCard || availableCards.Count == 0)
- {
- UnitAI.isMyTurn = false;
- // AddReward(0.1f);
- }
- }
- else if (actionBuffers.DiscreteActions[0] == 0)
- {
- UnitAI.isMyTurn = false;
- }
- /*else if (UnitAI.currPlayCard <= UnitAI.maxPlayCard && !alreadyPlayedCard)
- {
- AddReward(-0.01f);
- }*/
- }
- // Should pick a tile based on what card it played.
- if (GameData.lockHand)
- {
- Vector3Int selectedPos = new Vector3Int(
- Mathf.FloorToInt(transform.position.x + actionBuffers.DiscreteActions[1] - 15),
- Mathf.FloorToInt(transform.position.y + actionBuffers.DiscreteActions[2] - 15),
- 0
- );
- GameData.main.Placeholder.Select(selectedPos);
- //if (!GameData.lockHand) AddReward(0.02f);
- //else AddReward(-0.1f);
- }
- // if (StepCount > 0) AddReward(-0.00001f * AILearning.totalCount);
- //Debug.Log(string.Format("{0} Reward: {1}", transform.name, GetCumulativeReward()));
- }
- public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
- {
- // Get allowd cards to play
- List<Card> availableCards = UnitAI.hand.Where(value =>
- {
- List<Game.Keyword.CheckPlay> CheckPlay = value.keywords.Select(k => k.Get<Game.Keyword.CheckPlay>()).ToList();
- foreach (Game.Keyword.CheckPlay stack in CheckPlay)
- {
- if (stack != null && !stack.CheckPlay(UnitAI, value)) return false;
- }
- return true;
- }).ToList();
- // Get must card to play
- if (availableCards.Exists(value => value.keywords.Exists(key => key.keyword == "Priority")))
- {
- availableCards = availableCards.Where(value => value.keywords.Exists(key => key.keyword == "Priority")).ToList();
- }
- List<int> maskList = Database.main.cards
- .Where(value => !availableCards
- .Select(v => v.name)
- .Contains(value.card.name)
- ).Select((v) => Database.main.cards.IndexOf(v) + 1
- ).ToList();
- maskList.AddRange(Enumerable.Range(Database.main.cards.Count, 1000 - Database.main.cards.Count));
- Debug.Log(string.Join(", ", Enumerable.Range(0, 1000).ToList().Where(value => !maskList.Contains(value)).Select(v =>
- {
- if (v == 0) return "End Turn (0)";
- return $"{Database.main.cards[v - 1].card.name} ({v})";
- })));
- actionMask.WriteMask(0, maskList);
- // Mask the navigation
- if (GameData.main.Placeholder.AvailableTiles.Count > 0)
- {
- List<int> X = Enumerable.Range(0, 25).Where(v => !GameData.main.Placeholder.AvailableTiles.Select(value => value.x).Contains((int) transform.position.x + (v - 15))).ToList();
- actionMask.WriteMask(1, X);
- List<int> Y = Enumerable.Range(0, 25).Where(v => !GameData.main.Placeholder.AvailableTiles.Select(value => value.y).Contains((int) transform.position.y + (v - 15))).ToList();
- actionMask.WriteMask(2, Y);
- }
- }
- // Give data
- public override void CollectObservations(VectorSensor sensor)
- {
- AddUnitObservation(sensor, UnitAI);
- sensor.AddObservation(UnitAI.currPlayCard);
- sensor.AddObservation(UnitAI.maxPlayCard);
- sensor.AddObservation(UnitAI.hand.Count);
- sensor.AddObservation(UnitAI.startDeck.Count);
- sensor.AddObservation(UnitAI.isMyTurn);
- sensor.AddObservation(GameData.lockHand);
- sensor.AddObservation(GameData.main.Placeholder.Range);
- sensor.AddObservation(latestCard);
- localUnits = transform.parent.GetComponentsInChildren<AI>(); // Update the list incase of removed units
- sensor.AddObservation(localUnits.Count());
- List<AI> listUnits = localUnits.Where(value => value != UnitAI).OrderBy(value => Vector2.Distance(value.transform.position, UnitAI.transform.position)).Take(10).ToList();
- listUnits.ForEach(value => {
- AddUnitObservation(sensor, value);
- });
- for (int x = 10 - listUnits.Count; x > 0; x--)
- {
- sensor.AddObservation(0); // Name
- sensor.AddObservation(Vector3.zero); // Pos
- sensor.AddObservation(0); // Health
- sensor.AddObservation(0); // Max Health
- sensor.AddObservation(0); // Defence
- sensor.AddObservation(0); // Speed
- sensor.AddObservation(0); // Faction
- sensor.AddObservation(0); // Facing
- }
- UnitAI.hand.Take(15).ToList().ForEach(value => sensor.AddObservation(Database.main.cards.FindIndex(v => v.card.name == value.name) + 1));
- for (int x = 15 - UnitAI.hand.Take(15).ToList().Count(); x > 0; x--)
- {
- sensor.AddObservation(0);
- }
- GameData.main.Placeholder.LayoutObservation(sensor, transform.position, 10);
- // How many of each card do i have in my deck
- /*UnitAI.startDeck.Where(value => Database.main.cards.Contains(value)).Take(100).ToList().ForEach(value => sensor.AddObservation(Database.main.cards.Find(v => value == v).GetInstanceID()));
- for (int x = 100 - UnitAI.startDeck.Take(100).ToList().Count(); x > 0; x--)
- {
- sensor.AddObservation(0);
- }*/
- }
- private void AddUnitObservation(VectorSensor sensor, AI unit) // 7 observations
- {
- sensor.AddObservation(Helper.StringToInt(unit.transform.name));
- sensor.AddObservation(unit.transform.position);
- sensor.AddObservation(unit.health);
- sensor.AddObservation(unit.maxHealth);
- sensor.AddObservation(unit.defence);
- sensor.AddObservation(unit.speed);
- sensor.AddObservation((int)unit.faction);
- sensor.AddObservation((int)unit.facing);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement