array-api
array-api copied to clipboard
Add complex number support to `tanh`
This PR
- adds complex number support to
tanhby documenting special cases. The hyperbolic tangent is an analytical function in the complex plane and has no branch cuts. - updates the input and output array data types to be any floating-point data type, not just real-valued floating-point data types.
- derives special cases from C99 and tested against NumPy (script found below).
- adds a warning concerning two special cases where results may differ depending on which C version an array library compiles against (for those libraries which are C compiled).
import numpy as np
import math
def is_equal_float(x, y):
"""Test whether two floating-point numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : float
First input number.
y : float
Second input number.
Returns
-------
bool
Boolean indicating whether two floating-point numbers are equal.
Examples
--------
>>> is_equal_float(0.0, -0.0)
False
>>> is_equal_float(-0.0, -0.0)
True
"""
# Handle +-0:
if x == 0.0 and y == 0.0:
return math.copysign(1.0, x) == math.copysign(1.0, y)
# Handle NaNs:
if x != x:
return y != y
# Everything else, including infinities:
return x == y
def is_equal(x, y):
"""Test whether two complex numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : complex
First input number.
y : complex
Second input number.
Returns
-------
bool
Boolean indicating whether two complex numbers are equal.
Examples
--------
>>> import numpy as np
>>> is_equal(complex(np.nan, np.nan), complex(np.nan, np.nan))
True
"""
return is_equal_float(x.real, y.real) and is_equal_float(x.imag, y.imag)
# Strided array consisting of input values and expected values:
values = [
complex(0.0, 0.0), # 0
complex(0.0, 0.0), # 0
complex(-0.0, -0.0), # 1
complex(-0.0, -0.0), # 1
complex(1.0, np.inf), # 2
complex(np.nan, np.nan), # 2
complex(0.0, np.inf), # 3
complex(0.0, np.nan), # 3, seems to be a bug in NumPy, which returns `NaN + NaN j`; however, this does match old C99 behavior (see https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1892.htm#dr_471)
complex(1.0, np.nan), # 4
complex(np.nan, np.nan), # 4
complex(0.0, np.nan), # 5
complex(0.0, np.nan), # 5 seems to be a bug in NumPy, which returns `NaN + NaN j`; however, this does match old C99 behavior (see https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1892.htm#dr_471)
complex(np.inf, 1.0), # 6
complex(1.0, 0.0), # 6
complex(np.inf, np.inf), # 7
complex(1.0, -0.0), # 7
complex(np.inf, np.nan), # 8
complex(1.0, 0.0), # 8
complex(np.nan, 0.0), # 9
complex(np.nan, 0.0), # 9
complex(np.nan, 1.0), # 10
complex(np.nan, np.nan), # 10
complex(np.nan, np.nan), # 11
complex(np.nan, np.nan) # 11
]
for i in range(len(values)//2):
j = i * 2
v = values[j]
e = values[j+1]
actual = np.tanh(v)
print('Index: {index}'.format(index=str(i)))
print('Value: {value}'.format(value=str(v)))
print('Actual: {actual}'.format(actual=str(actual)))
print('Expected: {expected}'.format(expected=str(e)))
print('Equal: {is_equal}'.format(is_equal=str(is_equal(actual, e))))
print('\n')
Index: 0
Value: 0j
Actual: 0j
Expected: 0j
Equal: True
Index: 1
Value: (-0-0j)
Actual: (-0-0j)
Expected: (-0-0j)
Equal: True
Index: 2
Value: (1+infj)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Index: 3
Value: infj
Actual: (nan+nanj)
Expected: nanj
Equal: False
Index: 4
Value: (1+nanj)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Index: 5
Value: nanj
Actual: (nan+nanj)
Expected: nanj
Equal: False
Index: 6
Value: (inf+1j)
Actual: (1+0j)
Expected: (1+0j)
Equal: True
/path/to/ctanh.py:107: RuntimeWarning: invalid value encountered in tanh
actual = np.tanh(v)
Index: 7
Value: (inf+infj)
Actual: (1-0j)
Expected: (1-0j)
Equal: True
Index: 8
Value: (inf+nanj)
Actual: (1+0j)
Expected: (1+0j)
Equal: True
Index: 9
Value: (nan+0j)
Actual: (nan+0j)
Expected: (nan+0j)
Equal: True
Index: 10
Value: (nan+1j)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Index: 11
Value: (nan+nanj)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Notes
- NumPy deviates from special behavior in two cases:
0 + inf jand0 + NaN j. NumPy adheres to old C99 behavior; however, this was corrected in the C standard in 2014.