The Inverse Error Function

pytorch
metal api
macos
Author

Peter Pham

Published

July 19, 2023

Overview

I happen to stumble upon a feature request to implement the metal backend for the \(\operatorname{erf}^{-1}(x)\) in Pytorch. I thought it would be a good exercise to implement it myself to get a better understanding of defining custom torch ops for the MPS backend. (There’s also list of unimplemented torch ops for MPS)

Math

To work on the inverse error function, we first need an understanding of the error function. The error function is defined as:

\[ \operatorname{erf}(x)=\frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^{2}} d t \tag{1}\]

The error function maps a real number in the domain \((-\infty, \infty)\) to a real number from \((-1, 1)\).

Code
import math

import numpy as np
import plotly.express as px
import torch


x = np.linspace(-5, 5, 1000)
y = torch.erf(torch.tensor(x)).numpy()
# add y as erf(x) label
fig = px.line(x=x, y=y, labels={"x": "x", "y": "erf(x)"})

fig.show()

The inverse error function is defined as: \[ \operatorname{erf}^{-1}(\operatorname{erf}(x))=x \tag{2}\]

This means the \(\operatorname{erf}^{-1}(x)\) will have a domain of \((-1, 1)\) and a range of \((-\infty, \infty)\).

There is no closed-form solution for the \(\operatorname{erf}^{-1}(x)\) however it can be approximated using elementary function as proposed by Abramowitz and Stegun[1]. The approximation is given by:

\[ \operatorname{erf}^{-1}(x) \approx \operatorname{sgn}(x) \sqrt{\sqrt{\left(\frac{2}{\pi a}+\frac{\ln(1-x^{2})}{2}\right)^{2}- \frac{\ln (1-x^{2})}{a} }- (\frac{2}{\pi a}+\frac{\ln (1-x^{2})}{2})} \tag{3}\]

where \(a=0.147\) or \(a=0.140012\) where the latter is more accurate around \(x=0\) for the error function but the former has a smaller maximum error for the error function. There was no analysis given for the maximum error rate of the \(\operatorname{erf}^{-1}(x)\) so I will have to experiment with it myself.

Here is a plot of the \(\operatorname{erf}^{-1}(x)\) using the approximation method.

Code
def erfinv(x, a=0.147):
    """
    Compute the inverse error function using an approximation method
    """

    # the Abravov fast approximation method
    # compute the first term
    term = np.sqrt(
        np.sqrt(
            (2 / (np.pi * a) + np.log(1 - x**2) / 2) ** 2 - np.log(1 - x**2) / a
        )
        - (2 / (np.pi * a) + np.log(1 - x**2) / 2)
    )
    # compute the sign
    sign = 1 if x > 0 else -1
    y =  sign * term

    return y


x = np.linspace(-1, 1, 1000)
# Vectorize the erfinv function
erfinv_vec = np.vectorize(erfinv)
y_rapid = erfinv_vec(x)
Code
# Add y as erfinv(x) label
fig = px.line(x=x, y=y_rapid, labels={"x": "x", "y": "erfinv(x)"})
fig.show()

Let’s compute the MSE (ignoring inf values) between the approximation versus the torch.erfinv implementation using:

\(a = 0.147\)

Code
x = np.linspace(-1, 1, 1000)
y_rapid = erfinv_vec(x)
y_torch = torch.erfinv(torch.tensor(x)).numpy()
# compute mask where y_rapid isn't infinite
mask = np.isfinite(y_rapid)
x_mask = x[mask]
y_rapid = y_rapid[mask]
y_torch = y_torch[mask]
mse = np.mean((y_rapid - y_torch) ** 2)
print(f"MSE: {mse}")
Code
print(f"MSE: {mse}")
max_error_idx = np.argmax(np.abs(y_rapid - y_torch))

print(
    f"Max Error: {np.max(np.abs(y_rapid - y_torch))} at x: {x_mask[max_error_idx]}, index: {max_error_idx}"
)
print(f"y_rapid: {y_rapid[max_error_idx]}, y_torch: {y_torch[max_error_idx]}")
MSE: 1.7424258166728075e-07
Max Error: 0.003953770269555346 at x: -0.997997997997998, index: 0
y_rapid: -2.180960329841855, y_torch: -2.1849141001114103

Now repeat the experiment but using:

\(a=0.140012\)

Code
y_rapid2 = erfinv_vec(x, a=0.140012)
y_torch = torch.erfinv(torch.tensor(x)).numpy()
# compute mask where y_rapid isn't infinite
mask = np.isfinite(y_rapid2)
x_mask = x[mask]
y_rapid2 = y_rapid2[mask]
y_torch = y_torch[mask]
Code
mse = np.mean((y_rapid2 - y_torch) ** 2)
print(f"MSE: {mse}")
max_error_idx = np.argmax(np.abs(y_rapid2 - y_torch))
print(f"Max Error: {np.max(np.abs(y_rapid2 - y_torch))} at x: {x_mask[max_error_idx]}")
print(f"y_rapid2: {y_rapid2[max_error_idx]}, y_torch: {y_torch[max_error_idx]}")
# max error index and x value
MSE: 8.462108415214081e-07
Max Error: 0.007140781441233646 at x: -0.997997997997998
y_rapid2: -2.1777733186701766, y_torch: -2.1849141001114103

Both methods have the worst performance at around \(x=-1\) which is expected since the \(\operatorname{erf}^{-1}(x)\) is asymptotic at \(x=-1\). I will use \(a=0.147\) for the approximation method since it has a smaller maximum error and also smaller MSE compared to the torch.erfinv implementation.

Pytorch Implementation

Now that we have a decent approximation of the inverse error function, we can implement it in pytorch.

Compile pytorch from source

My experience with compiling pytorch from source was more challenging on macOS compared to Ubuntu 22.04.

I used this guide and in addition I encountered 2 errors that weren’t covered in the guide.

  • I had a mismatched installation of protoc (protobuf) both from conda and brew. I had to uninstall both of them.
  • I encountered errors near the end related to cast-function-type-strict such as:

/src/pytorch/torch/csrc/Generator.cpp /src/pytorch/torch/csrc/Generator.cpp:208:16: error: cast from ‘PyObject ()(THPGenerator , void )’ (aka ’_object ()(THPGenerator , void )‘) to ’getter’ (aka ’_object ()(_object , void )’) converts to incompatible function type [-Werror,-Wcast-function-type-strict] {“device”, (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},

I had to modify CMakeLists.txtto to comment out : -Werror=cast-function-type” CMAKE_CXX_FLAGS

append_cxx_flag_if_supported("-Wno-unused-but-set-variable" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-Wno-maybe-uninitialized" CMAKE_CXX_FLAGS)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fno-omit-frame-pointer -O0")
string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fno-omit-frame-pointer -O0")
append_cxx_flag_if_supported("-fno-math-errno" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-fno-trapping-math" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-Werror=format" CMAKE_CXX_FLAGS)

# append_cxx_flag_if_supported("-Werror=cast-function-type" CMAKE_CXX_FLAGS)

Then finally I was able to compile pytorch from source. using the follow command:

MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ USE_MPS=1 USE_PYTORCH_METAL=1 \\
DEBUG=1 python setup.py develop

Metal api

  • https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph

Reducing the MPSGraph code

First I have to try to reduce Equation 3 to avoid repeated calculation of terms in order to reduce the number of nodes in the compute graph.

We can create the following terms so that they are re-used in the graph:

\[\begin{align*} A &= x^2 \\ B &= \log(1 - A) \\ C &= \frac{2}{\pi a} + \frac{B}{2} \\ \end{align*}\]

Then the Equation 3 can be re-written as: \[ \operatorname{erfinv}(x) = \operatorname{sgn}(x) \sqrt{\sqrt{C^2 - \frac{B}{a}} - C} \]

  • \(A\) term requires 1 multiply node
  • \(B\) term requires 1 log node and 1 subtract node
  • \(C\) term requires 1 add node, 2 divisision nodes, 1 multiply node
  • \(\operatorname{erfinv}(x)\) requires 2 square root nodes, 2 subtract nodes, 1 multiply node, 1 division node
    • the \(\operatorname{sgn}(x)\) term requires 4 nodes: greaterThan, selectPredicate, 2 multiply nodes

This translate to a total of 17 nodes (not counting constants) in the compute graph for the erfinv function.

MPSGraph implementation

// constant tensors
auto negOneTensor = [mpsGraph constantWithScalar:-1.0 dataType:inputTensor.dataType];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
auto halfTensor = [mpsGraph constantWithScalar:0.5 dataType:inputTensor.dataType];
auto oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType];
auto twoTensor = [mpsGraph constantWithScalar:2.0 dataType:inputTensor.dataType];
auto piTensor = [mpsGraph constantWithScalar:3.14159265358979323846264338327950288 dataType:inputTensor.dataType];
auto aTensor = [mpsGraph constantWithScalar:0.147 dataType:inputTensor.dataType];


