Advertisement
Guest User

Untitled

a guest
Jan 21st, 2021
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 7.10 KB | None | 0 0
  1. using System.Collections;
  2. using System.Collections.Generic;
  3. using UnityEngine;
  4.  
  5. using Unity.MLAgents;
  6. using Unity.MLAgents.Sensors;
  7.  
  8. public class FarmerAgent : Agent
  9. {
  10.     public List<Transform> Mines;
  11.     public List<Transform> Base;
  12.  
  13.  
  14.     public float life = 0;
  15.     public float timeBetweenCollect = 0;
  16.     public int gold = 0;
  17.     public bool collectedGold = false;
  18.     public int collideType = 0;
  19.  
  20.     Rigidbody rBody;
  21.     void Start()
  22.     {
  23.         rBody = GetComponent<Rigidbody>();
  24.     }
  25.  
  26.     public override void OnEpisodeBegin()
  27.     {
  28.         // If the Agent fell, zero its momentum
  29.         if (this.transform.localPosition.y < 0)
  30.         {
  31.             this.rBody.angularVelocity = Vector3.zero;
  32.             this.rBody.velocity = Vector3.zero;
  33.             this.transform.localPosition = new Vector3(1.56f, 0.5f, 0.2f);
  34.         }
  35.  
  36.         RandomizePositions();
  37.  
  38.         //reset values
  39.         gold = 0;
  40.         collectedGold = false;
  41.         life = 0;
  42.         timeBetweenCollect = 0;
  43.         collideType = -1;
  44.     }
  45.  
  46.     static public float LengthDirX(float length, float direction)
  47.     {
  48.         return Mathf.Cos(direction * Mathf.Deg2Rad) * length;
  49.     }
  50.  
  51.     static public float LengthDirY(float length, float direction)
  52.     {
  53.         return Mathf.Sin(direction * Mathf.Deg2Rad) * length;
  54.     }
  55.  
  56.     private void RandomizePositions()
  57.     {
  58.         float angle = Random.value * 360f;
  59.         float radius = 3.4f;
  60.         float stageWidth = 8f;
  61.         float stageHeight = 8f;
  62.         for (int i = 0; i < Mines.Count; i++)
  63.         {
  64.             Mines[i].GetComponent<Pickup>().collectable = true;
  65.             Mines[i].GetComponent<Pickup>().goldCount = 50;
  66.             Mines[i].localScale = new Vector3(1, 1, 1);
  67.             if (i > 0) Mines[i].gameObject.SetActive(Random.value > 0.5f);
  68.             Mines[i].localPosition = new Vector3(LengthDirX(radius,angle), 0, LengthDirY(radius,angle));
  69.         }
  70.  
  71.         bool colliding = true;
  72.         for (int i = 0; i < Base.Count; i++)
  73.         {
  74.             if (i > 0) Base[i].gameObject.SetActive(Random.value > 0.5f);
  75.             colliding = true;
  76.             Base[i].localScale = new Vector3(1,0.25f,1);
  77.             Base[i].localPosition = new Vector3(LengthDirX(radius, angle+180f), 0, LengthDirY(radius, angle+180f));
  78.             Base[i].GetComponent<Pickup>().goldCount = 0;
  79.             while (colliding)
  80.             {
  81.                 colliding = false;
  82.                 Base[i].GetComponent<Pickup>().collectable = false;
  83.             }
  84.         }
  85.  
  86.         this.transform.localPosition = new Vector3(0, 0.5f, 0);
  87.     }
  88.  
  89.     public override void CollectObservations(VectorSensor sensor)
  90.     {
  91.         for (int i = 0; i < Mines.Count; i++)
  92.         {
  93.             sensor.AddObservation(Mines[i].localPosition);
  94.             sensor.AddObservation(Mines[i].gameObject.activeSelf);
  95.             sensor.AddObservation(Mines[i].GetComponent<Pickup>().collectable);
  96.             sensor.AddObservation(Mines[i].GetComponent<Pickup>().type);
  97.             Mines[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
  98.             sensor.AddObservation(Mines[i].GetComponent<Pickup>().calculatedDistanceToAgent);
  99.           //  sensor.AddObservation(Mines[i].GetComponent<Pickup>().goldCount);
  100.             // sensor.AddObservation(new Quaternion(Mines[i].localPosition.x, Mines[i].localPosition.y, Mines[i].localPosition.z, Mines[i].GetComponent<Pickup>().type));
  101.         }
  102.  
  103.         for (int i = 0; i < Base.Count; i++)
  104.         {
  105.             sensor.AddObservation(Base[i].localPosition);
  106.             sensor.AddObservation(Base[i].gameObject.activeSelf);
  107.             sensor.AddObservation(Base[i].GetComponent<Pickup>().collectable);
  108.             sensor.AddObservation(Base[i].GetComponent<Pickup>().type);
  109.          
  110.             Base[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
  111.             sensor.AddObservation(Base[i].GetComponent<Pickup>().calculatedDistanceToAgent);
  112.            // sensor.AddObservation(Base[i].GetComponent<Pickup>().goldCount);
  113.             //sensor.AddObservation(new Quaternion(Base[i].localPosition.x, Base[i].localPosition.y, Base[i].localPosition.z, Base[i].GetComponent<Pickup>().type));
  114.         }
  115.  
  116.         sensor.AddObservation(this.transform.localPosition);
  117.         sensor.AddObservation(collectedGold);
  118.         sensor.AddObservation(timeBetweenCollect);
  119.         sensor.AddObservation(life);
  120.         sensor.AddObservation(collideType);
  121.     }
  122.      
  123.     public float forceMultiplier = 10;
  124.     public override void OnActionReceived(float[] vectorAction)
  125.     {
  126.  
  127.         collideType = -1;
  128.         life += Time.deltaTime;
  129.         timeBetweenCollect += Time.deltaTime;
  130.         AddReward(-0.01f * Time.deltaTime);
  131.  
  132.         // Actions, size = 2
  133.         Vector3 controlSignal = Vector3.zero;
  134.         controlSignal.x = vectorAction[0];
  135.         controlSignal.z = vectorAction[1];
  136.         //rBody.AddForce(controlSignal * forceMultiplier);
  137.  
  138.         transform.localPosition += controlSignal * forceMultiplier * Time.deltaTime;
  139.          
  140.         // Rewards
  141.         if (collectedGold == false)
  142.         {
  143.             for (int i = 0; i < Mines.Count; i++)
  144.             {
  145.                 if (Mines[i].gameObject.activeSelf == false) continue;
  146.                 float distanceToMine = Vector3.Distance(this.transform.localPosition, Mines[i].localPosition);
  147.                 if (distanceToMine < 1f)
  148.                 {
  149.                     if (Mines[i].GetComponent<Pickup>().goldCount > 0)
  150.                     {
  151.                         Mines[i].GetComponent<Pickup>().goldCount--;
  152.                         Mines[i].localScale -= new Vector3(0, 0.5f, 0);
  153.                         collectedGold = true;
  154.                         collideType = 0;
  155.                         AddReward(1f);
  156.                     }
  157.                 }
  158.             }
  159.         }
  160.         else
  161.         {
  162.             for (int i = 0; i < Base.Count; i++)
  163.             {
  164.                 if (Base[i].gameObject.activeSelf == false) continue;
  165.                 float distanceToMine = Vector3.Distance(this.transform.localPosition, Base[i].localPosition);
  166.                 if (distanceToMine < 1f)
  167.                 {
  168.                     collectedGold = false;
  169.                     gold++;
  170.                     Base[i].GetComponent<Pickup>().goldCount++;
  171.                     Base[i].localScale += new Vector3(0, 0.5f, 0);
  172.                     AddReward(10f);
  173.                     timeBetweenCollect = 0f;
  174.                     collideType = 1;
  175.                     life -= 0.8f;
  176.                 }
  177.             }
  178.         }
  179.  
  180.  
  181.         for (int i = 0; i < Mines.Count; i++)
  182.         {
  183.             Mines[i].GetComponent<Pickup>().collectable = collectedGold == false;
  184.             Mines[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
  185.         }
  186.         for (int i = 0; i < Base.Count; i++)
  187.         {
  188.             Base[i].GetComponent<Pickup>().collectable = collectedGold;
  189.             Base[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
  190.         }
  191.          
  192.         // Life expired
  193.         if (life > 1.6f)
  194.         {  
  195.             EndEpisode();
  196.         }  
  197.     }
  198.  
  199. }
  200.  
  201.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement