fig2 = plt.figure()
ax = fig2.add_subplot(111, projection='3d')
for j in range(obs.shape[0]):
x = VT[0,:] @ obs[j,:].T
y = VT[1,:] @ obs[j,:].T
z = VT[2,:] @ obs[j,:].T
if grp[j] == 'Cancer':
ax.scatter(x,y,z,marker='x',color='r',s=50)
else:
ax.scatter(x,y,z,marker='o',color='b',s=50)
ax.view_init(25,20)
plt.show()