0


RuntimeError: expected scalar type float but found __int64

问题描述

  1. ---------------------------------------------------------------------------
  2. RuntimeError Traceback (most recent call last)<ipython-input-30-d9bacc2c4126>in<module>4445 gat = GATConv(dataset.num_features,16)--->46 gat(data.x, data.edge_index).shape
  3. D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self,*input,**kwargs)1108ifnot(self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
  4. 1109or _global_forward_hooks or _global_forward_pre_hooks):->1110return forward_call(*input,**kwargs)1111# Do not call functions when jit is used1112 full_backward_hooks, non_full_backward_hooks =[],[]<ipython-input-30-d9bacc2c4126>in forward(self, x, edge_index)3132 adj = to_dense_adj(edge_index)--->33 attention = torch.where(adj >0, e,0)3435 attention = F.softmax(attention, dim=1)
  5. RuntimeError: expected scalar typefloat but found __int64

原因分析:

调用

  1. torch.where()

时传入了int类型整数,但是函数的输入参数要求传入float类型数据,所以修改下类型即可。

解决方案:

  1. attention = torch.where(adj >0, e, torch.tensor(0, dtype=torch.float32))

本文转载自: https://blog.csdn.net/m0_47256162/article/details/128744085
版权归原作者 海洋.之心 所有, 如有侵权,请联系我们删除。

“RuntimeError: expected scalar type float but found __int64”的评论:

还没有评论