pytorch怎么加flatten函数?有人知道吗?