VGGT

Overview

前馈神经网络:

从一张或多张图片中预测3d场景的关键属性:

  1. 相机参数

  2. 深度图

  3. 点云

  4. 3d轨迹

优点:

不依赖传统优化,快,好,不需要先验,从图像序列中一次性输出3d场景信息。

Method

Basic Definition

VGGT的输入是张同一场景不同角度的RGB图片序列,

VGGT transformer是一个映射 ,将这个图片序列映射到了场景3d属性序列中:

对于某一张(一帧)图片而言:

相机参数(内参 + 外参):

  • :相机旋转角
  • :平移向量
  • :视场角 FOV (),假设中心在图像几何中心(内参)

深度图:

  • 设第 张图像为 ,其像素网格定义为:

  • 每个像素的位置用二维坐标 表示。

  • 模型为每张输入图像 预测一个对应的深度图 。其中,每个像素位置 的预测深度值记作 ,满足:

  • 表示:第 张图像中,像素点 所“看到”的真实世界中的三维点距离相机的深度;

3d 点云图:

  • 一张图上面 个点
  • 每一个点对应一个三维坐标
  • 沿用上面说的深度图的定义:,对应一个三维点。

3d点轨迹。

  • 给定一张图像中的某个像素,找到它在其它所有图像帧中的对应位置(同一3d点在不同2d图像中的2d坐标)

  • VGGT 的 transformer 主干网络并不直接输出每个点在其他帧的对应坐标,而为每一帧图像输出一个 dense feature grid:

  • 另外设计了一个模块用来追踪,如输入个点,输出个序列,每一个序列里面包含着个由初始点追踪得到的点。

  • 进行联合的端到端训练。

Frame Order

第一帧作为参考帧,顺序不能变换。

其他帧顺序随意。

Over-complete Predictions

VGGT 的多个输出是冗余的,但冗余是有意设计的,能提升性能。

  • 有了点云,可以用PnP算法估计相机参数

  • 有了相机参数和点云,也可以推算出深度

  • 在训练时显式地监督所有这些子任务,有助于网络学习更一致、更准确的几何表征

Feature Backbone

image

1. 图像编码器

采用进行图片编码。将图像编码成了 ,整个序列为

vggt/vggt/models/aggregator.py__build_patch_embed__函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}

self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)

可以自己调用,也可以调用预训练的编码器。传入(B*S, C, H, W)(此处的C是rgb三色)的图片序列之后,经由encoder,最后得到了(B*S, P, C)(此处的C是经过编码嵌入之后token的维度)。

1
2
3
# Reshape to [B*S, C, H, W] for patch embedding
images = images.view(B * S, C_in, H, W)
patch_tokens = self.patch_embed(images)

2.添加特殊token

该函数slice_expand_and_flatten是 VGGT 模型中用于处理特殊 tokens函数。其目标是将形如(1, 2, X, C)的特殊token扩展成适用于B个batch、每个batch有S帧图像的token序列,输出形状为(B*S, X, C),便于与patch tokens一起送入Transformer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def slice_expand_and_flatten(token_tensor, B, S):
"""
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
1) Uses the first position (index=0) for the first frame only
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
3) Expands both to match batch size B
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
followed by (S-1) second-position tokens
5) Flattens to (B*S, X, C) for processing

Returns:
torch.Tensor: Processed tokens with shape (B*S, X, C)
"""
# token_tensor : (1, 2, X, C)
# 提取第一帧“query token”,复制 B 次
# (1, 1, X, C) => (B, 1, X, C)
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
# 提取第二帧“other token”,复制B * (S - 1)次
# (1, 1, X, C) => (B, S-1, X, C)
"""
注意这里没有 .clone() 或 .repeat(),意味着
在计算图上,所有 batch 实际上都共享同一个token
参数的 memory 引用,只是逻辑上“看起来”每个batch
有独立的token。
"""
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
# Concatenate
# (B, 1, X, C) + (B, S-1, X, C) => (B, S, X, C)
combined = torch.cat([query, others], dim=1)
# Flatten, 与 patch token 对齐
# (B, S, X, C) => shape (B*S, X, C)
combined = combined.view(B * S, *combined.shape[2:])
return combined

