Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System.Collections;
- using System.Collections.Generic;
- using UnityEngine;
- using Unity.MLAgents;
- using Unity.MLAgents.Sensors;
- public class FarmerAgent : Agent
- {
- public List<Transform> Mines;
- public List<Transform> Base;
- public float life = 0;
- public float timeBetweenCollect = 0;
- public int gold = 0;
- public bool collectedGold = false;
- public int collideType = 0;
- Rigidbody rBody;
- void Start()
- {
- rBody = GetComponent<Rigidbody>();
- }
- public override void OnEpisodeBegin()
- {
- // If the Agent fell, zero its momentum
- if (this.transform.localPosition.y < 0)
- {
- this.rBody.angularVelocity = Vector3.zero;
- this.rBody.velocity = Vector3.zero;
- this.transform.localPosition = new Vector3(1.56f, 0.5f, 0.2f);
- }
- RandomizePositions();
- //reset values
- gold = 0;
- collectedGold = false;
- life = 0;
- timeBetweenCollect = 0;
- collideType = -1;
- }
- static public float LengthDirX(float length, float direction)
- {
- return Mathf.Cos(direction * Mathf.Deg2Rad) * length;
- }
- static public float LengthDirY(float length, float direction)
- {
- return Mathf.Sin(direction * Mathf.Deg2Rad) * length;
- }
- private void RandomizePositions()
- {
- float angle = Random.value * 360f;
- float radius = 3.4f;
- float stageWidth = 8f;
- float stageHeight = 8f;
- for (int i = 0; i < Mines.Count; i++)
- {
- Mines[i].GetComponent<Pickup>().collectable = true;
- Mines[i].GetComponent<Pickup>().goldCount = 50;
- Mines[i].localScale = new Vector3(1, 1, 1);
- if (i > 0) Mines[i].gameObject.SetActive(Random.value > 0.5f);
- Mines[i].localPosition = new Vector3(LengthDirX(radius,angle), 0, LengthDirY(radius,angle));
- }
- bool colliding = true;
- for (int i = 0; i < Base.Count; i++)
- {
- if (i > 0) Base[i].gameObject.SetActive(Random.value > 0.5f);
- colliding = true;
- Base[i].localScale = new Vector3(1,0.25f,1);
- Base[i].localPosition = new Vector3(LengthDirX(radius, angle+180f), 0, LengthDirY(radius, angle+180f));
- Base[i].GetComponent<Pickup>().goldCount = 0;
- while (colliding)
- {
- colliding = false;
- Base[i].GetComponent<Pickup>().collectable = false;
- }
- }
- this.transform.localPosition = new Vector3(0, 0.5f, 0);
- }
- public override void CollectObservations(VectorSensor sensor)
- {
- for (int i = 0; i < Mines.Count; i++)
- {
- sensor.AddObservation(Mines[i].localPosition);
- sensor.AddObservation(Mines[i].gameObject.activeSelf);
- sensor.AddObservation(Mines[i].GetComponent<Pickup>().collectable);
- sensor.AddObservation(Mines[i].GetComponent<Pickup>().type);
- Mines[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
- sensor.AddObservation(Mines[i].GetComponent<Pickup>().calculatedDistanceToAgent);
- // sensor.AddObservation(Mines[i].GetComponent<Pickup>().goldCount);
- // sensor.AddObservation(new Quaternion(Mines[i].localPosition.x, Mines[i].localPosition.y, Mines[i].localPosition.z, Mines[i].GetComponent<Pickup>().type));
- }
- for (int i = 0; i < Base.Count; i++)
- {
- sensor.AddObservation(Base[i].localPosition);
- sensor.AddObservation(Base[i].gameObject.activeSelf);
- sensor.AddObservation(Base[i].GetComponent<Pickup>().collectable);
- sensor.AddObservation(Base[i].GetComponent<Pickup>().type);
- Base[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
- sensor.AddObservation(Base[i].GetComponent<Pickup>().calculatedDistanceToAgent);
- // sensor.AddObservation(Base[i].GetComponent<Pickup>().goldCount);
- //sensor.AddObservation(new Quaternion(Base[i].localPosition.x, Base[i].localPosition.y, Base[i].localPosition.z, Base[i].GetComponent<Pickup>().type));
- }
- sensor.AddObservation(this.transform.localPosition);
- sensor.AddObservation(collectedGold);
- sensor.AddObservation(timeBetweenCollect);
- sensor.AddObservation(life);
- sensor.AddObservation(collideType);
- }
- public float forceMultiplier = 10;
- public override void OnActionReceived(float[] vectorAction)
- {
- collideType = -1;
- life += Time.deltaTime;
- timeBetweenCollect += Time.deltaTime;
- AddReward(-0.01f * Time.deltaTime);
- // Actions, size = 2
- Vector3 controlSignal = Vector3.zero;
- controlSignal.x = vectorAction[0];
- controlSignal.z = vectorAction[1];
- //rBody.AddForce(controlSignal * forceMultiplier);
- transform.localPosition += controlSignal * forceMultiplier * Time.deltaTime;
- // Rewards
- if (collectedGold == false)
- {
- for (int i = 0; i < Mines.Count; i++)
- {
- if (Mines[i].gameObject.activeSelf == false) continue;
- float distanceToMine = Vector3.Distance(this.transform.localPosition, Mines[i].localPosition);
- if (distanceToMine < 1f)
- {
- if (Mines[i].GetComponent<Pickup>().goldCount > 0)
- {
- Mines[i].GetComponent<Pickup>().goldCount--;
- Mines[i].localScale -= new Vector3(0, 0.5f, 0);
- collectedGold = true;
- collideType = 0;
- AddReward(1f);
- }
- }
- }
- }
- else
- {
- for (int i = 0; i < Base.Count; i++)
- {
- if (Base[i].gameObject.activeSelf == false) continue;
- float distanceToMine = Vector3.Distance(this.transform.localPosition, Base[i].localPosition);
- if (distanceToMine < 1f)
- {
- collectedGold = false;
- gold++;
- Base[i].GetComponent<Pickup>().goldCount++;
- Base[i].localScale += new Vector3(0, 0.5f, 0);
- AddReward(10f);
- timeBetweenCollect = 0f;
- collideType = 1;
- life -= 0.8f;
- }
- }
- }
- for (int i = 0; i < Mines.Count; i++)
- {
- Mines[i].GetComponent<Pickup>().collectable = collectedGold == false;
- Mines[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
- }
- for (int i = 0; i < Base.Count; i++)
- {
- Base[i].GetComponent<Pickup>().collectable = collectedGold;
- Base[i].GetComponent<Pickup>().CalculateDistanceToAgent(transform);
- }
- // Life expired
- if (life > 1.6f)
- {
- EndEpisode();
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement