Crawler in ML agents case

This case comes from the official example of ML agents, Github address: https://github.com/Unity-Technologies/ml-agents

Based on my previous two articles, I need to know about ml agents. For details, see: The use of ML agents in Unity reinforcement learning,Ml agents command and configuration.

reference material: Ml agents (x) Crawler

The task of 3DBall in the last run is relatively simple. You only need to stop the ball above the box, the input dimension is low, and the setting of reward function is relatively simple, so you can train a better effect soon. Next, train for a more challenging task.

As shown in the figure above, we need to train a four legged simulation robot to learn to stand, walk towards the target, and finally eat green squares. The faster the process, the better.

Environment explanation

The environment of the robot is a flat ground with friction, surrounded by four walls, and there must be a green square as the target of the robot.

The environment of the agent is very simple, but the agent itself is not simple at all.

The agent robot itself is divided into body trunk and four legs, and each leg is divided into forelimb and hindlimb. Therefore, it has eight joints.

Configure joints

reference material: Learn more about Unity Configurable Joints,Configurable joint

Only the important parameters used are explained here. First of all, four forelimbs (close to the body) and four hind limbs need to use this component. The hind limb is the sub object of the forelimb, so that the hind limb can follow the forelimb. For forelimbs, we need to set Angular Y Motion and Angular X Motion to Limited, others to Locked, and for hind limbs, we only need to set Angular X Motion to Limited. Then click the Edit Angular Limits button to set the joint position and rotation angle, which can be achieved by setting Anchor and Axis.

Touchdown detection

It can be seen that each part of the body is configured with a Ground Contact script, which can detect whether that part contacts the ground.

using UnityEngine;
using Unity.MLAgents;

namespace Unity.MLAgentsExamples
{

    [DisallowMultipleComponent]
    public class GroundContact : MonoBehaviour
    {
        [HideInInspector] public Agent agent;

        [Header("Ground Check")] public bool agentDoneOnGroundContact; // Whether to reset agent on ground contact.
        public bool penalizeGroundContact; // Whether to penalize on contact.
        public float groundContactPenalty; // Penalty amount (ex: -1).
        public bool touchingGround;
        const string k_Ground = "ground"; // Tag of ground object.

        // When entering the collision, set touchingGround to true and give punishment to judge whether the game is over
        void OnCollisionEnter(Collision col)
        {
            if (col.transform.CompareTag(k_Ground))
            {
                touchingGround = true;
                if (penalizeGroundContact)
                {
                    agent.SetReward(groundContactPenalty);
                }
                if (agentDoneOnGroundContact)
                {
                    agent.EndEpisode();
                }
            }
        }
        ///Toucharound is set to false when exiting the collision. Judge not to touch the ground
        void OnCollisionExit(Collision other)
        {
            if (other.transform.CompareTag(k_Ground))
            {
                touchingGround = false;
            }
        }
    }
}

code analysis

Now we can formally see which scripts are hung on the agent.

The first is the Behavior Parameters that remain unchanged for ten thousand years. The input vector is 32 dimensions and the output continuous action is 20 dimensions.

Then there is the Decision Requester that remains unchanged for ten thousand years, and Take Actions Between Decisions is set to false.

Then, the model override, which remains unchanged for ten thousand years, is also arranged to allow the model to be covered during training.

Joint Drive Controller

Let's explain the Joint Drive Controller. This script is responsible for controlling each joint.

Let's first look at the BodyPart method:

