Files
modular-vr/Assets/Engine/Runtime/Utilities/Probability.cs
2022-12-15 23:29:02 +01:00

77 lines
2.4 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Linq;
using MathNet.Numerics.Distributions;
using UnityEngine;
using Random = System.Random;
namespace EscapeRoomEngine.Engine.Runtime.Utilities
{
[Serializable]
public struct NormalDistribution
{
public float μ, σ;
public static NormalDistribution Standard => new NormalDistribution { μ = 0, σ = 1 };
public NormalDistribution(float[] samples) : this()
{
μ = Probability.Mean(samples);
σ = Probability.StandardDeviation(samples, μ);
}
public float Sample() => σ * Probability.Normal() + μ;
public float Cumulative(float x) => (float)new Normal(μ, σ).CumulativeDistribution(x);
public float InverseCumulative(float x) => (float)new Normal(μ, σ).InverseCumulativeDistribution(x);
}
public static class Probability
{
private static readonly Random _random = new();
/// <summary>
/// Sample a random variable from the standard normal distribution.
/// For simplicity, the result is clamped between -3 and 3. This is accurate for 99.7% of all samples, by the three-σ rule.
/// </summary>
/// <remarks>The calculation of the random variable is done by a Box-Muller transform.</remarks>
public static float Normal()
{
float u1, u2, square;
// get two random points inside the unit circle
do
{
u1 = 2 * (float)_random.NextDouble() - 1;
u2 = 2 * (float)_random.NextDouble() - 1;
square = u1 * u1 + u2 * u2;
} while (square >= 1f);
return u1 * Mathf.Sqrt(-2 * Mathf.Log(square) / square);
}
public static float Mean(float[] samples)
{
if (samples.Length == 0)
{
return 0;
}
return samples.Sum() / samples.Length;
}
public static float StandardDeviation(float[] samples) => StandardDeviation(samples, Mean(samples));
public static float StandardDeviation(float[] samples, float mean)
{
var deviations = new float[samples.Length];
for (var i = 0; i < samples.Length; i++)
{
var d = samples[i] - mean;
deviations[i] = d * d;
}
return Mathf.Sqrt(Mean(deviations));
}
}
}