auto A = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:inputTensor name:nil];
auto B = [mpsGraph logarithmWithTensor:[mpsGraph subtractionWithPrimaryTensor:oneTensor
                                                                    secondaryTensor:A
                                                                                name:nil]
                                        name:nil];
auto C = [mpsGraph
    additionWithPrimaryTensor:[mpsGraph divisionWithPrimaryTensor:twoTensor
                                                    secondaryTensor:[mpsGraph multiplicationWithPrimaryTensor:piTensor
                                                                                            secondaryTensor:aTensor
                                                                                                        name:nil]
                                                                name:nil]
                secondaryTensor:[mpsGraph multiplicationWithPrimaryTensor:B secondaryTensor:halfTensor name:nil]
                            name:nil];
auto CSquared = [mpsGraph multiplicationWithPrimaryTensor:C secondaryTensor:C name:nil];
auto CSquaredMinusBDivA = [mpsGraph subtractionWithPrimaryTensor:CSquared
                                        secondaryTensor:[mpsGraph divisionWithPrimaryTensor:B
                                                                            secondaryTensor:aTensor
                                                                                        name:nil]
                                                    name:nil];
auto squareRootDiffTerm = [mpsGraph squareRootWithTensor:CSquaredMinusBDivA name:nil];
auto finalDiff = [mpsGraph subtractionWithPrimaryTensor:squareRootDiffTerm secondaryTensor:C name:nil];
auto finalSquareRoot = [mpsGraph squareRootWithTensor:finalDiff name:nil];
auto predicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor
                                                        secondaryTensor:zeroTensor
                                                                    name:nil];
auto resultPositive = [mpsGraph multiplicationWithPrimaryTensor:finalSquareRoot secondaryTensor:oneTensor name:nil];
auto resultNegative = [mpsGraph multiplicationWithPrimaryTensor:finalSquareRoot
                                                secondaryTensor:negOneTensor
                                                            name:nil];
return [mpsGraph selectWithPredicateTensor:predicateTensor
                        truePredicateTensor:resultPositive
                        falsePredicateTensor:resultNegative
                                        name:nil];

Here’s a quick benchmark of the MPS compute (M1 macbookpro 16” 16 GB) vs CPU for the erfinv function:

Code
import torch
x = torch.arange(-1, 1, 0.00001)
x = x.to("mps")
# measure MPS compute time
time = %timeit -o -q  torch.erfinv(x)
mps_time = time.average
print("MPS torch.erfinv time: ", mps_time)
x = x.to("cpu")
# measure CPU compute time by calling torch.erfinv but storing it to y_cpu
time = %timeit -o -q torch.erfinv(x)
cpu_time = time.average
print("CPU torch.erfinv time: ", cpu_time)
print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv")

# compute MSE between y_cpu and y_mps
x = x.to("mps")
y_mps = torch.erfinv(x)
y_cpu = torch.erfinv(x.to("cpu"))
mask = torch.isfinite(y_cpu) & torch.isfinite(y_mps.to("cpu"))
mse = torch.square(y_cpu[mask] - y_mps[mask].to("cpu")).mean()
print("MSE between MPS and CPU torch.erfinv: ", mse)
MPS torch.erfinv time:  1.2065902201325765e-05
CPU torch.erfinv time:  0.003775719881440247
MPS torch.erfinv is 31292.478742496205 percent faster than CPU torch.erfinv
MSE between MPS and CPU torch.erfinv:  tensor(4.1653e-14)