/// <summary>
    ///It is used to store the action and learning related information of each body part of the agent
    /// </summary>
    [System.Serializable]
    public class BodyPart
    {
        [Header("Body Part Info")] [Space(10)] public ConfigurableJoint joint;//Configurable joint assembly of the body
        public Rigidbody rb;//rigid body
        [HideInInspector] public Vector3 startingPos;//Starting position
        [HideInInspector] public Quaternion startingRot;//Starting angle

        [Header("Ground & Target Contact")]
        [Space(10)]
        public GroundContact groundContact;//Detect ground contact
        public TargetContact targetContact;//Detect target contact

        [FormerlySerializedAs("thisJDController")]
        [HideInInspector] public JointDriveController thisJdController;//Joint component Controller

        [Header("Current Joint Settings")]
        [Space(10)]
        public Vector3 currentEularJointRotation;//Current Euler angle of joint

        [HideInInspector] public float currentStrength;//Current force
        public float currentXNormalizedRot;
        public float currentYNormalizedRot;
        public float currentZNormalizedRot;

        [Header("Other Debug Info")]
        [Space(10)]
        public Vector3 currentJointForce;//Current joint force

        public float currentJointForceSqrMag;//Current joint force
        public Vector3 currentJointTorque;//Current joint torque
        public float currentJointTorqueSqrMag;//Current joint torque
        public AnimationCurve jointForceCurve = new AnimationCurve();//Joint force curve
        public AnimationCurve jointTorqueCurve = new AnimationCurve();//Joint torque curve

        /// <summary>
        ///Data initialization
        /// </summary>
        public void Reset(BodyPart bp)
        {
            bp.rb.transform.position = bp.startingPos;//position
            bp.rb.transform.rotation = bp.startingRot;//angle
            bp.rb.velocity = Vector3.zero;//speed
            bp.rb.angularVelocity = Vector3.zero;//angular velocity
            if (bp.groundContact)
            {//Ground contact sign setting
                bp.groundContact.touchingGround = false;
            }

            if (bp.targetContact)
            {//Target contact flag setting
                bp.targetContact.touchingTarget = false;
            }
        }

        /// <summary>
        ///The torque is calculated according to the given x,y,z angles and the magnitude of the force
        /// </summary>
        public void SetJointTargetRotation(float x, float y, float z)
        {
            x = (x + 1f) * 0.5f;
            y = (y + 1f) * 0.5f;
            z = (z + 1f) * 0.5f;

            //Mathf. LERP (from: float, to: float, t: float) interpolation, t=0~1, return (to from) * t
            var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x);
            var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y);
            var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z);

            //Mathf. Inverselerp (from: float, to: float, value: float) inverse interpolation to return the proportional value of value between from and to
            currentXNormalizedRot = Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot);
            currentYNormalizedRot = Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot);
            currentZNormalizedRot = Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot);

            joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot);//Turn the joint to the target angle
            currentEularJointRotation = new Vector3(xRot, yRot, zRot);//Euler angle of current joint
        }
        /// <summary>
        ///Sets the amount of joint force
        /// </summary>
        /// <param name="strength"></param>
        public void SetJointStrength(float strength)
        {
            var rawVal = (strength + 1f) * 0.5f * thisJdController.maxJointForceLimit;
            var jd = new JointDrive
            {
                positionSpring = thisJdController.maxJointSpring,//Maximum joint elasticity
                positionDamper = thisJdController.jointDampen,//Joint elasticity
                maximumForce = rawVal//Maximum force applied
            };
            joint.slerpDrive = jd;
            currentStrength = jd.maximumForce;//Current applied force
        }
    } 

This script is mainly used to manage multiple bodyparts. At the same time, it can update the force and torque of each part of the body in real time, so that the Agent can collect relevant information of bodyparts.

JointDriveController method:

 /// <summary>
    ///Joint controller
    /// </summary>
    public class JointDriveController : MonoBehaviour
    {
        [Header("Joint Drive Settings")]
        [Space(10)]
        public float maxJointSpring;//Maximum joint elasticity
        public float jointDampen;//The strength of the joint against elasticity
        public float maxJointForceLimit;//Maximum force
        //float m_FacingDot;// This variable is not used

        //Body parts dictionary
        [HideInInspector] public Dictionary<Transform, BodyPart> bodyPartsDict = new Dictionary<Transform, BodyPart>();

        /// <summary>
        ///Create a BodyPart object and add it to the dictionary
        /// </summary>
        public void SetupBodyPart(Transform t)
        {
            var bp = new BodyPart
            {
                rb = t.GetComponent<Rigidbody>(),
                joint = t.GetComponent<ConfigurableJoint>(),
                startingPos = t.position,
                startingRot = t.rotation
            };
            bp.rb.maxAngularVelocity = 100;//The maximum angular velocity is 100

            //Add ground collision detection script
            bp.groundContact = t.GetComponent<GroundContact>();
            if (!bp.groundContact)
            {
                bp.groundContact = t.gameObject.AddComponent<GroundContact>();
                bp.groundContact.agent = gameObject.GetComponent<Agent>();
            }
            else
            {
                bp.groundContact.agent = gameObject.GetComponent<Agent>();
            }

            //Add target collision detection script
            bp.targetContact = t.GetComponent<TargetContact>();
            if (!bp.targetContact)
            {
                bp.targetContact = t.gameObject.AddComponent<TargetContact>();
            }

            bp.thisJdController = this;
            bodyPartsDict.Add(t, bp);
        }
        /// <summary>
        ///Update the current force and torque of each part of the body
        /// </summary>
        public void GetCurrentJointForces()
        {
            foreach (var bodyPart in bodyPartsDict.Values)
            {//Poll every part of the body
                if (bodyPart.joint)
                {
                    bodyPart.currentJointForce = bodyPart.joint.currentForce;//Current joint force
                    bodyPart.currentJointForceSqrMag = bodyPart.joint.currentForce.magnitude;//Current joint force
                    bodyPart.currentJointTorque = bodyPart.joint.currentTorque;//Current joint action torque
                    bodyPart.currentJointTorqueSqrMag = bodyPart.joint.currentTorque.magnitude;//Current joint torque
                    if (Application.isEditor)
                    {//Under IDE, create curves of joint force and joint torque
                        if (bodyPart.jointForceCurve.length > 1000)
                        {
                            bodyPart.jointForceCurve = new AnimationCurve();
                        }

                        if (bodyPart.jointTorqueCurve.length > 1000)
                        {
                            bodyPart.jointTorqueCurve = new AnimationCurve();
                        }

                        bodyPart.jointForceCurve.AddKey(Time.time, bodyPart.currentJointForceSqrMag);
                        bodyPart.jointTorqueCurve.AddKey(Time.time, bodyPart.currentJointTorqueSqrMag);
                    }
                }
            }
        }
    }

