array-api
array-api copied to clipboard
Add complex number support to `expm1`
This PR
- adds complex number support to
expm1by documenting special cases. The exponential function is an entire function in the complex plane. Thus, the function does not have branch cuts. - updates the input and output array data types to be any floating-point data type, not just real-valued floating-point data types.
- derives special cases from C99 exp and tested against NumPy (script found below).
import numpy as np
import math
def is_equal_float(x, y):
"""Test whether two floating-point numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : float
First input number.
y : float
Second input number.
Returns
-------
bool
Boolean indicating whether two floating-point numbers are equal.
Examples
--------
>>> is_equal_float(0.0, -0.0)
False
>>> is_equal_float(-0.0, -0.0)
True
"""
# Handle +-0:
if x == 0.0 and y == 0.0:
return math.copysign(1.0, x) == math.copysign(1.0, y)
# Handle NaNs:
if x != x:
return y != y
# Everything else, including infinities:
return x == y
def is_equal(x, y):
"""Test whether two complex numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : complex
First input number.
y : complex
Second input number.
Returns
-------
bool
Boolean indicating whether two complex numbers are equal.
Examples
--------
>>> import numpy as np
>>> is_equal(complex(np.nan, np.nan), complex(np.nan, np.nan))
True
"""
return is_equal_float(x.real, y.real) and is_equal_float(x.imag, y.imag)
# Strided array consisting of input values and expected values:
values = [
complex(0.0, 0.0), # 0
complex(0.0, 0.0), # 0
complex(-0.0, 0.0), # 1
complex(-0.0, 0.0), # 1
complex(1.0, np.inf), # 2
complex(np.nan, np.nan), # 2
complex(1.0, np.nan), # 3
complex(np.nan, np.nan), # 3
complex(np.inf, 0.0), # 4
complex(np.inf, 0.0), # 4, seems to be a bug in NumPy, as it returns (inf+nanj), vs np.exp(complex(np.inf, 0.0))-1.0 == (inf+0j)
complex(-np.inf, 1.0), # 5
complex(-1.0, 0.0), # 5
complex(np.inf, 1.0), # 6
complex(np.inf, np.inf), # 6
complex(-np.inf, np.inf), # 7
complex(-1.0, 0.0), # 7, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.inf))-1.0 == (-1+0j)
complex(np.inf, np.inf), # 8
complex(np.inf, np.nan), # 8, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.inf))-1.0 == (inf+nanj)
complex(-np.inf, np.nan), # 9
complex(-1.0, 0.0), # 9, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.nan))-1.0 == (-1+0j)
complex(np.inf, np.nan), # 10
complex(np.inf, np.nan), # 10, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.nan))-1.0 == (inf+nanj)
complex(np.nan, 0.0), # 11
complex(np.nan, 0.0), # 11, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.nan, 0.0))-1.0 == (nan+0j)
complex(np.nan, 1.0), # 12
complex(np.nan, np.nan), # 12
complex(np.nan, np.nan), # 13
complex(np.nan, np.nan) # 13
]
for i in range(len(values)//2):
j = i * 2
v = values[j]
e = values[j+1]
actual = np.expm1(v)
print('Value: {value}'.format(value=str(v)))
print('Actual: {actual}'.format(actual=str(actual)))
print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
print('Expected: {expected}'.format(expected=str(e)))
print('Equal: {is_equal}'.format(is_equal=str(is_equal(actual, e))))
print('\n')
Value: 0j
Actual: 0j
Naive: 0j
Expected: 0j
Equal: True
Value: (-0+0j)
Actual: (-0+0j)
Naive: 0j
Expected: (-0+0j)
Equal: True
/path/to/cexpm1.py:113: RuntimeWarning: invalid value encountered in expm1
actual = np.expm1(v)
Value: (1+infj)
Actual: (nan+nanj)
/path/to/cexpm1.py:116: RuntimeWarning: invalid value encountered in exp
print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Value: (1+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Value: (inf+0j)
Actual: (inf+nanj)
Naive: (inf+0j)
Expected: (inf+0j)
Equal: False
Value: (-inf+1j)
Actual: (-1+0j)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: True
Value: (inf+1j)
Actual: (inf+infj)
Naive: (inf+infj)
Expected: (inf+infj)
Equal: True
Value: (-inf+infj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False
Value: (inf+infj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False
Value: (-inf+nanj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False
Value: (inf+nanj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False
Value: (nan+0j)
Actual: (nan+nanj)
Naive: (nan+0j)
Expected: (nan+0j)
Equal: False
Value: (nan+1j)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Value: (nan+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Notes
- NumPy currently fails for ~~4~~ 6 complex number special cases. It's behavior is inconsistent with
np.exp(z)-1, as documented in the script above.