I thought I was done until I added erfinv into those 2 test cases.

python3 test/test_mps.py TestNLLLoss.test_unary_ops
python3 test/test_unary_ufuncs.py TestUnaryUfuncs.unary_mem_overlap_cases

I found that the metal compute graph accuracy was not good enough to pass the test for TestNLLLoss.test_unary_ops where it could only pass 70% of the test. I ended up adding 2 processing step to include: - 2 iterations of the Newton-Raphson method to improve the accuracy of the MPS compute graph. - Logic to set to +/- inf if input is +/- 1.0 since the approximation be off at the boundary.

Below was the addition to the MPSGraph unary compute graph for erfinv:

   // add 2 steps of Newton-Raphson iteration to improve accuracy
    // adopted from
    // https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L191

    auto currentEstimated = estimated;
    for (int i = 0; i < 2; ++i) {
      auto negEstimated = [mpsGraph multiplicationWithPrimaryTensor:currentEstimated
                                                    secondaryTensor:negOneTensor
                                                               name:nil];
      auto estimatedSquared = [mpsGraph multiplicationWithPrimaryTensor:negEstimated
                                                        secondaryTensor:currentEstimated
                                                                   name:nil];
      auto estimatedSquaredExp = [mpsGraph exponentWithTensor:estimatedSquared name:nil];
      auto twoDivSquareRootPi = [mpsGraph divisionWithPrimaryTensor:twoTensor
                                                    secondaryTensor:piSquareRootTensor
                                                               name:nil];
      auto gradientDenominator = [mpsGraph multiplicationWithPrimaryTensor:twoDivSquareRootPi
                                                           secondaryTensor:estimatedSquaredExp
                                                                      name:nil];
      auto changeErf = [mpsGraph subtractionWithPrimaryTensor:[mpsGraph erfWithTensor:currentEstimated name:nil]
                                              secondaryTensor:inputTensor
                                                         name:nil];
      auto gradient = [mpsGraph divisionWithPrimaryTensor:changeErf secondaryTensor:gradientDenominator name:nil];
      currentEstimated = [mpsGraph subtractionWithPrimaryTensor:currentEstimated secondaryTensor:gradient name:nil];
    }

    // post processing step to check if we have exactly +1/-1 then we should map to infinity/-infinity
    // this is because the this algorithm might push us on the wrong side of the asymptote due to rounding
    auto onePredicate = [mpsGraph equalWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil];
    auto negOnePredicate = [mpsGraph equalWithPrimaryTensor:inputTensor secondaryTensor:negOneTensor name:nil];

    auto resultWithInfinity = [mpsGraph selectWithPredicateTensor:onePredicate
                                              truePredicateTensor:infinityTensor
                                             falsePredicateTensor:currentEstimated
                                                             name:nil];
    return [mpsGraph selectWithPredicateTensor:negOnePredicate
                           truePredicateTensor:negInfinityTensor
                          falsePredicateTensor:resultWithInfinity
                                          name:nil];

Adding the above code allows me to pass the pytorch automatic testing 100%.

 python3 test/test_mps.py TestNLLLoss.test_unary_ops
.
----------------------------------------------------------------------
Ran 1 test in 0.283s

OK2

Unfortunately, I discovered that the current algorithm uses too much memory. I have to optimize the MPS compute graph further before this function could be of practical use. I will go back and update this blog and reopen my PR once I have a better solution.

** UPDATE July 19, 2023 ** I had created another PR where I used raw metal kernel for 18x speed up instead of the MPS graph api. This PR is currently under review.

** UPDATE Aug 16, 2023 ** The raw metal PR was merged into pytorch master. The code has a minor bug that would cause slicing to fail so I submitted another follow up PR for the bug fix.

References

1. EduCare. 2023. Error in functions. https://www.educare.bz/unit/error-in-functions/