Possible use of torch.multiprocessing
Consider the simulation loop in the Network.run() function:
# Simulate network activity for `time` timesteps.
for t in range(timesteps):
-> for l in self.layers:
# Update each layer of nodes.
if isinstance(self.layers[l], AbstractInput):
self.layers[l].step(inpts[l][t], self.dt)
else:
self.layers[l].step(inpts[l], self.dt)
# Clamp neurons to spike.
clamp = clamps.get(l, None)
if clamp is not None:
self.layers[l].s[clamp] = 1
# Run synapse updates.
-> for c in self.connections:
self.connections[c].update(
reward=reward, mask=masks.get(c, None), learning=self.learning
)
# Get input to all layers.
inpts.update(self.get_inputs())
# Record state variables of interest.
for m in self.monitors:
self.monitors[m].record()
# Re-normalize connections.
-> for c in self.connections:
self.connections[c].normalize()
Where I've marked a ->, there might be an opportunity to use torch.multiprocessing. Since we do updates at time t based on network state at time t-1, all Nodes / Connections updates can be performed with a separate process (thread?) at once. Letting k = no. of layers, m = no. of connections, given enough CPU / GPU resources, the loops marked with -> would have time complexity O(1) instead of O(k), O(m) in the number of layers and connections, respectively.
I think it'd be good to keep around two (?) multiprocessing.Pool objects around, one for Nodes objects and another for Connection objects. Instead of statements of the form:
for l in self.layers:
self.layers[l].step(...)
We might rewrite this as something like:
self.nodes_pool.map(Nodes.step, self.layers)
Here, nodes_pool is defined as an attribute in the Network constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).
This same idea can also be applied in the Network's reset() and get_inputs() functions.
@djsaunde any progress on this? I'll start looking into it, because I'm working with pretty small networks and GPUs won't give you much of an advantage there. This seems to be the way to speed up in that case.
@Huizerd nope, just an idea we had some time ago. I'm not sure that it will speed things up, but it might be worth a shot. Let me know if you need any help.
Check out this branch for a start on the multiprocessing work (I'm pretty sure it fails as-is). It'll need to be fast-forwarded to the current state of the master branch.