(B, S, X, C)中的X取1或者4.

Token 类型 维度 作用
Camera token 用于预测该帧相机的内外参
Register tokens 帮助注意力学习帧内上下文

从随机矩阵生成camera token和register toke,最后把三个token拼接在一起。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Note: We have two camera tokens, one for the first frame and one for the rest
# The same applies for register tokens
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))

# Initialize parameters with small values
nn.init.normal_(self.camera_token, std=1e-6)
nn.init.normal_(self.register_token, std=1e-6)

# Expand camera and register tokens to match batch size and sequence length
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
register_token = slice_expand_and_flatten(self.register_token, B, S)

# Concatenate special tokens with patch tokens
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)

两种camera token都是可以训练的,但是第一种camera token只参与transformer的训练,使用独立的一组而第二种camera token既参与transformer的训练也参与camera head的训练,且剩余的(S - 1)组token是共享的

register token则是为了捕捉全局上下文,引导注意力对齐。

3.交替注意力机制:

帧Transformer模块(共有24个):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
self.frame_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)

全局Transformer模块(共24个):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
self.global_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)

aa_block_size:一组里面有几个block(默认是1),aa_block_num:共有几组(默认是24)

1
self.aa_block_num = self.depth // self.aa_block_size

_process_frame_attention_process_global_attention两个函数中,处理tokens传入Transformer块的形状,如果是_process_frame_attention,传入的是(B*S, P, C);如果是_process_global_attention,传入的是(B, S*P, C),显然前者更为关注一帧之内的情况,而后者更为关注全局的情况。

然后根据aa_block_size:选择一组之内重复的块的个数(默认是1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)

if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)

intermediates = []

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))

return tokens, frame_idx, intermediates

def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)

if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)

intermediates = []

# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.training:
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
else:
tokens = self.global_blocks[global_idx](tokens, pos=pos)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))

return tokens, global_idx, intermediates

pos是传Transformer块的一个重要参数,在如果是_process_frame_attention,传入的是(B*S, P, 2);如果是_process_global_attention,传入的是(B, S*P, 2),表示的是patch在这个图片中的空间位置坐标。

pos的生成过程中,position_getter函数根据patch的数量,生成一个(B * S, P, 2)的位置编码。

我们注意到,特殊token不应该被添加位置编码。所以设立pos_special将特殊token所对应的位置编码设为0.同时将原来的位置编码都加一,将(0, 0)对应的位置编码留了出来。

1
2
3
4
5
6
7
8
9
10
pos = None
if self.rope is not None:
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
# (self.patch_start_idx = 1 + num_register_tokens)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)

根据aa_order,frame_transformerglobal_transformer交替进行,同时将24层交替frame_transformerglobal_transformer的结果拼接存储起来成(B, S, P, 2C)(default),用于以后的预测与分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for _ in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S, P, C, frame_idx, pos=pos
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S, P, C, global_idx, pos=pos
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")

for i in range(len(frame_intermediates)):
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)

4.预测头:

这个是把预测头之前的模型架构作为一个整体模块,输出之后的结果。

aggregated_tokens_list是24轮Transformer,每一轮的global和frame输出拼接之后的结果。也就是说,这是一个24个元素的数组,每一个元素是(B, S, P, 2C)的Tensor。

patch_start_idx是patch_token开始的索引,默认是5。

1
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
Camera head

第一帧的Camera位态默认是,,特殊的camera token和register token能够帮助Transformer预测第一个相机。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
"""
Forward pass to predict camera parameters.

Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.

Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# 取出最后的token (B, S, P, 2*C)
tokens = aggregated_tokens_list[-1]

# 取出其中的camera token (B, S, 2*C)
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)

pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list

def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
"""
Iteratively refine camera pose predictions.

Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
num_iterations (int): Number of refinement iterations.

Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []

for _ in range(num_iterations):
# embed_pose 是一个线性层
# 第一轮时没有前一轮的姿态预测,因此使用“空”姿态作为输入
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
#对之前预测的 pose embedding 进行 detach(断开反向传播),避免跨时间步传播梯度。
#这是典型的 “预测 - 停梯度 - 再预测” 做法,防止梯度在循环中不断累积.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)

# poseLN_modulation是一个mlp
# 把mlp的输出分成三份成偏移量,缩放量,门控系数
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)

# 做normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
# 残差链接
pose_tokens_modulated = pose_tokens_modulated + pose_tokens

# 输入四层Transformer
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))

if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta

# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)

return pred_pose_enc_list

def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
return x * (1 + scale) + shift
Dense Prediction

:第i张图片,Transformer最后输出的结果token。

先通过DPT转换成了dense feature map:

使用3*3的卷积核,分别将映射成了一个深度图和一个点云,和用于3d point track的

即: depth map:

point map:

tracking feature map: .

同时,网络分别预测了针对于depth mappoint mapaleatoric uncertainty map and

Tracking

首先VGGT会经由Dense Prediction的部分给出每帧的trackiing feature map:

接着给定一组图像 和一个查询点 (例如第一张图像中的像素坐标),而是通过一个额外的跟踪模块,使用这些特征图来推断轨迹。

的实现方式基于CoTracker的架构,流程如下:

  • 查询点特征采样:在查询图像(一般是第一张图像)中,使用双线性插值从特征图 中采样得到查询点 的特征向量。
  • 相关性计算:将该特征与其他图像的特征图 做内积,计算得到每一帧图像的相关性热图(correlation map)。
  • 位置预测:使用 Transformer 对所有相关性图进行联合建模,输出该点在每张图像中的预测位置

完整的 tracking 函数可以形式化为:

查询图像可为任意一帧,不依赖时间顺序,因此适用于图像对、视频序列等多种输入类型

Tracking模块在训练时与主干VGGT端到端联合训练,梯度可回传至主网络。

Training

1.Training loss

我们使用多任务损失对 VGGT 模型 f 进行端到端训练:

通过经验发现,相机损失 ( )、深度损失 ( ) 和点图损失 ( ) 的范围相似,不需要相互加权。跟踪损失 ( ) 则通过一个因子 进行降权。

相机损失

相机损失 用于监督相机参数 g:

预测的相机参数:

真实值

Huber loss:

深度损失与点图损失

深度损失实现了 aleatoric-uncertainty loss,即用预测的不确定性图 来加权预测深度 与真实深度之间的差异。

公式为:

其中 是通道广播的逐元素乘积。点图损失的定义类似,但使用的是点图的不确定性

跟踪损失

跟踪损失由下式给出:

此处,外层求和遍历查询图像 中的所有真实查询点 在图像 中的真实对应点,而 是通过应用跟踪模块 得到的相应预测值

此外,模型遵循CoTracker2的方法,应用了一个可见性损失(二元交叉熵)来估计一个点在给定帧中是否可见。

2.Ground Truth Coordinate Normalization

在训练过程中:

首先,所有3D量都被转换到第一个相机的坐标系中,作为世界参考坐标系。

接着,计算点图中所有3D点到原点的平均欧几里得距离。这个平均距离被用作尺度因子来归一化相机平移向量、点图和深度图.

Tasks

1.Camera Pose Estimation (相机位姿估计)

相机位姿估计的目标是确定每张图像拍摄时相机的精确三维位置(translation, 平移)和方向(rotation, 旋转)。简单来说,就是回答“相机在哪里”以及“相机朝向哪个方向”这两个问题。

2.Multi-view Depth Estimation (多视角深度估计)

多视角深度估计的目标是为场景中的每个可见像素(或一个稀疏点集合)预测其到相机的距离(即深度)。这有助于构建场景的3D结构。

3.Dense Point Cloud Reconstruction (密集点云重建)

密集点云重建旨在从一组2D图像中生成一个包含大量3D点的集合,这些点精确地表示了场景的几何形状。它是场景3D结构重建的一种常见形式。

4.Long-term 3D Point Tracking (长期3D点跟踪)

长期3D点跟踪是指在长时间的视频序列中,跟踪特定3D点的运动轨迹。这要求模型不仅能预测点在每一帧的2D位置,还要能处理遮挡、视角变化和光照变化等挑战,保持点身份的一致性。