Although this script is mounted on the agent, it will not work by itself. It will only work when called by other scripts.

RigidBody Sensor Component

You can see that a script of RigidBody Sensor Component is also mounted under agents.

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{

    public class RigidBodySensorComponent : SensorComponent
    {
        public Rigidbody RootBody;

        /// Optional GameObject used to determine the root of the poses.
        public GameObject VirtualRoot;

        /// Settings defining what types of observations will be generated.
        [SerializeField]
        public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();

        /// Optional sensor name. This must be unique for each Agent.
        [SerializeField]
        public string sensorName;

        [SerializeField]
        [HideInInspector]
        RigidBodyPoseExtractor m_PoseExtractor;

        /// Creates a PhysicsBodySensor.
        public override ISensor[] CreateSensors()
        {
            var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
            return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
        }

        /// Get the DisplayNodes of the hierarchy.
        internal IList<PoseExtractor.DisplayNode> GetDisplayNodes()
        {
            return GetPoseExtractor().GetDisplayNodes();
        }

        /// Lazy construction of the PoseExtractor.
        RigidBodyPoseExtractor GetPoseExtractor()
        {
            if (m_PoseExtractor == null)
            {
                ResetPoseExtractor();
            }

            return m_PoseExtractor;
        }

        /// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
        internal void ResetPoseExtractor()
        {
            // Get the current enabled state of each body, so that we can reinitialize with them.
            Dictionary<Rigidbody, bool> bodyPosesEnabled = null;
            if (m_PoseExtractor != null)
            {
                bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled();
            }
            m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled);
        }

        /// Toggle the pose at the given index.
        internal void SetPoseEnabled(int index, bool enabled)
        {
            GetPoseExtractor().SetPoseEnabled(index, enabled);
        }

        internal bool IsTrivial()
        {
            if (ReferenceEquals(RootBody, null))
            {
                // It *is* trivial, but this will happen when the sensor is being set up, so don't warn then.
                return false;
            }
            var joints = RootBody.GetComponentsInChildren<Joint>();
            if (joints.Length == 0)
            {
                if (ReferenceEquals(VirtualRoot, null) || ReferenceEquals(VirtualRoot, RootBody.gameObject))
                {
                    return true;
                }
            }
            return false;
        }
    }

}

This component is a newly added experimental function. It is in the ml-agents.extensions package rather than in the main package. The following Hierachy is generated automatically by running. As long as we drag the Body object to the RootBody and the orientation cube to the VirtualRoot, we can use this component normally.

Similarly, this is a sensor that can obtain input by itself. In the CreateSensors method, new a PhysicsBodySensor, and this class inherits the ISensor interface, that is, it can obtain input by itself. The Write method of ISensor interface is used to generate actual observations.

When an agent uses a Joint, adding this component can make the agent train better. Specific functions to be explored.

Crawler Agent

Next is the play Crawler Agent script:

This component inherits from the Agent and is a component that truly realizes the key elements of reinforcement learning, such as the Agent obtaining input, obtaining output, defining reward, and defining the end of episode.

We assign the Transform, mesh rendering and materials of each part of the body of the agent one by one. Then look at what the methods achieve:

First look at the initialization method Initialize, which defines what needs to be done before the game starts:

 public override void Initialize()
 {
     // The following two lines were not added in the early version, but it was found that adding a pointing object to the agent can greatly increase the reward
     // The reasons are worth investigating
     SpawnTarget(TargetPrefab, transform.position); //spawn target
     m_OrientationCube = GetComponentInChildren<OrientationCubeController>();

     m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
     m_JdController = GetComponent<JointDriveController>();

     //Setup each body part
     m_JdController.SetupBodyPart(body);
     m_JdController.SetupBodyPart(leg0Upper);
     m_JdController.SetupBodyPart(leg0Lower);
     m_JdController.SetupBodyPart(leg1Upper);
     m_JdController.SetupBodyPart(leg1Lower);
     m_JdController.SetupBodyPart(leg2Upper);
     m_JdController.SetupBodyPart(leg2Lower);
     m_JdController.SetupBodyPart(leg3Upper);
     m_JdController.SetupBodyPart(leg3Lower);
 }
// Generate target box
void SpawnTarget(Transform prefab, Vector3 pos)
{
	m_Target = Instantiate(prefab, pos, Quaternion.identity, transform.parent);
}

Firstly, a target point is generated, and then the necessary components and the initialization of each joint are obtained.

Then the OnEpisodeBegin executed at the beginning of each episode:

public override void OnEpisodeBegin()
{
    // Reset all joints
    foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
    {
        bodyPart.Reset(bodyPart);
    }

    // Let the agent move randomly in one direction
    body.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);

    // Update the coordinates and rotation of an empty object on the agent (the function needs to be studied)
    UpdateOrientationObjects();

    // Set random target speed
    TargetWalkingSpeed = Random.Range(0.1f, m_maxWalkingSpeed);
}

Here is the old friend collectoobservations. Add the corresponding input to the input of the neural network:

public override void CollectObservations(VectorSensor sensor)
{
    var cubeForward = m_OrientationCube.transform.forward;

    //velocity we want to match
    var velGoal = cubeForward * TargetWalkingSpeed;
    // Gets the average velocity of the rigid body
    var avgVel = GetAvgVelocity();

    // Enter the distance between the average speed and the target speed. The dimension is 1
    sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
    // Enter the average speed of the rigid body of the agent relative to the cube on the body (think about why adding this cube object will make the training more effective), and the dimension is 3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
    // Enter the speed of the agent relative to the cube on the body. The dimension is 3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
    // Enter a quaternion rotation with dimension 4
    sensor.AddObservation(Quaternion.FromToRotation(body.forward, cubeForward));

    // Enter the position of the target point relative to the cube on the body. The dimension is 3
    sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(m_Target.transform.position));

    // Emit rays to measure the distance from the body to the ground, with a dimension of 1
    RaycastHit hit;
    float maxRaycastDist = 10;
    if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist))
    {
        sensor.AddObservation(hit.distance / maxRaycastDist);
    }
    else
        sensor.AddObservation(1);

    // Input from every part of the body
    foreach (var bodyPart in m_JdController.bodyPartsList)
    {
        CollectObservationBodyPart(bodyPart, sensor);
    }
}

Reference articles on coordinate systems: https://www.sohu.com/a/221556633_ six hundred and sixty-seven thousand nine hundred and twenty-eight

public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{   
    // Whether the input touches the ground, there are 9 inputs here
    sensor.AddObservation(bp.groundContact.touchingGround); 

    // If it is not the body, add the current joint strength as the input. There are 8 inputs here
    if (bp.rb.transform != body)
    {
        sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
    }
}

A total of 32 input dimensions.

Now look at the output OnActionReceived:

public override void OnActionReceived(ActionBuffers actionBuffers)
{
    // The dictionary with all the body parts in it are in the jdController
    var bpDict = m_JdController.bodyPartsDict;

    var continuousActions = actionBuffers.ContinuousActions;
    var i = -1;
    // Pick a new target joint rotation
    bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
    bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
    bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);

    // Update joint strength
    bpDict[leg0Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg1Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg2Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg3Upper].SetJointStrength(continuousActions[++i]);
    bpDict[leg0Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg1Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg2Lower].SetJointStrength(continuousActions[++i]);
    bpDict[leg3Lower].SetJointStrength(continuousActions[++i]);
}

The rotation angles and corresponding forces of eight joints are set. A total of 20 continuous outputs.

Take another look at FixedUpdate. This function is called at fixed time intervals and is not affected by the frame rate.

