[ml-agents 2.0] MLDino with GAIL, BC

Unity3D/ml-agent 2021. 11. 16. 23:04
반응형

#unity #mlagents #강화학습 #모방학습 #유니티 #머신러닝

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class CactusGenerator : MonoBehaviour
{
    public GameObject prefab;
    public Transform targetParent;
    [SerializeField]
    private float spawnTime = 3f;
    private Coroutine routine;
    private List<GameObject> cactusList = new List<GameObject>();

    public void StartGenerate()
    {
        if (this.routine != null) StopCoroutine(this.routine);
        this.routine = StartCoroutine(this.GenerateImpl());    
    }

    private IEnumerator GenerateImpl()
    {
        while (true) {
            var go = Instantiate(this.prefab, targetParent);
            var pos = go.transform.localPosition;
            pos.x = 7.25f;
            go.transform.localPosition = pos;
            this.cactusList.Add(go);
            yield return new WaitForSeconds(this.spawnTime);
        }
    }

    public void StopGenerate() 
    {
        if (this.routine != null) StopCoroutine(this.routine);
    }

    public void ClearCactusList()
    {
        foreach (var go in this.cactusList) {
            Destroy(go);
        }

        this.cactusList.Clear();
    }
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class CactusMove : MonoBehaviour
{
    [SerializeField]
    private float moveSpeed = 3f;

    // Update is called once per frame
    void Update()
    {
        this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);

        if (this.transform.localPosition.x <= -6f)
        {
            Destroy(this.gameObject);
        }
    }
}
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;

public class DinoAgent : Agent
{
    [SerializeField]
    private float jumpForce = 300f;
    private Rigidbody2D rBody;
    private Animator anim;
    public UnityAction onDie;
    public UnityAction onEpisiodeBegin;
    private bool isGround = false;
    private float elapsedReward = 0;

    void Start()
    {
        this.rBody = this.GetComponent<Rigidbody2D>();    
        this.anim = this.GetComponent<Animator>();
    }

    public override void OnEpisodeBegin()
    {
        this.transform.localPosition = Vector3.zero;
        this.elapsedReward = 0;
        this.onEpisiodeBegin();
        this.anim.Rebind();
        this.anim.Update(0f);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(this.transform.localPosition.y);   //1개 
        sensor.AddObservation(this.rBody.velocity.y);       //1개 
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        Vector3 dir = Vector3.zero;
        dir.y = actions.DiscreteActions[0]; 

        if (dir.y == 1 && this.isGround && this.rBody.velocity.y == 0)
        {
            this.anim.SetBool("IsJump", true);
            this.rBody.AddForce(dir * this.jumpForce);
        }
    }
    

    private void OnCollisionStay2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Ground"))
        {
            this.isGround = true;
            this.AddReward(0.01f);
            this.elapsedReward += 0.01f;
        }
    }

    private void OnCollisionExit2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Ground")) {
            this.isGround = false;
        }
    }

    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Ground"))
        {
            this.anim.SetBool("IsJump", false);
        }
        else if (collision.collider.CompareTag("Cactus"))
        {
            this.anim.SetTrigger("Die");
            this.onDie();
            this.elapsedReward *= 0.5f;
            this.AddReward(-this.elapsedReward);
            this.EndEpisode(); 
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        if (Input.GetKey(KeyCode.Space)) {
            var action = actionsOut.DiscreteActions;
            action[0] = 1;
        }
    }
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class GroundMove : MonoBehaviour
{
    [SerializeField]
    private float moveSpeed = 3f;
    
    void Update()
    {
        this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);

        if (this.transform.localPosition.x <= -21f) {
            var pos = this.transform.localPosition;
            pos.x = 21f;
            this.transform.localPosition = pos;
        }
    }
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class TrainingArea : MonoBehaviour
{
    public CactusGenerator generator;
    public DinoAgent agent;

    void Start()
    {
        Application.runInBackground = true;

        this.agent.onDie = () => {
            this.generator.StopGenerate();
            this.generator.ClearCactusList();
        };
        this.agent.onEpisiodeBegin = () => {
            this.generator.StartGenerate();
        };

        this.generator.StartGenerate();    
    }

}

