I have the following implementation in my PyTorch-based code which involves a nested for loop. The nested for loop along with the if
condition makes the code very slow to execute. I attempted to avoid the nested loop to involve the broadcasting concepts in NumPy and PyTorch but that did not yield any result. Any help regarding avoiding the for
loops will be appreciated.
Here are the links I have read PyTorch, NumPy.
#!/usr/bin/env python
# coding: utf-8
import torch
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8
teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])
student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))
student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])