Code
import math
import numpy as np
import plotly.express as px
import torch
= np.linspace(-5, 5, 1000)
x = torch.erf(torch.tensor(x)).numpy()
y # add y as erf(x) label
= px.line(x=x, y=y, labels={"x": "x", "y": "erf(x)"})
fig
fig.show()
Peter Pham
July 19, 2023
I happen to stumble upon a feature request to implement the metal backend for the \(\operatorname{erf}^{-1}(x)\) in Pytorch. I thought it would be a good exercise to implement it myself to get a better understanding of defining custom torch ops for the MPS backend. (There’s also list of unimplemented torch ops for MPS)
To work on the inverse error function, we first need an understanding of the error function. The error function is defined as:
\[ \operatorname{erf}(x)=\frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^{2}} d t \tag{1}\]
The error function maps a real number in the domain \((-\infty, \infty)\) to a real number from \((-1, 1)\).
The inverse error function is defined as: \[ \operatorname{erf}^{-1}(\operatorname{erf}(x))=x \tag{2}\]
This means the \(\operatorname{erf}^{-1}(x)\) will have a domain of \((-1, 1)\) and a range of \((-\infty, \infty)\).
There is no closed-form solution for the \(\operatorname{erf}^{-1}(x)\) however it can be approximated using elementary function as proposed by Abramowitz and Stegun[1]. The approximation is given by:
\[ \operatorname{erf}^{-1}(x) \approx \operatorname{sgn}(x) \sqrt{\sqrt{\left(\frac{2}{\pi a}+\frac{\ln(1-x^{2})}{2}\right)^{2}- \frac{\ln (1-x^{2})}{a} }- (\frac{2}{\pi a}+\frac{\ln (1-x^{2})}{2})} \tag{3}\]
where \(a=0.147\) or \(a=0.140012\) where the latter is more accurate around \(x=0\) for the error function but the former has a smaller maximum error for the error function. There was no analysis given for the maximum error rate of the \(\operatorname{erf}^{-1}(x)\) so I will have to experiment with it myself.
Here is a plot of the \(\operatorname{erf}^{-1}(x)\) using the approximation method.
def erfinv(x, a=0.147):
"""
Compute the inverse error function using an approximation method
"""
# the Abravov fast approximation method
# compute the first term
term = np.sqrt(
np.sqrt(
(2 / (np.pi * a) + np.log(1 - x**2) / 2) ** 2 - np.log(1 - x**2) / a
)
- (2 / (np.pi * a) + np.log(1 - x**2) / 2)
)
# compute the sign
sign = 1 if x > 0 else -1
y = sign * term
return y
x = np.linspace(-1, 1, 1000)
# Vectorize the erfinv function
erfinv_vec = np.vectorize(erfinv)
y_rapid = erfinv_vec(x)
Let’s compute the MSE (ignoring inf values) between the approximation versus the torch.erfinv implementation using:
\(a = 0.147\)
MSE: 1.7424258166728075e-07
Max Error: 0.003953770269555346 at x: -0.997997997997998, index: 0
y_rapid: -2.180960329841855, y_torch: -2.1849141001114103
Now repeat the experiment but using:
\(a=0.140012\)
mse = np.mean((y_rapid2 - y_torch) ** 2)
print(f"MSE: {mse}")
max_error_idx = np.argmax(np.abs(y_rapid2 - y_torch))
print(f"Max Error: {np.max(np.abs(y_rapid2 - y_torch))} at x: {x_mask[max_error_idx]}")
print(f"y_rapid2: {y_rapid2[max_error_idx]}, y_torch: {y_torch[max_error_idx]}")
# max error index and x value
MSE: 8.462108415214081e-07
Max Error: 0.007140781441233646 at x: -0.997997997997998
y_rapid2: -2.1777733186701766, y_torch: -2.1849141001114103
Both methods have the worst performance at around \(x=-1\) which is expected since the \(\operatorname{erf}^{-1}(x)\) is asymptotic at \(x=-1\). I will use \(a=0.147\) for the approximation method since it has a smaller maximum error and also smaller MSE compared to the torch.erfinv implementation.
Now that we have a decent approximation of the inverse error function, we can implement it in pytorch.
My experience with compiling pytorch from source was more challenging on macOS compared to Ubuntu 22.04.
I used this guide and in addition I encountered 2 errors that weren’t covered in the guide.
/src/pytorch/torch/csrc/Generator.cpp /src/pytorch/torch/csrc/Generator.cpp:208:16: error: cast from ‘PyObject ()(THPGenerator , void )’ (aka ’_object ()(THPGenerator , void )‘) to ’getter’ (aka ’_object ()(_object , void )’) converts to incompatible function type [-Werror,-Wcast-function-type-strict] {“device”, (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
I had to modify CMakeLists.txtto to comment out : -Werror=cast-function-type” CMAKE_CXX_FLAGS
append_cxx_flag_if_supported("-Wno-unused-but-set-variable" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-Wno-maybe-uninitialized" CMAKE_CXX_FLAGS)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fno-omit-frame-pointer -O0")
string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fno-omit-frame-pointer -O0")
append_cxx_flag_if_supported("-fno-math-errno" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-fno-trapping-math" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-Werror=format" CMAKE_CXX_FLAGS)
# append_cxx_flag_if_supported("-Werror=cast-function-type" CMAKE_CXX_FLAGS)
Then finally I was able to compile pytorch from source. using the follow command:
First I have to try to reduce Equation 3 to avoid repeated calculation of terms in order to reduce the number of nodes in the compute graph.
We can create the following terms so that they are re-used in the graph:
\[\begin{align*} A &= x^2 \\ B &= \log(1 - A) \\ C &= \frac{2}{\pi a} + \frac{B}{2} \\ \end{align*}\]
Then the Equation 3 can be re-written as: \[ \operatorname{erfinv}(x) = \operatorname{sgn}(x) \sqrt{\sqrt{C^2 - \frac{B}{a}} - C} \]
This translate to a total of 17 nodes (not counting constants) in the compute graph for the erfinv function.
// constant tensors
auto negOneTensor = [mpsGraph constantWithScalar:-1.0 dataType:inputTensor.dataType];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
auto halfTensor = [mpsGraph constantWithScalar:0.5 dataType:inputTensor.dataType];
auto oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType];
auto twoTensor = [mpsGraph constantWithScalar:2.0 dataType:inputTensor.dataType];
auto piTensor = [mpsGraph constantWithScalar:3.14159265358979323846264338327950288 dataType:inputTensor.dataType];
auto aTensor = [mpsGraph constantWithScalar:0.147 dataType:inputTensor.dataType];
auto A = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:inputTensor name:nil];
auto B = [mpsGraph logarithmWithTensor:[mpsGraph subtractionWithPrimaryTensor:oneTensor
secondaryTensor:A
name:nil]
name:nil];
auto C = [mpsGraph
additionWithPrimaryTensor:[mpsGraph divisionWithPrimaryTensor:twoTensor
secondaryTensor:[mpsGraph multiplicationWithPrimaryTensor:piTensor
secondaryTensor:aTensor
name:nil]
name:nil]
secondaryTensor:[mpsGraph multiplicationWithPrimaryTensor:B secondaryTensor:halfTensor name:nil]
name:nil];
auto CSquared = [mpsGraph multiplicationWithPrimaryTensor:C secondaryTensor:C name:nil];
auto CSquaredMinusBDivA = [mpsGraph subtractionWithPrimaryTensor:CSquared
secondaryTensor:[mpsGraph divisionWithPrimaryTensor:B
secondaryTensor:aTensor
name:nil]
name:nil];
auto squareRootDiffTerm = [mpsGraph squareRootWithTensor:CSquaredMinusBDivA name:nil];
auto finalDiff = [mpsGraph subtractionWithPrimaryTensor:squareRootDiffTerm secondaryTensor:C name:nil];
auto finalSquareRoot = [mpsGraph squareRootWithTensor:finalDiff name:nil];
auto predicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
auto resultPositive = [mpsGraph multiplicationWithPrimaryTensor:finalSquareRoot secondaryTensor:oneTensor name:nil];
auto resultNegative = [mpsGraph multiplicationWithPrimaryTensor:finalSquareRoot
secondaryTensor:negOneTensor
name:nil];
return [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:resultPositive
falsePredicateTensor:resultNegative
name:nil];
Here’s a quick benchmark of the MPS compute (M1 macbookpro 16” 16 GB) vs CPU for the erfinv function:
import torch
x = torch.arange(-1, 1, 0.00001)
x = x.to("mps")
# measure MPS compute time
time = %timeit -o -q torch.erfinv(x)
mps_time = time.average
print("MPS torch.erfinv time: ", mps_time)
x = x.to("cpu")
# measure CPU compute time by calling torch.erfinv but storing it to y_cpu
time = %timeit -o -q torch.erfinv(x)
cpu_time = time.average
print("CPU torch.erfinv time: ", cpu_time)
print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv")
# compute MSE between y_cpu and y_mps
x = x.to("mps")
y_mps = torch.erfinv(x)
y_cpu = torch.erfinv(x.to("cpu"))
mask = torch.isfinite(y_cpu) & torch.isfinite(y_mps.to("cpu"))
mse = torch.square(y_cpu[mask] - y_mps[mask].to("cpu")).mean()
print("MSE between MPS and CPU torch.erfinv: ", mse)
MPS torch.erfinv time: 1.2065902201325765e-05
CPU torch.erfinv time: 0.003775719881440247
MPS torch.erfinv is 31292.478742496205 percent faster than CPU torch.erfinv
MSE between MPS and CPU torch.erfinv: tensor(4.1653e-14)
I thought I was done until I added erfinv into those 2 test cases.
python3 test/test_mps.py TestNLLLoss.test_unary_ops
python3 test/test_unary_ufuncs.py TestUnaryUfuncs.unary_mem_overlap_cases
I found that the metal compute graph accuracy was not good enough to pass the test for TestNLLLoss.test_unary_ops where it could only pass 70% of the test. I ended up adding 2 processing step to include: - 2 iterations of the Newton-Raphson method to improve the accuracy of the MPS compute graph. - Logic to set to +/- inf if input is +/- 1.0 since the approximation be off at the boundary.
Below was the addition to the MPSGraph unary compute graph for erfinv:
// add 2 steps of Newton-Raphson iteration to improve accuracy
// adopted from
// https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L191
auto currentEstimated = estimated;
for (int i = 0; i < 2; ++i) {
auto negEstimated = [mpsGraph multiplicationWithPrimaryTensor:currentEstimated
secondaryTensor:negOneTensor
name:nil];
auto estimatedSquared = [mpsGraph multiplicationWithPrimaryTensor:negEstimated
secondaryTensor:currentEstimated
name:nil];
auto estimatedSquaredExp = [mpsGraph exponentWithTensor:estimatedSquared name:nil];
auto twoDivSquareRootPi = [mpsGraph divisionWithPrimaryTensor:twoTensor
secondaryTensor:piSquareRootTensor
name:nil];
auto gradientDenominator = [mpsGraph multiplicationWithPrimaryTensor:twoDivSquareRootPi
secondaryTensor:estimatedSquaredExp
name:nil];
auto changeErf = [mpsGraph subtractionWithPrimaryTensor:[mpsGraph erfWithTensor:currentEstimated name:nil]
secondaryTensor:inputTensor
name:nil];
auto gradient = [mpsGraph divisionWithPrimaryTensor:changeErf secondaryTensor:gradientDenominator name:nil];
currentEstimated = [mpsGraph subtractionWithPrimaryTensor:currentEstimated secondaryTensor:gradient name:nil];
}
// post processing step to check if we have exactly +1/-1 then we should map to infinity/-infinity
// this is because the this algorithm might push us on the wrong side of the asymptote due to rounding
auto onePredicate = [mpsGraph equalWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil];
auto negOnePredicate = [mpsGraph equalWithPrimaryTensor:inputTensor secondaryTensor:negOneTensor name:nil];
auto resultWithInfinity = [mpsGraph selectWithPredicateTensor:onePredicate
truePredicateTensor:infinityTensor
falsePredicateTensor:currentEstimated
name:nil];
return [mpsGraph selectWithPredicateTensor:negOnePredicate
truePredicateTensor:negInfinityTensor
falsePredicateTensor:resultWithInfinity
name:nil];
Adding the above code allows me to pass the pytorch automatic testing 100%.
➜ python3 test/test_mps.py TestNLLLoss.test_unary_ops
.
----------------------------------------------------------------------
Ran 1 test in 0.283s
OK2
Unfortunately, I discovered that the current algorithm uses too much memory. I have to optimize the MPS compute graph further before this function could be of practical use. I will go back and update this blog and reopen my PR once I have a better solution.
** UPDATE July 19, 2023 ** I had created another PR where I used raw metal kernel for 18x speed up instead of the MPS graph api. This PR is currently under review.
** UPDATE Aug 16, 2023 ** The raw metal PR was merged into pytorch master. The code has a minor bug that would cause slicing to fail so I submitted another follow up PR for the bug fix.