Imitation 전 

behaviors:
    DinoRun:
        trainer_type: ppo
        hyperparameters:
            batch_size: 10
            buffer_size: 20480
            learning_rate: 0.0003
            beta: 0.001
            epsilon: 0.2
            lambd: 0.99
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 128
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
        keep_checkpoints: 5
        max_steps: 500000
        time_horizon: 1000
        summary_freq: 12000


https://forum.unity.com/threads/gail-vs-behavioral-cloning-whats-the-difference.944463/

 

GAIL vs Behavioral Cloning, what's the difference?

I couldn't really find a detailed explanation in the docs. Some of the imitation config files...

forum.unity.com

mlagents-learn ./DinoRun.yaml --run-id=DinoRun_02 --time-scale=1 --capture-frame-rate=0 --resume

behaviors:
    DinoRun:
        trainer_type: ppo
        hyperparameters:
            batch_size: 1024
            buffer_size: 20480
            learning_rate: 0.0003
            beta: 0.001
            epsilon: 0.2
            lambd: 0.99
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 128
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
            gail:
                strength: 0.8
                demo_path: ./demos/DinoRun.demo

        #behavioral_cloning:
        #    strength: 0.8
        #    demo_path: ./demos/DinoRun.demo

        keep_checkpoints: 5
        max_steps: 500000
        time_horizon: 1000
        summary_freq: 12000

https://github.com/Unity-Technologies/ml-agents/tree/main/config/imitation

 

GitHub - Unity-Technologies/ml-agents: Unity Machine Learning Agents Toolkit

Unity Machine Learning Agents Toolkit. Contribute to Unity-Technologies/ml-agents development by creating an account on GitHub.

github.com

https://github.com/Unity-Technologies/ml-agents/blob/0.15.0/docs/Training-Imitation-Learning.md

 

GitHub - Unity-Technologies/ml-agents: Unity Machine Learning Agents Toolkit

Unity Machine Learning Agents Toolkit. Contribute to Unity-Technologies/ml-agents development by creating an account on GitHub.

github.com

https://github.com/Unity-Technologies/ml-agents/blob/release_5/docs/ML-Agents-Overview.md#imitation-learning

 

GitHub - Unity-Technologies/ml-agents: Unity Machine Learning Agents Toolkit

Unity Machine Learning Agents Toolkit. Contribute to Unity-Technologies/ml-agents development by creating an account on GitHub.

github.com

 

 

behaviors:
    DinoRun:
        trainer_type: ppo
        hyperparameters:
            batch_size: 1024
            buffer_size: 20480
            learning_rate: 0.0003
            beta: 0.001
            epsilon: 0.2
            lambd: 0.99
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 128
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
            gail:
                strength: 0.01
                gamma: 0.99
                demo_path: ./demos/DinoRun.demo

        #behavioral_cloning:
        #    strength: 0.8
        #    demo_path: ./demos/DinoRun.demo

        keep_checkpoints: 5
        max_steps: 5000000
        time_horizon: 1000
        summary_freq: 12000

 

100만번 훈련 완료

이후 150만번 까지 훈련을 했으나 꼭 한번씩 선인장과 충돌 함...

 

데모 만들때 몇번 충돌 해서 그런가..

 

 

 

반응형

'Unity3D > ml-agent' 카테고리의 다른 글

MLZombieHunter 01  (0) 2022.03.21
./ml-agents-envs is not a valid editable requirement.  (0) 2022.03.21
--time-scale  (0) 2021.11.16
[ml-agents] batch size , episode, epoch  (0) 2021.11.14
Training-Configuration-File  (0) 2021.11.11
: