tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

TAPNext online tracking problems

Open CrazyCoder76 opened this issue 9 months ago • 4 comments

Dear team,

I am implementing online tracking in TAPNext and I am facing some problems.

First, after running several frames, the tracking point are out of target. Here, tracking_state.query_points is always the same query_points that I input in the first run. Is this correct or should I set query_points to the updated query_points every time during the run?

Second, my current frame resolution is 256x256. Can I increase the resolution?

Third, I want to export the model to onnx. Here, tracking_state should be the input and output of the onnx model. However, currently it is a dictionary as you know. How can I parse it back into a dictionary when setting tracking_state as a tensor as input to the onnx model?

I would really appreciate your help. Thank you.

CrazyCoder76 avatar Apr 25 '25 09:04 CrazyCoder76

Hi @CrazyCoder76 !

For the tracking error issue, if this happens after more than ~150 frames (with dense grid) - this is a known limitation of TAPNext (please see the limitations section in the paper or website https://tap-next.github.io/). A partial solution is to reduce the number of tracked points - the error will not vanish but it will be able to track for longer.

For the tracking state, yes the query_points field of tracking state is meant to be a fixed size tensor. However, this does not mean all your point queries should be queried at frame 0. The way online point tracking works is this. The model has all queries (at potentially arbitrary frames) but it only provides them to the model at the corresponding frames of the video. Let's say you have 5 query points: 3 queries at frame 0 and 2 queries at, say, frames 5 and 10. When you start online tracking the model does not use the last two queries at frames 0,1,2,3,4 and then at frame 5 it will use the first "late" query (so using 4 queries overall) and at frame 10 it will start using all queries. So at frame 0 you do not need to provide the XY-coordinate for all queries, but you need to provide the total number of queries so that the model allocates hidden states for them. If at frame 0 you do not know the XY-coordinate of a query at frame 5, I recommend just setting it to [0., 0.] and then update it when it is time to query that point.

For the resolution, it should be working fine with higher resolution, you need to interpolate position embedding. We will provide the code for that in a couple of days.

Regardging your ONNX question - can you please clarify your question? What exactly fails when exporting to onnx?

artemZholus avatar Apr 25 '25 10:04 artemZholus

Hi @artemZholus Thanks for the quick response. While exporting onnx, tracking_state is TAPNextTrackingState type. But you know, onnx input should be a tensor. How can I convert TAPNextTrackingState to tensor?

And the last commit throws an error.

Traceback (most recent call last): File "E:\research\tapnet\python_tapnext.py", line 89, in _, _, _, tracking_state = model.forward(video=frame[None, None], query_points=query_points[None, None]) File "E:\research\tapnet\tapnet\tapnext\tapnext_torch.py", line 279, in forward x, ssm_cache_layer = blk( File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "E:\research\tapnet\tapnet\tapnext\tapnext_torch.py", line 69, in forward x, ssm_cache = self.ssm_block(x, cache, use_linear_scan=use_linear_scan) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "E:\research\tapnet\tapnet\tapnext\tapnext_lru_modules.py", line 534, in forward x, cache = self.recurrent_block(inputs_normalized, cache, use_linear_scan) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "E:\research\tapnet\tapnet\tapnext\tapnext_lru_modules.py", line 400, in forward x, rg_lru_state = self.rg_lru( File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "E:\research\tapnet.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "E:\research\tapnet\tapnet\tapnext\tapnext_lru_modules.py", line 218, in forward multiplier = SqrtBoundDerivative.apply(1 - a_square) File "E:\research\tapnet.venv\lib\site-packages\torch\autograd\function.py", line 574, in apply return super().apply(*args, **kwargs) # type: ignore[misc] TypeError: SqrtBoundDerivative.forward() takes 2 positional arguments but 3 were given

Thanks

CrazyCoder76 avatar Apr 25 '25 12:04 CrazyCoder76

Could you please let me know if it is possible to convert the online tracking model to onnx? Thank you.

CrazyCoder76 avatar Apr 25 '25 14:04 CrazyCoder76

Hi @CrazyCoder76 . We just submitted the fix for the TypeError you got. Now the public pytorch colab works fine. For the onnx issue, we will likely get back to you in next week with a solution.

artemZholus avatar Apr 26 '25 22:04 artemZholus

For the onnx issue, we will likely get back to you in next week with a solution

It would be nice if it could be exported torch.export without extra hacks.

bhack avatar May 04 '25 08:05 bhack

Hi @bhack , torch.export now works after the most recent commit. Use example:

model = TAPNext(image_size=(256, 256))
model.eval() # this is important to ensure the model is in inference mode so we use linear scan which is faster for inference.
with torch.no_grad():
  tapnext_init = export(model, args=(), kwargs=dict(video=batch['video'][:, :4], query_points=batch['query_points']), )
  tracks, track_logits, vis_logits, state = model(video=batch['video'][:, :4], query_points=batch['query_points'])
  tapnext_track = export(model, args=(), kwargs=dict(video=batch['video'][:, 4:8], state=state), )

tapnext_init and tapnext_track will be two torch exported programs as the model uses different logic for initializing the tracker with query points and the tracking itself. Both methods are torch.compile-able (to reproduce the speed from the paper, please use torch.compile).

Note that TAPNext can track both per-frame and per-window. The first means, we pass one frame to the model, plus the previous state, it outputs the prediction for that frame and a new state. The second means, we pass a chunk of consecutive frames to the model, plus the state before the first frame in the chunk and the model outputs the predictions for the entire chunk and the state after the last frame in the chunk. The motivation for the second approach is that it has even higher FPS at the cost of significantly bigger latency.

Note also that it is not possible to do torch.export with variable size tensors because the model uses shape-dependent control flow inside of it (e.g. linear scan over the sequence length).

artemZholus avatar May 08 '25 15:05 artemZholus

Note also that it is not possible to do torch.export with variable size tensors because the model uses shape-dependent control flow inside of it (e.g. linear scan over the sequence length).

This is the problem I found exporting on my side the the last day. Is the main issue the python loop in pscan?

Cause I think that at least points query need to be dynamic (I am testing with aoti compile the exported program).

bhack avatar May 08 '25 16:05 bhack

The main issue is not the python loop (there is a python loop regardless of what you do - parallel scan or linear scan). The issue I think is how pytorch's compiler processes variable dimensions. In TAPNext, there is interleaved space time processing, which means the initial tensor of shape [b, t, h * w + q, c] will be reshaped as [b * (h * w + q), t, c] (for SSM) then as [b * t, h * w + q, c] (for ViT), and then repeat. The problem of such processing is that once you declare e.g. the batch dim as variable dimension, pytorch does not know how to combine a constant dimension with a variable dimension. For the same reason we can't export with variable q or t. But I am not an expert in torch.export so maybe there is another solution.

artemZholus avatar May 08 '25 16:05 artemZholus

@artemZholus I give you an update. Are you on the official pytorch slack?

bhack avatar May 08 '25 19:05 bhack

I managed to get a single AOTI compatible compilation to work for the per frame scenario after essentially flattening all the inputs to tensors (TAPNextTrackingState, and the RecurrentBlockCache) and dealing with the branching because of the cache=None options by passing down zeros for the cache for the first invocation and using torch.where() instead of the if conditions. For the einops.rearrange() calls, I had to add allow_ops_in_compiled_graph(). It's a little more tedious, but I didn't want to deal with two different AOTI compiled artifacts and juggling them.

icoderaven avatar May 08 '25 22:05 icoderaven

@icoderaven Are you able to export with the dynamic Dim query points min/max in the single frame scenario?

For sure a single export is the better.

bhack avatar May 08 '25 23:05 bhack

No, I did not go down the dynamic path since my usage scenario involves single frame 'real-time' videos always of the same dimensions

icoderaven avatar May 09 '25 05:05 icoderaven

@bhack I went back to try to utilise dynamic Dim query points and it was a little trickier than expected. For anyone else stumbling on it, I had to provide the cache size to be 1024 + query_point dimensions (the relationship for the expected size of the cache dims is 1024 + sqrt(num_query_pts).floor() ** 2, but we can't use sqrt on Dim objects). So, e.g. with the forward method using the flattened tensors it lloks like

query_points_dim = Dim("query_points", min=1, max=1024)
# The actual dimension of cache_dim is smaller but we cannot use sqrt
cache_dim = query_points_dim + 1024
dynamic_shapes = {
    "frame": {
        1: height_dim,
        2: width_dim,
    },
    "step": {},
    "query_points": {1: query_points_dim},
    "rg_cache": {1: cache_dim},
    "cv_cache": {1: cache_dim},
}

where for me

frame: Float[torch.Tensor, "num_channels height width"]
step: Int64[torch.Tensor, ""]
query_points: Float[torch.Tensor, "batch_size num_queries 3"]
rg_cache: Float[torch.Tensor, "cache_n cache_b cache_w"]
cv_cache: Float[torch.Tensor, "cache_n cache_b cache_c cache_w"]

icoderaven avatar May 29 '25 00:05 icoderaven

Has we removed the limit on the max query points with:

https://github.com/google-deepmind/tapnet/issues/147

I think it is not so much practical allocating always to the max.

How much resources are we wasting with this approach?

bhack avatar May 30 '25 14:05 bhack

My understanding is that the memory allocation for an AOT compiled artifact is performed at run time - the dynamic dimensions just provide the compiler hints to optimise the graph such that the compiled code can handle varying amounts of query points up until the specified max range.

icoderaven avatar Jun 07 '25 01:06 icoderaven

One extremely curious thing I have observed by the way, when dealing with the AOT compiled model with the released weights, is that the compiled model is incredibly sensitive to how the input images are resized to the 256x256 resolution.

The only way that I could get the AOT compiled model to return correctly tracks if I preprocessed the images via opencv2's resize method. None of the standard pytorch resize ops (nn.interpolate/kornia, torchvision, grid_sample) worked, despite trying bilinear or area based downsampling methods. All the resize ops worked fine with the non-AOT compiled model, or even with a @torch.compile'd model. I was wondering if you had any insights into this/have seen this before - was all the training data preprocessed using cv2.resize() and as a consequence the weights are somehow overfit/overlearnt to that distribution? (On pytorch 2.6)

The process I was following when I observed this puzzling behaviour is pad to square -> resize to 256x256 -> Scale to [-1,1]

icoderaven avatar Jun 07 '25 01:06 icoderaven

Hi @icoderaven ,

this is interesting, thanks for reporting. On my side, I actually tried to run tapnext at larger resolutions, and it pretty much did not work. The reason, I believe is that the model is overfitted to the values of positional embedding (which is a known issue in all vision transformers).

My implementation was to simply resize the positional embedding via torch nn.interpolate so this aligns with what you said.

By the way, I wonder if the same thing happens to you when you run the jax model? When we were converting jax to pytorch model, we found that the PT version sometimes produces very different numerical outputs despite the metrics remain the same. The reason they remain the same is that only the "occluded" coordinates change between jax and pytorch.

artemZholus avatar Jun 09 '25 16:06 artemZholus

Interesting! Good to know re: pos embeddings (it was on my backlog of things to try later). I haven't had a chance to try out the jax model and seems unlikely that I would be able to, sadly.

FWIW this was a tiny code block that I used to reliably repro the issue on my end for the resize op. (padded_img: Float[torch.Tensor, "Channel Dim Dim"])

    # This doesn't work with the AOT compiled model as input
    # resized_img: torch.Tensor = torch.nn.functional.interpolate(
    #     padded_img.unsqueeze(0), size=resized_shape, mode="area"
    # ).squeeze(0)
    resized_img = (
        torch.from_numpy(
            cv2.resize(padded_img.permute(1, 2, 0).cpu().numpy(), resized_shape, interpolation=cv2.INTER_AREA)
        )
        .permute(2, 0, 1)
        .cuda()
    )

icoderaven avatar Jun 09 '25 18:06 icoderaven

Hi @artemZholus, I'm trying to convert TAPNext to ONNX and I'm running into the same issue that @CrazyCoder76 had:

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: TAPNextTrackingState Steps to reproduce: use the public torch_tapnext_demo colab and run this cell at the end:

model.eval() # this is important to ensure the model is in inference mode so we use linear scan which is faster for inference.
model.cuda()
with torch.no_grad():
  tapnext_init = torch.onnx.export(model, args=(), kwargs=dict(video=batch['video'][:, :4], query_points=batch['query_points']), )
  tracks, track_logits, vis_logits, state = model(video=batch['video'][:, :4], query_points=batch['query_points'])
  tapnext_track = torch.onnx.export(model, args=(), kwargs=dict(video=batch['video'][:, 4:8], state=state), )

javirk avatar Aug 25 '25 08:08 javirk

Hi @CrazyCoder76 !

For the tracking error issue, if this happens after more than ~150 frames (with dense grid) - this is a known limitation of TAPNext (please see the limitations section in the paper or website https://tap-next.github.io/). A partial solution is to reduce the number of tracked points - the error will not vanish but it will be able to track for longer.

For the tracking state, yes the query_points field of tracking state is meant to be a fixed size tensor. However, this does not mean all your point queries should be queried at frame 0. The way online point tracking works is this. The model has all queries (at potentially arbitrary frames) but it only provides them to the model at the corresponding frames of the video. Let's say you have 5 query points: 3 queries at frame 0 and 2 queries at, say, frames 5 and 10. When you start online tracking the model does not use the last two queries at frames 0,1,2,3,4 and then at frame 5 it will use the first "late" query (so using 4 queries overall) and at frame 10 it will start using all queries. So at frame 0 you do not need to provide the XY-coordinate for all queries, but you need to provide the total number of queries so that the model allocates hidden states for them. If at frame 0 you do not know the XY-coordinate of a query at frame 5, I recommend just setting it to [0., 0.] and then update it when it is time to query that point.

For the resolution, it should be working fine with higher resolution, you need to interpolate position embedding. We will provide the code for that in a couple of days.

Regardging your ONNX question - can you please clarify your question? What exactly fails when exporting to onnx?

Hi, may I know if we can have the code for higher resolution images? Thanks!

TimSong412 avatar Sep 16 '25 04:09 TimSong412