Files
modular-vr/Assets/Engine/Runtime/Utilities/Probability.cs
2023-03-22 07:54:00 +01:00

111 lines
3.6 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Linq;
using MathNet.Numerics.Distributions;
using UnityEngine;
using Random = System.Random;
namespace EscapeRoomEngine.Engine.Runtime.Utilities
{
/// <summary>
/// The representation of a normal distribution with a certain mean μ and standard deviation σ.
/// </summary>
[Serializable]
public struct NormalDistribution
{
/// <summary>
/// The mean of this distribution.
/// </summary>
public float μ;
/// <summary>
/// The standard deviation of this distribution.
/// </summary>
public float σ;
/// <summary>
/// Generate a standard normal distribution.
/// </summary>
public static NormalDistribution Standard => new() { μ = 0, σ = 1 };
public NormalDistribution(float[] samples) : this()
{
μ = Probability.Mean(samples);
σ = samples.Length == 0 ? 1.0f : Probability.StandardDeviation(samples, μ);
}
/// <summary>
/// Sample a random value from this distribution.
/// </summary>
public float Sample() => σ * Probability.Normal() + μ;
/// <summary>
/// Sample the CDF of this distribution.
/// </summary>
public float Cumulative(float x) => (float)new Normal(μ, σ).CumulativeDistribution(x);
/// <summary>
/// Sample the inverse CDF of this distribution.
/// </summary>
public float InverseCumulative(float x) => (float)new Normal(μ, σ).InverseCumulativeDistribution(x);
}
/// <summary>
/// This class is used for probability calculations.
/// </summary>
public static class Probability
{
private static readonly Random _random = new();
/// <summary>
/// Sample a random variable from the standard normal distribution.
/// For simplicity, the result is clamped between -3 and 3. This is accurate for 99.7% of all samples, by the three-σ rule.
/// </summary>
/// <remarks>The calculation of the random variable is done by a Box-Muller transform.</remarks>
public static float Normal()
{
float u1, u2, square;
// get two random points inside the unit circle
do
{
u1 = 2 * (float)_random.NextDouble() - 1;
u2 = 2 * (float)_random.NextDouble() - 1;
square = u1 * u1 + u2 * u2;
} while (square >= 1f);
return u1 * Mathf.Sqrt(-2 * Mathf.Log(square) / square);
}
/// <summary>
/// Calculate the mean of a list of samples.
/// </summary>
public static float Mean(float[] samples)
{
if (samples.Length == 0)
{
return 0;
}
return samples.Sum() / samples.Length;
}
/// <summary>
/// Calculate the standard deviation of a list of samples.
/// </summary>
public static float StandardDeviation(float[] samples) => StandardDeviation(samples, Mean(samples));
/// <summary>
/// Calculate the standard deviation of a list of samples without recalculating the mean.
/// </summary>
public static float StandardDeviation(float[] samples, float mean)
{
var deviations = new float[samples.Length];
for (var i = 0; i < samples.Length; i++)
{
var d = samples[i] - mean;
deviations[i] = d * d;
}
return Mathf.Sqrt(Mean(deviations));
}
}
}