void FixedUpdate()
{
    // Update cube and indicator
    UpdateOrientationObjects();

    // Check whether the foot touches the ground. If it touches the ground, the material will be changed
    if (useFootGroundedVisualization)
    {
        foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
        foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround
            ? groundedMaterial
            : unGroundedMaterial;
    }

    var cubeForward = m_OrientationCube.transform.forward;

    // Now, the closer the speed vector is to the target speed vector, the higher the reward
    var matchSpeedReward = GetMatchingVelocityReward(cubeForward * TargetWalkingSpeed, GetAvgVelocity());

    // The point multiplication of two vectors is positive when the directions are the same and negative when the directions are opposite
    var lookAtTargetReward = (Vector3.Dot(cubeForward, body.forward) + 1) * .5F;
	// The reward is in the form of multiplication to ensure that the trained agents are facing the goal and the speed is also facing the goal
    AddReward(matchSpeedReward * lookAtTargetReward);
}

Among them, UpdateOrientationObjects always updates the position and rotation of the cube object on the agent to make it always move towards the target. Also update the position and rotation of the lower indicator:

void UpdateOrientationObjects()
{
    m_OrientationCube.UpdateOrientation(body, m_Target);
    if (m_DirectionIndicator)
    {
    	m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
	}
}

There is also a GetMatchingVelocityReward method, which inputs the target speed and actual speed, and outputs a reward. The smaller the distance between the two speeds, the higher the reward:

public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
    //The direct distance between the target speed and the actual speed, and the range is limited to 0 to TargetWalkingSpeed
    var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, TargetWalkingSpeed);

    //return the value on a declining sigmoid shaped curve that decays from 1 to 0
    //This reward will approach 1 if it matches perfectly and approach zero as it deviates
    return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / TargetWalkingSpeed, 2), 2);
}

Target Controller

After the explanation of the main script of inheriting the Agent above, the following is the script for generating the target. We need to generate a cube at a random place in the site and regenerate it after being eaten.

// Execute only once when the program starts
void OnEnable()
{
    m_startingPos = transform.position;
    if (respawnIfTouched)
    {
        MoveTargetToRandomPosition();
    }
}
// Once per frame
void Update()
{
    if (respawnIfFallsOffPlatform)
    {
        if (transform.position.y < m_startingPos.y - fallDistance)
        {
            Debug.Log($"{transform.name} Fell Off Platform");
            MoveTargetToRandomPosition();
        }
    }
}

// Move randomly in a spherical range and fix the y-axis
public void MoveTargetToRandomPosition()
{
    var newTargetPos = m_startingPos + (Random.insideUnitSphere * spawnRadius);
    newTargetPos.y = m_startingPos.y;
    transform.position = newTargetPos;
}

// When you collide with an agent, move to another place
// The cube reward should be added here, but because the previous reward settings are relatively perfect, you can train normally without it.
private void OnCollisionEnter(Collision col)
{
    if (col.transform.CompareTag(tagToDetect))
    {
        onCollisionEnterEvent.Invoke(col);
        if (respawnIfTouched)
        {
            MoveTargetToRandomPosition();
        }
    }
}

Training parameter configuration

Without using other additional functions, pure PPO can make the reward reach more than 2500 in 3 million steps. The action of the agent is ideal:

behaviors:
  Crawler:
    trainer_type: ppo
    hyperparameters:
      batch_size: 2048
      buffer_size: 20480
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.995
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 10000000
    time_horizon: 1000
    summary_freq: 30000

The configuration file using SAC algorithm is:

behaviors:
  Crawler:
    trainer_type: sac
    hyperparameters:
      learning_rate: 0.0003
      learning_rate_schedule: constant
      batch_size: 256
      buffer_size: 500000
      buffer_init_steps: 0
      tau: 0.005
      steps_per_update: 20.0
      save_replay_buffer: false
      init_entcoef: 1.0
      reward_signal_steps_per_update: 20.0
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.995
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 5000000
    time_horizon: 1000
    summary_freq: 30000

Configuration using imitation learning:

behaviors:
  Crawler:
    trainer_type: ppo
    hyperparameters:
      batch_size: 2024
      buffer_size: 20240
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      gail:
        gamma: 0.99
        strength: 1.0
        network_settings:
          normalize: true
          hidden_units: 128
          num_layers: 2
          vis_encode_type: simple
        learning_rate: 0.0003
        use_actions: false
        use_vail: false
        demo_path: Project/Assets/ML-Agents/Examples/Crawler/Demos/ExpertCrawler.demo
    keep_checkpoints: 5
    max_steps: 10000000
    time_horizon: 1000
    summary_freq: 30000
    behavioral_cloning:
      demo_path: Project/Assets/ML-Agents/Examples/Crawler/Demos/ExpertCrawler.demo
      steps: 50000
      strength: 0.5
      samples_per_update: 0

Tags: Unity

Posted on Tue, 09 Nov 2021 17:03:45 -0500 by casbboy