tensor([[-0.0701, 0.0035, -0.0785, ..., 0.1628, 0.0201, -0.0419],
[-0.0350, -0.0082, -0.0715, ..., 0.1119, -0.0159, -0.1164],
[-0.0753, 0.0172, -0.0513, ..., 0.1070, 0.1476, -0.0699],
...,
[-0.1425, 0.1393, -0.2316, ..., 0.0169, 0.0897, -0.0431],
[-0.0690, 0.0513, -0.0935, ..., 0.1311, 0.0324, -0.0705],
[-0.0812, 0.0511, -0.0482, ..., 0.1010, 0.0365, -0.0582]],
device='cuda:0', dtype=torch.float16)
tensor([[-0.0701, 0.0035, -0.0785, ..., 0.1628, 0.0201, -0.0419],
[-0.0350, -0.0082, -0.0715, ..., 0.1119, -0.0159, -0.1164],
[-0.0753, 0.0172, -0.0513, ..., 0.1070, 0.1476, -0.0699],
...,
[-0.1425, 0.1393, -0.2316, ..., 0.0169, 0.0897, -0.0431],
[-0.0690, 0.0513, -0.0935, ..., 0.1311, 0.0324, -0.0705],
[-0.0812, 0.0511, -0.0482, ..., 0.1010, 0.0365, -0.0582]],
device='cuda:0', dtype=torch.float16)
torch.Size([1024, 96]) torch.Size([1024, 96])
match: tensor(1., device='cuda:0')