Advertisement
Guest User

Untitled

a guest
Jan 6th, 2021
29
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 10.86 KB | None | 0 0
  1. using System.Collections;
  2. using System.Collections.Generic;
  3. using UnityEngine;
  4. using Unity.MLAgents;
  5. using CardLogic;
  6. using System.Linq;
  7. using Unity.MLAgents.Sensors;
  8. using Unity.MLAgents.Actuators;
  9.  
  10. public class MonsterAgent : Agent
  11. {
  12.     private AI UnitAI;
  13.     private AI[] localUnits;
  14.     private Vector3 startPosition;
  15.     private int latestCard;
  16.  
  17.  
  18.     public override void Initialize()
  19.     {
  20.         UnitAI = GetComponent<AI>();
  21.         //Debug.Log(UnitAI);
  22.         localUnits = transform.parent.GetComponentsInChildren<AI>();
  23.         UnitAI.OnDeath.DynamicCalls += (Unit u, Unit d, int dgm, bool b) =>
  24.         {
  25.  
  26.             //Debug.Log(string.Format("{0} Died", u.name));
  27.             /* TwitchViewer user = TwitchViewer.GetUser(name.ToLower());
  28.              if (user != null)
  29.              {
  30.                  user.GainPoins(1);
  31.              }*/
  32.  
  33.             if (u != d) // Suicide should not count since we clear the board with it.
  34.             {
  35.                MonsterAgent agent = u.GetComponent<MonsterAgent>();
  36.                if (agent != null)
  37.                {
  38.                     localUnits
  39.                         .Where(value => value != null && value.faction == u.faction)
  40.                         .Select(value => value.GetComponent<MonsterAgent>())
  41.                         .ToList()
  42.                         .ForEach(value => {
  43.                             if (u.faction != d.faction) agent.AddReward(-0.1f);
  44.                         });
  45.  
  46.                     agent.AddReward(-1);
  47.                 }
  48.  
  49.                 agent = d.GetComponent<MonsterAgent>();
  50.                 if (agent != null)
  51.                 {
  52.                     localUnits
  53.                         .Where(value => value != null && value.faction == d.faction)
  54.                         .Select(value => value.GetComponent<MonsterAgent>())
  55.                         .ToList()
  56.                         .ForEach(value => {
  57.                             if (u.faction != d.faction) value.AddReward(0.1f);
  58.                         });
  59.  
  60.                     if (u.faction != d.faction) agent.AddReward(1);
  61.                     else agent.AddReward(-1f);
  62.                 }
  63.             }
  64.  
  65.             EndEpisode();
  66.         };
  67.         UnitAI.OnDamage.DynamicCalls += (Unit u, Unit d, int dgm, bool b) => {
  68.  
  69.             //Debug.Log(string.Format("{0} Damaged", u.name));
  70.             MonsterAgent agent = u.GetComponent<MonsterAgent>();
  71.             if (agent != null) agent.AddReward(-0.1f);
  72.  
  73.             agent = d.GetComponent<MonsterAgent>();
  74.             if (agent != null)
  75.             {
  76.                 if (u.faction != d.faction) AddReward(0.1f);
  77.                 else agent.AddReward(-0.05f);
  78.             }
  79.         };
  80.         UnitAI.OnStackEnd.AddListener(delegate {
  81.             UnitAI.currPlayCard = 0;
  82.         });
  83.     }
  84.  
  85.     // Reset
  86.     public override void OnEpisodeBegin() { }
  87.  
  88.     // Got data
  89.     public override void OnActionReceived(ActionBuffers actionBuffers)
  90.     {
  91.         //Debug.Log(string.Format("{0} Action Recived ({1}, ({2}, {3}))", transform.name, vectorAction[0], vectorAction[1], vectorAction[2]));
  92.         if (UnitAI.isMyTurn && !GameData.lockHand)
  93.         {
  94.             bool alreadyPlayedCard = false;
  95.  
  96.             // Check what cards that can be played
  97.             List<Card> availableCards = UnitAI.hand.Where(value =>
  98.             {
  99.                 List<Game.Keyword.CheckPlay> CheckPlay = value.keywords.Select(k => k.Get<Game.Keyword.CheckPlay>()).ToList();
  100.                 foreach (Game.Keyword.CheckPlay stack in CheckPlay)
  101.                 {
  102.                     if (stack != null && !stack.CheckPlay(UnitAI, value)) return false;
  103.                 }
  104.                 return true;
  105.             }).ToList();
  106.  
  107.  
  108.             if (availableCards.Exists(value => value.keywords.Exists(key => key.keyword == "Priority")))
  109.             {
  110.                 availableCards = availableCards.Where(value => value.keywords.Exists(key => key.keyword == "Priority")).ToList();
  111.             }
  112.  
  113.             foreach (int value in actionBuffers.DiscreteActions)
  114.             {
  115.                 Debug.Log(value);
  116.             }
  117.             if (actionBuffers.DiscreteActions[0] != 0 && UnitAI.currPlayCard < UnitAI.maxPlayCard)
  118.             {
  119.                 if (availableCards.Count > 0 && availableCards.Count <= actionBuffers.DiscreteActions[0])
  120.                 {
  121.                     int cardIndex = actionBuffers.DiscreteActions[0] - 1;
  122.                     int playIndex = availableCards.FindIndex(value => Database.main.cards[cardIndex].card.name == value.name);
  123.  
  124.                     latestCard = Helper.StringToInt(availableCards[playIndex].Description());
  125.                     UnitAI.AddToStack(availableCards[cardIndex]);
  126.                     UnitAI.currPlayCard++;
  127.                     alreadyPlayedCard = true;
  128.                 }
  129.  
  130.                 if (UnitAI.currPlayCard == UnitAI.maxPlayCard || availableCards.Count == 0)
  131.                 {
  132.                     UnitAI.isMyTurn = false;
  133.                     // AddReward(0.1f);
  134.                 }
  135.             }
  136.             else if (actionBuffers.DiscreteActions[0] == 0)
  137.             {
  138.                 UnitAI.isMyTurn = false;
  139.             }
  140. /*else if (UnitAI.currPlayCard <= UnitAI.maxPlayCard && !alreadyPlayedCard)
  141.             {
  142.                 AddReward(-0.01f);
  143.             }*/
  144.  
  145.         }
  146.  
  147.  
  148.         // Should pick a tile based on what card it played.
  149.         if (GameData.lockHand)
  150.         {
  151.             Vector3Int selectedPos =  new Vector3Int(
  152.                     Mathf.FloorToInt(transform.position.x + actionBuffers.DiscreteActions[1] - 15),
  153.                     Mathf.FloorToInt(transform.position.y + actionBuffers.DiscreteActions[2] - 15),
  154.                     0
  155.                 );
  156.  
  157.             GameData.main.Placeholder.Select(selectedPos);
  158.             //if (!GameData.lockHand) AddReward(0.02f);
  159.             //else AddReward(-0.1f);
  160.         }
  161.  
  162.         // if (StepCount > 0) AddReward(-0.00001f * AILearning.totalCount);
  163.         //Debug.Log(string.Format("{0} Reward: {1}", transform.name, GetCumulativeReward()));
  164.     }
  165.  
  166.     public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
  167.     {
  168.         // Get allowd cards to play
  169.         List<Card> availableCards = UnitAI.hand.Where(value =>
  170.         {
  171.             List<Game.Keyword.CheckPlay> CheckPlay = value.keywords.Select(k => k.Get<Game.Keyword.CheckPlay>()).ToList();
  172.             foreach (Game.Keyword.CheckPlay stack in CheckPlay)
  173.             {
  174.                 if (stack != null && !stack.CheckPlay(UnitAI, value)) return false;
  175.             }
  176.             return true;
  177.         }).ToList();
  178.  
  179.         // Get must card to play
  180.         if (availableCards.Exists(value => value.keywords.Exists(key => key.keyword == "Priority")))
  181.         {
  182.             availableCards = availableCards.Where(value => value.keywords.Exists(key => key.keyword == "Priority")).ToList();
  183.         }
  184.  
  185.         List<int> maskList = Database.main.cards
  186.             .Where(value => !availableCards
  187.                 .Select(v => v.name)
  188.                 .Contains(value.card.name)
  189.             ).Select((v) => Database.main.cards.IndexOf(v) + 1
  190.             ).ToList();
  191.  
  192.         maskList.AddRange(Enumerable.Range(Database.main.cards.Count, 1000 - Database.main.cards.Count));
  193.         Debug.Log(string.Join(", ", Enumerable.Range(0, 1000).ToList().Where(value => !maskList.Contains(value)).Select(v =>
  194.         {
  195.             if (v == 0) return "End Turn (0)";
  196.             return $"{Database.main.cards[v - 1].card.name} ({v})";
  197.         })));
  198.  
  199.         actionMask.WriteMask(0, maskList);
  200.  
  201.         // Mask the navigation
  202.         if (GameData.main.Placeholder.AvailableTiles.Count > 0)
  203.         {
  204.             List<int> X = Enumerable.Range(0, 25).Where(v => !GameData.main.Placeholder.AvailableTiles.Select(value => value.x).Contains((int) transform.position.x + (v - 15))).ToList();
  205.             actionMask.WriteMask(1, X);
  206.             List<int> Y = Enumerable.Range(0, 25).Where(v => !GameData.main.Placeholder.AvailableTiles.Select(value => value.y).Contains((int) transform.position.y + (v - 15))).ToList();
  207.             actionMask.WriteMask(2, Y);
  208.         }
  209.  
  210.     }
  211.  
  212.     // Give data
  213.     public override void CollectObservations(VectorSensor sensor)
  214.     {
  215.         AddUnitObservation(sensor, UnitAI);
  216.         sensor.AddObservation(UnitAI.currPlayCard);
  217.         sensor.AddObservation(UnitAI.maxPlayCard);
  218.         sensor.AddObservation(UnitAI.hand.Count);
  219.         sensor.AddObservation(UnitAI.startDeck.Count);
  220.  
  221.         sensor.AddObservation(UnitAI.isMyTurn);
  222.         sensor.AddObservation(GameData.lockHand);
  223.         sensor.AddObservation(GameData.main.Placeholder.Range);
  224.         sensor.AddObservation(latestCard);
  225.  
  226.         localUnits = transform.parent.GetComponentsInChildren<AI>(); // Update the list incase of removed units
  227.         sensor.AddObservation(localUnits.Count());
  228.         List<AI> listUnits = localUnits.Where(value => value != UnitAI).OrderBy(value => Vector2.Distance(value.transform.position, UnitAI.transform.position)).Take(10).ToList();
  229.         listUnits.ForEach(value => {
  230.             AddUnitObservation(sensor, value);
  231.         });
  232.  
  233.         for (int x = 10 - listUnits.Count; x > 0; x--)
  234.         {
  235.             sensor.AddObservation(0); // Name
  236.             sensor.AddObservation(Vector3.zero); // Pos
  237.             sensor.AddObservation(0); // Health
  238.             sensor.AddObservation(0); // Max Health
  239.             sensor.AddObservation(0); // Defence
  240.             sensor.AddObservation(0); // Speed
  241.             sensor.AddObservation(0); // Faction
  242.             sensor.AddObservation(0); // Facing
  243.         }
  244.  
  245.         UnitAI.hand.Take(15).ToList().ForEach(value => sensor.AddObservation(Database.main.cards.FindIndex(v => v.card.name == value.name) + 1));
  246.         for (int x = 15 - UnitAI.hand.Take(15).ToList().Count(); x > 0; x--)
  247.         {
  248.             sensor.AddObservation(0);
  249.         }
  250.  
  251.         GameData.main.Placeholder.LayoutObservation(sensor, transform.position, 10);
  252.  
  253.         // How many of each card do i have in my deck
  254.         /*UnitAI.startDeck.Where(value => Database.main.cards.Contains(value)).Take(100).ToList().ForEach(value => sensor.AddObservation(Database.main.cards.Find(v => value == v).GetInstanceID()));
  255.         for (int x = 100 - UnitAI.startDeck.Take(100).ToList().Count(); x > 0; x--)
  256.         {
  257.             sensor.AddObservation(0);
  258.         }*/
  259.     }
  260.  
  261.     private void AddUnitObservation(VectorSensor sensor, AI unit) // 7 observations
  262.     {
  263.         sensor.AddObservation(Helper.StringToInt(unit.transform.name));
  264.         sensor.AddObservation(unit.transform.position);
  265.         sensor.AddObservation(unit.health);
  266.         sensor.AddObservation(unit.maxHealth);
  267.         sensor.AddObservation(unit.defence);
  268.         sensor.AddObservation(unit.speed);
  269.         sensor.AddObservation((int)unit.faction);
  270.         sensor.AddObservation((int)unit.facing);
  271.     }
  272. }
  273.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement