I've been recently working on a self-balancing pair of legs that are supposed to try to keep the character from falling over. Each AgentReset, the legs reset all of its necessary factors such as pos, rot, and vel, and the floor beneath the character changes the rotation randomly by less than 5 degrees. Though, no matter what I seem to do with the number of observations the agent takes in, he still doesn't really seem to learn from his mistakes. Now, I'm new to machine-learning so go easy on me! What am I missing? Thank you!
Some notes: I'm not really sure how RayPerceptionSensorComponent3D works. Maybe someone can put me in the right direction if that might help.
Agent Script :
using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;
public class BalanceAgent : Agent
{
private BalancingArea area;
public GameObject floor;
public GameObject waist;
public GameObject buttR;
public GameObject buttL;
public GameObject thighR;
public GameObject thighL;
public GameObject legR;
public GameObject legL;
public GameObject footR;
public GameObject footL;
public GameObject[] bodyParts = new GameObject[9];
public HingeJoint[] hingeParts = new HingeJoint[9];
public JointLimits[] jntLimParts = new JointLimits[9];
public Vector3[] posStart = new Vector3[9];
public Vector3[] eulerStart = new Vector3[9];
public RayPerceptionSensorComponent3D raySensors;
float rayDist = 0;
float rayAngle = 0;
Vector3 rayFloorAngle = new Vector3(0,0,0);
Vector3 rayPoint = new Vector3(0,0,0);
int rotAgent = 0;
public void Start() {
bodyParts = new GameObject[] { waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL }; //Waist = 0, footL = 8.
for (int i = 0; i < bodyParts.Length; i++) {
posStart[i] = bodyParts[i].transform.position;
eulerStart[i] = bodyParts[i].transform.eulerAngles;
if (bodyParts[i].GetComponent<HingeJoint>() != null) {
hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
jntLimParts[i] = hingeParts[i].limits;
}
}
}
public override void InitializeAgent() {
base.InitializeAgent();
area = GetComponentInParent<BalancingArea>();
}
public override void AgentReset() {
floor.transform.eulerAngles = new Vector3(Random.Range(-5, 5), 0, Random.Range(-5, 5));
print("Reset! - " + rotAgent);
for (int i = 0; i < bodyParts.Length; i++) {
bodyParts[i].transform.position = posStart[i];
bodyParts[i].transform.eulerAngles = eulerStart[i];
if (bodyParts[i].GetComponent<HingeJoint>() != null) {
jntLimParts[i].max = 1;
jntLimParts[i].min = -1;
}
bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
}
}
public override void AgentAction(float[] vectorAction) {
float buttRDir = 0;
int buttRVec = (int)vectorAction[0];
switch (buttRVec) {
case 1:
buttRDir = 0;
break;
case 2:
buttRDir = .2f;
break;
case 3:
buttRDir = -.2f;
break;
}
jntLimParts[1].max += buttRDir;
jntLimParts[1].min = jntLimParts[1].max - 1;
hingeParts[1].limits = jntLimParts[1];
float buttLDir = 0;
int buttLVec = (int)vectorAction[1];
switch (buttLVec) {
case 1:
buttLDir = 0;
break;
case 2:
buttLDir = .2f;
break;
case 3:
buttLDir = -.2f;
break;
}
jntLimParts[2].max += buttLDir;
jntLimParts[2].min = jntLimParts[2].max - 1;
hingeParts[2].limits = jntLimParts[2];
float thighRDir = 0;
int thighRVec = (int)vectorAction[2];
switch (thighRVec) {
case 1:
thighRDir = 0;
break;
case 2:
thighRDir = .2f;
break;
case 3:
thighRDir = -.2f;
break;
}
jntLimParts[3].max += thighRDir;
jntLimParts[3].min = jntLimParts[3].max - 1;
hingeParts[3].limits = jntLimParts[3];
float thighLDir = 0;
int thighLVec = (int)vectorAction[3];
switch (thighLVec) {
case 1:
thighLDir = 0;
break;
case 2:
thighLDir = .2f;
break;
case 3:
thighLDir = -.2f;
break;
}
jntLimParts[4].max += thighLDir;
jntLimParts[4].min = jntLimParts[4].max - 1;
hingeParts[4].limits = jntLimParts[4];
float legRDir = 0;
int legRVec = (int)vectorAction[4];
switch (legRVec) {
case 1:
legRDir = 0;
break;
case 2:
legRDir = .2f;
break;
case 3:
legRDir = -.2f;
break;
}
jntLimParts[5].max += legRDir;
jntLimParts[5].min = jntLimParts[5].max - 1;
hingeParts[5].limits = jntLimParts[5];
float legLDir = 0;
int legLVec = (int)vectorAction[5];
switch (legLVec) {
case 1:
legLDir = 0;
break;
case 2:
legLDir = .2f;
break;
case 3:
legLDir = -.2f;
break;
}
jntLimParts[6].max += legLDir;
jntLimParts[6].min = jntLimParts[6].max - 1;
hingeParts[6].limits = jntLimParts[6];
float footRDir = 0;
int footRVec = (int)vectorAction[6];
switch (footRVec) {
case 1:
footRDir = 0;
break;
case 2:
footRDir = .2f;
break;
case 3:
footRDir = -.2f;
break;
}
jntLimParts[7].max += footRDir;
jntLimParts[7].min = jntLimParts[7].max - 1;
hingeParts[7].limits = jntLimParts[7];
float footLDir = 0;
int footLVec = (int)vectorAction[7];
switch (footLVec) {
case 1:
footLDir = 0;
break;
case 2:
footLDir = .2f;
break;
case 3:
footLDir = -.2f;
break;
}
jntLimParts[8].max += footLDir;
jntLimParts[8].min = jntLimParts[8].max - 1;
hingeParts[8].limits = jntLimParts[8];
float waistDir = 0;
int waistVec = (int)vectorAction[8];
switch (footLVec) {
case 1:
waistDir = 0;
break;
case 2:
waistDir = .2f;
break;
case 3:
waistDir = -.2f;
break;
}
// waist.transform.Rotate(0, waistDir, 0);
//buttR = vectorAction[0]; //Right or none
//if (buttR == 2) buttR = -1f; //Left
if (waist.transform.position.y > -1.4f) {
AddReward(.02f);
}
else {
AddReward(-.03f);
}
if (waist.transform.position.y <= -3) {
Done();
print("He fell too far...");
}
RaycastHit hit;
Ray r;
if (Physics.Raycast(waist.transform.position, -waist.transform.up, out hit)) {
rayDist = hit.distance;
rayPoint = hit.point;
rayAngle = Vector3.Angle(waist.transform.position, hit.normal);
rayFloorAngle = hit.collider.transform.eulerAngles;
}
}
public override void CollectObservations() {
for(int i = 0; i < bodyParts.Length; i++) {
AddVectorObs(bodyParts[i].transform.position);
AddVectorObs(bodyParts[i].transform.eulerAngles);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
AddVectorObs(jntLimParts[i].max);
AddVectorObs(jntLimParts[i].min);
AddVectorObs(raySensors);
AddVectorObs(rayDist);
AddVectorObs(rayPoint);
AddVectorObs(rayAngle);
AddVectorObs(rayFloorAngle);
}
}
}
Area Script :
using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
public class BalancingArea : Area
{
public List<BalanceAgent> BalanceAgent { get; private set; }
public BalanceAcademy BalanceAcademy { get; private set; }
public GameObject area;
private void Awake() {
BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList(); //Grabs all agents in area
BalanceAcademy = FindObjectOfType<BalanceAcademy>(); //Grabs balance acedemy
}
private void Start() {
}
public void ResetAgentPosition(BalanceAgent agent) {
//agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
// agent.transform.eulerAngles = new Vector3(0,0,0);
}
// Update is called once per frame
void Update()
{
}
}
Academy Script :
using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class BalanceAcademy : Academy
{
}