Files
modular-vr/Assets/Engine/Runtime/Utilities/Probability.cs
2022-12-15 10:49:46 +01:00

74 lines
2.3 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Linq;
using MathNet.Numerics.Distributions;
using Random = System.Random;
namespace EscapeRoomEngine.Engine.Runtime.Utilities
{
[Serializable]
public struct NormalDistribution
{
public double mean, σ;
public static NormalDistribution Standard => new NormalDistribution { mean = 0, σ = 1 };
public NormalDistribution(double[] samples) : this()
{
mean = Probability.Mean(samples);
σ = Probability.StandardDeviation(samples, mean);
}
public double Sample() => σ * Probability.Normal() + mean;
public double Cumulative(double x) => new Normal(mean, σ).CumulativeDistribution(x);
}
public static class Probability
{
private static readonly Random _random = new();
/// <summary>
/// Sample a random variable from the standard normal distribution.
/// For simplicity, the result is clamped between -3 and 3. This is accurate for 99.7% of all samples, by the three-σ rule.
/// </summary>
/// <remarks>The calculation of the random variable is done by a Box-Muller transform.</remarks>
public static double Normal()
{
double u1, u2, square;
// get two random points inside the unit circle
do
{
u1 = 2 * _random.NextDouble() - 1;
u2 = 2 * _random.NextDouble() - 1;
square = u1 * u1 + u2 * u2;
} while (square >= 1f);
return u1 * Math.Sqrt(-2 * Math.Log(square) / square);
}
public static double Mean(double[] samples)
{
if (samples.Length == 0)
{
return 0;
}
return samples.Sum() / samples.Length;
}
public static double StandardDeviation(double[] samples) => StandardDeviation(samples, Mean(samples));
public static double StandardDeviation(double[] samples, double mean)
{
var deviations = new double[samples.Length];
for (var i = 0; i < samples.Length; i++)
{
var d = samples[i] - mean;
deviations[i] = d * d;
}
return Math.Sqrt(Mean(deviations));
}
}
}