using System;
using System.Linq;
using MathNet.Numerics.Distributions;
using UnityEngine;
using Random = System.Random;
namespace EscapeRoomEngine.Engine.Runtime.Utilities
{
///
/// The representation of a normal distribution with a certain mean μ and standard deviation σ.
///
[Serializable]
public struct NormalDistribution
{
///
/// The mean of this distribution.
///
public float μ;
///
/// The standard deviation of this distribution.
///
public float σ;
///
/// Generate a standard normal distribution.
///
public static NormalDistribution Standard => new() { μ = 0, σ = 1 };
public NormalDistribution(float[] samples) : this()
{
μ = Probability.Mean(samples);
σ = Probability.StandardDeviation(samples, μ);
}
///
/// Sample a random value from this distribution.
///
public float Sample() => σ * Probability.Normal() + μ;
///
/// Sample the CDF of this distribution.
///
public float Cumulative(float x) => (float)new Normal(μ, σ).CumulativeDistribution(x);
///
/// Sample the inverse CDF of this distribution.
///
public float InverseCumulative(float x) => (float)new Normal(μ, σ).InverseCumulativeDistribution(x);
}
///
/// This class is used for probability calculations.
///
public static class Probability
{
private static readonly Random _random = new();
///
/// Sample a random variable from the standard normal distribution.
/// For simplicity, the result is clamped between -3 and 3. This is accurate for 99.7% of all samples, by the three-σ rule.
///
/// The calculation of the random variable is done by a Box-Muller transform.
public static float Normal()
{
float u1, u2, square;
// get two random points inside the unit circle
do
{
u1 = 2 * (float)_random.NextDouble() - 1;
u2 = 2 * (float)_random.NextDouble() - 1;
square = u1 * u1 + u2 * u2;
} while (square >= 1f);
return u1 * Mathf.Sqrt(-2 * Mathf.Log(square) / square);
}
///
/// Calculate the mean of a list of samples.
///
public static float Mean(float[] samples)
{
if (samples.Length == 0)
{
return 0;
}
return samples.Sum() / samples.Length;
}
///
/// Calculate the standard deviation of a list of samples.
///
public static float StandardDeviation(float[] samples) => StandardDeviation(samples, Mean(samples));
///
/// Calculate the standard deviation of a list of samples without recalculating the mean.
///
public static float StandardDeviation(float[] samples, float mean)
{
var deviations = new float[samples.Length];
for (var i = 0; i < samples.Length; i++)
{
var d = samples[i] - mean;
deviations[i] = d * d;
}
return Mathf.Sqrt(Mean(deviations));
}
}
}