in differentiable_robot_model/rigid_body.py [0:0]
def update_joint_state(self, q, qd):
batch_size = q.shape[0]
joint_ang_vel = qd @ self.joint_axis
self.joint_vel = SpatialMotionVec(
torch.zeros_like(joint_ang_vel), joint_ang_vel
)
rot_angles_vals = self.rot_angles()
roll = rot_angles_vals[0, 0]
pitch = rot_angles_vals[0, 1]
yaw = rot_angles_vals[0, 2]
fixed_rotation = (z_rot(yaw) @ y_rot(pitch)) @ x_rot(roll)
# when we update the joint angle, we also need to update the transformation
self.joint_pose.set_translation(torch.reshape(self.trans(), (1, 3)))
if torch.abs(self.joint_axis[0, 0]) == 1:
rot = x_rot(torch.sign(self.joint_axis[0, 0]) * q)
elif torch.abs(self.joint_axis[0, 1]) == 1:
rot = y_rot(torch.sign(self.joint_axis[0, 1]) * q)
else:
rot = z_rot(torch.sign(self.joint_axis[0, 2]) * q)
self.joint_pose.set_rotation(fixed_rotation.repeat(batch_size, 1, 1) @ rot)
return