0


RuntimeError: expected scalar type float but found __int64

问题描述

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)<ipython-input-30-d9bacc2c4126>in<module>4445 gat = GATConv(dataset.num_features,16)--->46 gat(data.x, data.edge_index).shape

D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self,*input,**kwargs)1108ifnot(self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109or _global_forward_hooks or _global_forward_pre_hooks):->1110return forward_call(*input,**kwargs)1111# Do not call functions when jit is used1112         full_backward_hooks, non_full_backward_hooks =[],[]<ipython-input-30-d9bacc2c4126>in forward(self, x, edge_index)3132         adj = to_dense_adj(edge_index)--->33         attention = torch.where(adj >0, e,0)3435         attention = F.softmax(attention, dim=1)

RuntimeError: expected scalar typefloat but found __int64

原因分析:

调用

torch.where()

时传入了int类型整数,但是函数的输入参数要求传入float类型数据,所以修改下类型即可。

解决方案:

attention = torch.where(adj >0, e, torch.tensor(0, dtype=torch.float32))

本文转载自: https://blog.csdn.net/m0_47256162/article/details/128744085
版权归原作者 海洋.之心 所有, 如有侵权,请联系我们删除。

“RuntimeError: expected scalar type float but found __int64”的评论:

还没有评论