The keras-team on the GitHub answered the question about how to make a custom activation function.
There also is a question with a code with a custom activation function.
These pages may help you!
Additional comment
These pages were not enough for this question so I add the comment below;
Maybe PyTorch is better for customization than Keras. I tried to write such a network, though it is a very simple one, based on PyTorch tutorials and "Extending PyTorch with Custom Activation Functions"
I made a custom activation function in which the 1-th(counting from 0) elements of the output vector are equal to twice the 0-th elements. A very simple network with one layer was used for the training. After training, I checked that the condition was satisfied.
import torch
import matplotlib.pyplot as plt
# Define the custom activation function
# reference: https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa
def silu(input):
input[:,1] = input[:,0] * 2
return input
class SiLU(torch.nn.Module):
def __init__(self):
super().__init__() # init the base class
def forward(self, input):
return silu(input) # simply apply already implemented SiLU
# Training
# reference: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
k = 10
x = torch.rand([k,3])
y = x * 2
model = torch.nn.Sequential(
torch.nn.Linear(3, 3),
SiLU() # custom activation function
)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
for t in range(2000):
y_pred = model(x)
loss = loss_fn(y_pred, y)
if t % 100 == 99:
print(t, loss.item())
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# check the behaviour
yy = model(x) # predicted
print('ground truth')
print(y)
print('predicted')
print(yy)
# examples for the first five data
colorlist = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
plt.figure()
for i in range(5):
plt.plot(y[i,:].detach().numpy(), linestyle = "solid", label = "ground truth_" + str(i), color=colorlist[i])
plt.plot(yy[i,:].detach().numpy(), linestyle = "dotted", label = "predicted_" + str(i), color=colorlist[i])
plt.legend()
# check if the custom activation works correctly
plt.figure()
plt.plot(yy[:,0].detach().numpy()*2, label = '0th * 2')
plt.plot(yy[:,1].detach().numpy(), label = '1th')
plt.legend()
print(yy[:,0]*2)
print(yy[:,1])