RatInABox icon indicating copy to clipboard operation
RatInABox copied to clipboard

Gridcells gridscales are larger than inputs.

Open charlesdgburns opened this issue 10 months ago • 2 comments

Hiya - I am working on some code to estimate grid scales and was borrowing code from this library to plot idealised gridcells with a rectified cosine. I noticed, however, that there was deviation in the gridscale input and the actual distance between peaks - I haven't boiled down exactly why this happens. I speculate it is related to width of cosine waves and interference patterns.

A grid cell which I intended to give a scale of 20 ended up with a scale of 23, et.c. as below:

Image

The hotfix is to divide the gridscales with a magic number 1.15 (not ideal) before tiling the environment and summing cosine waves: This could be implemented somewhere here.

Image

Below is some code to reproduce this behaviour:

# Using Tom George's code for rectified cosine grid cell model

import numpy as np
import matplotlib.pyplot as plt

def rotate(vector, theta):
    """Rotates a vector anticlockwise by angle theta."""
    R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    vector_new = np.dot(R, vector)
    return vector_new

def get_vectors_between(pos1, pos2):
    """Calculates vectors between two sets of positions."""
    pos1_ = pos1.reshape(-1, 1, pos1.shape[-1])
    pos2_ = pos2.reshape(1, -1, pos2.shape[-1])
    pos1 = np.repeat(pos1_, pos2_.shape[1], axis=1)
    pos2 = np.repeat(pos2_, pos1_.shape[0], axis=0)
    vectors = pos1 - pos2
    return vectors

def get_flattened_coords(N):
    """Generates flattened coordinates for a square meshgrid."""
    x = np.arange(N)
    y = np.arange(N)
    xv, yv = np.meshgrid(x, y)
    return np.stack((xv.flatten(), yv.flatten()), axis=1)

# --- Customizable parameters ---
N = 100  # Size of the grid
gridscales = np.array([10, 20, 30, 40])  # Grid scales for each neuron
phis = [0, 0, 0,0] # Orientations (in radians) for each neuron
# --- End of customizable parameters ---


n_neurons = len(gridscales)
phase_offsets = np.ones(shape=(n_neurons, 2)) * N/2  # Centered phase offsets

width_ratio = 4 / (3 * np.sqrt(3))
w = []
for i in range(n_neurons):
    w1 = np.array([1.0, 0.0])
    w1 = rotate(w1, np.pi/6+phis[i]) # Apply orientation here, such that baseline has a peak due east
    w2 = rotate(w1, np.pi / 3)
    w3 = rotate(w1, 2 * np.pi / 3)
    w.append(np.array([w1, w2, w3]))
w = np.array(w)

pos = get_flattened_coords(N)
origin = np.ones([n_neurons, 2]) * N/2
vecs = get_vectors_between(origin, pos)

# Tile parameters for efficient calculation
w1 = np.tile(np.expand_dims(w[:, 0, :], axis=1), reps=(1, pos.shape[0], 1))
w2 = np.tile(np.expand_dims(w[:, 1, :], axis=1), reps=(1, pos.shape[0], 1))
w3 = np.tile(np.expand_dims(w[:, 2, :], axis=1), reps=(1, pos.shape[0], 1))

adjusted_gridscales = gridscales/(1.15) # THIS IS AN APPROXIMATE FIX FOR GRID SCALE CHANGING.

tiled_gridscales = np.tile(np.expand_dims(adjusted_gridscales, axis=1), reps=(1, pos.shape[0]))


phi_1 = ((2 * np.pi) / tiled_gridscales) * (vecs * w1).sum(axis=-1)
phi_2 = ((2 * np.pi) / tiled_gridscales) * (vecs * w2).sum(axis=-1)
phi_3 = ((2 * np.pi) / tiled_gridscales) * (vecs * w3).sum(axis=-1)

firingrate = (1 / 3) * (np.cos(phi_1) + np.cos(phi_2) + np.cos(phi_3))

# ... (rest of your code for firing rate calculation and plotting) ...

#calculate the firing rate at the width fraction then shift, scale and rectify at the level
a, b, c = np.array([1,0])@np.array([1,0]), np.array([np.cos(np.pi/3),np.sin(np.pi/3)])@np.array([1,0]), np.array([np.cos(np.pi/3),-np.sin(np.pi/3)])@np.array([1,0])
firing_rate_at_full_width = (1 / 3) * (np.cos(np.pi*width_ratio*a) +
                              np.cos(np.pi*width_ratio*b) +
                              np.cos(np.pi*width_ratio*c))
firing_rate_at_full_width = (1 / 3) * (2*np.cos(np.sqrt(3)*np.pi*width_ratio/2) + 1)
firingrate -= firing_rate_at_full_width
firingrate /= (1 - firing_rate_at_full_width)
firingrate[firingrate < 0] = 0

# Plotting
fig, ax = plt.subplots(1, len(gridscales), figsize=(12,4))
for i, each_cell in enumerate(gridscales):
    ax[i].imshow(firingrate[i].reshape(N, N), cmap='jet', extent=[0, N, 0, N])
    ax[i].set_title(f"Grid scale: {each_cell} \n Orientation: {phis[i]:.2f} rad")
    ax[i].axis('off')

plt.tight_layout()
plt.show()

## Computing and reporting difernences in scale
from skimage.feature import peak_local_max
for i in range(firingrate.shape[0]):
  peaks = peak_local_max(firingrate[i].reshape(N,N))
  peaks = peaks -[N/2,N/2]
  sorted_sizes = np.sort(np.linalg.norm(peaks,axis=1))
  difference = np.mean(sorted_sizes[1:7]-gridscales[i])
  print(f'Mean difference ~{round(difference,3)} from intended scale {gridscales[i]}')
  print(f'Distance from centre for inner six peaks:\n {sorted_sizes[1:7]} \n ----')

charlesdgburns avatar Mar 09 '25 19:03 charlesdgburns

Hi Charles! Thanks for bringing this to my attention. I figured out what it is, and it's a surprisingly basic error😅:

Three cosines of wavelength $\lambda$ constructively interfere on the vertices of a regular hexagon with small radius (centre to midpoint of any edge) $\lambda$ and long-radius (centre to any vertex, ie gridscale) $\frac{2}{\sqrt{3}}\lambda$. Thus the magic number is actually $\frac{2}{\sqrt{3}} \approx 1.15$. Currently the parameter "gridscale" which users enter is (in 2D) the distance to the nearest edge, not the gridscale. Oops!

I'm really quite surprised this hasn't been spotted before by others/me/reviewers etc. but glad it has.

The fix is obviously pretty simple but as I was writing it I realised the grid cell code may benefit from a more comprehensive restructure to make it all clearer. It also now differs from the formula in the methods so some docs should be updated too. I won't have time to do this right now but I'll leave this issue open until I do.

Does this unblock you for the moment?

Image

TomGeorge1234 avatar Mar 12 '25 04:03 TomGeorge1234

Thanks for spotting the source of the error so simply!

Definitely helps me understand idealised grid cell models better and is a precise fix to the issue.

Glad I could help noticing it in this code.

Cheers, Charles

charlesdgburns avatar Mar 13 '25 10:03 charlesdgburns