array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Add complex number support to `expm1`

Open kgryte opened this issue 3 years ago • 0 comments

This PR

  • adds complex number support to expm1 by documenting special cases. The exponential function is an entire function in the complex plane. Thus, the function does not have branch cuts.
  • updates the input and output array data types to be any floating-point data type, not just real-valued floating-point data types.
  • derives special cases from C99 exp and tested against NumPy (script found below).
import numpy as np
import math

def is_equal_float(x, y):
    """Test whether two floating-point numbers are equal with special consideration for zeros and NaNs.

    Parameters
    ----------
    x : float
        First input number.
    y : float
        Second input number.

    Returns
    -------
    bool
        Boolean indicating whether two floating-point numbers are equal.

    Examples
    --------
    >>> is_equal_float(0.0, -0.0)
    False
    >>> is_equal_float(-0.0, -0.0)
    True
    """
    # Handle +-0:
    if x == 0.0 and y == 0.0:
        return math.copysign(1.0, x) == math.copysign(1.0, y)

    # Handle NaNs:
    if x != x:
        return y != y

    # Everything else, including infinities:
    return x == y


def is_equal(x, y):
    """Test whether two complex numbers are equal with special consideration for zeros and NaNs.

    Parameters
    ----------
    x : complex
        First input number.
    y : complex
        Second input number.

    Returns
    -------
    bool
        Boolean indicating whether two complex numbers are equal.

    Examples
    --------
    >>> import numpy as np
    >>> is_equal(complex(np.nan, np.nan), complex(np.nan, np.nan))
    True
    """
    return is_equal_float(x.real, y.real) and is_equal_float(x.imag, y.imag)


# Strided array consisting of input values and expected values:
values = [
    complex(0.0, 0.0),        # 0
    complex(0.0, 0.0),        # 0
    
    complex(-0.0, 0.0),       # 1
    complex(-0.0, 0.0),       # 1
    
    complex(1.0, np.inf),     # 2
    complex(np.nan, np.nan),  # 2
    
    complex(1.0, np.nan),     # 3
    complex(np.nan, np.nan),  # 3
    
    complex(np.inf, 0.0),     # 4
    complex(np.inf, 0.0),     # 4, seems to be a bug in NumPy, as it returns (inf+nanj), vs np.exp(complex(np.inf, 0.0))-1.0 == (inf+0j)
    
    complex(-np.inf, 1.0),    # 5
    complex(-1.0, 0.0),       # 5
    
    complex(np.inf, 1.0),     # 6
    complex(np.inf, np.inf),  # 6
    
    complex(-np.inf, np.inf), # 7
    complex(-1.0, 0.0),       # 7, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.inf))-1.0 == (-1+0j)
    
    complex(np.inf, np.inf),  # 8
    complex(np.inf, np.nan),  # 8, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.inf))-1.0 == (inf+nanj)
    
    complex(-np.inf, np.nan), # 9
    complex(-1.0, 0.0),       # 9, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.nan))-1.0 == (-1+0j)
    
    complex(np.inf, np.nan),  # 10
    complex(np.inf, np.nan),  # 10, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.nan))-1.0 == (inf+nanj)
    
    complex(np.nan, 0.0),     # 11
    complex(np.nan, 0.0),     # 11, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.nan, 0.0))-1.0 == (nan+0j)
    
    complex(np.nan, 1.0),     # 12
    complex(np.nan, np.nan),  # 12
    
    complex(np.nan, np.nan),  # 13
    complex(np.nan, np.nan)   # 13
]

for i in range(len(values)//2):
    j = i * 2
    v = values[j]
    e = values[j+1]
    actual = np.expm1(v)
    print('Value: {value}'.format(value=str(v)))
    print('Actual: {actual}'.format(actual=str(actual)))
    print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
    print('Expected: {expected}'.format(expected=str(e)))
    print('Equal: {is_equal}'.format(is_equal=str(is_equal(actual, e))))
    print('\n')
Value: 0j
Actual: 0j
Naive: 0j
Expected: 0j
Equal: True


Value: (-0+0j)
Actual: (-0+0j)
Naive: 0j
Expected: (-0+0j)
Equal: True


/path/to/cexpm1.py:113: RuntimeWarning: invalid value encountered in expm1
  actual = np.expm1(v)
Value: (1+infj)
Actual: (nan+nanj)
/path/to/cexpm1.py:116: RuntimeWarning: invalid value encountered in exp
  print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (1+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (inf+0j)
Actual: (inf+nanj)
Naive: (inf+0j)
Expected: (inf+0j)
Equal: False


Value: (-inf+1j)
Actual: (-1+0j)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: True


Value: (inf+1j)
Actual: (inf+infj)
Naive: (inf+infj)
Expected: (inf+infj)
Equal: True


Value: (-inf+infj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False


Value: (inf+infj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False


Value: (-inf+nanj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False


Value: (inf+nanj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False


Value: (nan+0j)
Actual: (nan+nanj)
Naive: (nan+0j)
Expected: (nan+0j)
Equal: False


Value: (nan+1j)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (nan+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True

Notes

  • NumPy currently fails for ~~4~~ 6 complex number special cases. It's behavior is inconsistent with np.exp(z)-1, as documented in the script above.

kgryte avatar Jun 13 '22 08:06 kgryte