2

I have a quadratic loss z=(1/2)||Aw-b||^2 where A is 4x2 matrix, w=[x,y] is a 2d vector, and b is a 4d vector. If we plot z, there would be a surface in terms of x,y. I want to plot z using Plotly library. To do this, I want to use Pytorch and the function torch.norm for calculating the norm. Here is a worked example for plotting a 3d surface and I want to modify it as follows:

import plotly.graph_objects as go
import numpy as np

A = torch.tensor([[ 0.1542, -0.0682],
        [ 0.8631,  0.6762],
        [-1.4002,  1.1773],
        [ 0.4614,  0.2431]])

b = torch.tensor([-0.2332, -0.7453,  0.9061,  1.2118])

x = np.arange(-1,1,.01)
y = np.arange(-1,1,.01)
X,Y = np.meshgrid(x,y)

W = ??????

Z = 0.5*torch.norm(torch.matmul(A, W)-b)**2

fig = go.Figure(
    data=[go.Surface(z=Z, x=x, y=y, colorscale="Reds", opacity=0.5)])
fig.update_layout(
    title='My title', 
    autosize=False,
    width=500, 
    height=500,
    margin=dict(l=65, r=50, b=65, t=90), 
    scene_aspectmode='cube'
)
fig.show()

Question:

How should I modify W which includes x,y to plot the surface?

Saeed
  • 598
  • 10
  • 19

2 Answers2

2

You could simply do:

Z = [[0.5 * torch.norm(torch.matmul(A, torch.tensor([float(xx), float(yy)]))-b)**2 for xx in x] for yy in y]

Update: You can improve the performance significantly by using torch's micro batch feature. For this you have to reshape your data to lists of matrices. That means you have to extend tensor A to a list that contains only one matrix and W to a list that contains all mesh points, each as matrix.

import plotly.graph_objects as go
import torch

A = torch.tensor([[[0.1542, -0.0682],
                   [0.8631,  0.6762],
                   [-1.4002, 1.1773],
                   [0.4614,  0.2431]]])

b = torch.tensor([-0.2332, -0.7453,  0.9061,  1.2118])

x = torch.arange(-1, 1, 0.01)
y = torch.arange(-1, 1, 0.01)
W = torch.reshape(torch.cartesian_prod(x, y), (len(x) * len(y), 2, 1))

V = torch.reshape(torch.matmul(A, W), (len(x), len(y), 4)) - b
Z = 0.5 * torch.norm(V, dim=2)**2
Markus
  • 5,976
  • 5
  • 6
  • 21
  • I ran your code and @John Doe's. Yours take 22 seconds and his takes 17 seconds while he has more for loops than you. Do you have any idea how to make yours faster? Specially, is there any way to avoid loops for calculating `Z`? It is fair to say that your code is really nice and has written in a Pythonic way. – Saeed Aug 15 '22 at 00:50
  • @Sepide John Doe's took less than 5 seconds for me :/ – DialFrost Aug 15 '22 at 01:45
  • @DialFrost: How about Markus's code? I am taking relatively and saying even though John Doe's has more loops, it is faster than Markus's. – Saeed Aug 15 '22 at 01:49
  • @Sepide: If performance matters I have to take a look on this later. Generally speaking: [about comprehensions](https://stackoverflow.com/questions/30245397/why-is-a-list-comprehension-so-much-faster-than-appending-to-a-list) – Markus Aug 15 '22 at 06:27
  • @Sepide: Observation on my laptop: Both `torch` based solutions, John Doe's and mine, take around 3 secs, while the pure `numpy` solution takes 2.4 secs. I have measured only the formula, not the rendering of the figure. – Markus Aug 15 '22 at 07:15
  • @Markus: You can take your time and get back to me later. Maybe, rendering the image is the part that takes most of the time. – Saeed Aug 15 '22 at 07:47
  • @Sepide I think the reason may be related to cache misses. Filling W first lets its entries in contiguous memory addresses. I will do some tests from my side as well. I think there are still a lot of little tricks to optimize it more. – joaopfg Aug 15 '22 at 08:33
  • @Sepide Please, check the updates in my answer. 40 times faster code xD (and I'm still trying to squeeze a little more improvement). – joaopfg Aug 15 '22 at 11:02
  • @Sepide Please, check it out again. I still optimized a bit more by JIT compiling the linear algebra operations. I will still try to squeeze a bit more improvement. – joaopfg Aug 15 '22 at 11:26
  • @Sepide I tried some more tricks but nothing improved much more the time. I think my current answer is the most optimized you can get – joaopfg Aug 15 '22 at 12:36
1

After a lot of trial and error, turns out the most optimized code is to write the linear algebra operation "by hand" (without using any linear algebra packages at all) and JIT compiling it (with Numba).

import plotly.graph_objects as go
from numba import njit, float32


@njit(float32(float32, float32))
def op(x, y):
    return 0.5 * ((0.1542*x - 0.0682*y - 0.2332)*(0.1542*x - 0.0682*y - 0.2332) +
                  (0.8631*x + 0.6762*y - 0.7453)*(0.8631*x + 0.6762*y - 0.7453) +
                  (-1.4002*x + 1.1773*y + 0.9061)*(-1.4002*x + 1.1773*y + 0.9061) +
                  (0.4614*x + 0.2431*y + 1.2118)*(0.4614*x + 0.2431*y + 1.2118))


interval = [i/100 for i in range(-100, 100)]
z = [[op(xi, yi) for yi in interval] for xi in interval]
fig = go.Figure(data=[go.Surface(z=z, x=interval, y=interval)])
fig.show()
joaopfg
  • 1,227
  • 2
  • 9
  • 18
  • I want to do this for arbitrary number of matrices. The way you hard coded impedes me doing it. – Saeed Aug 15 '22 at 16:53
  • Are your matrices always as small as the ones you showed in your question ? If so, you can put them as parameters of the JIT compilable function `op`. The main idea to keep in mind is to avoid using linear algebra packages if your matrices are small. This way you can JIT compile the linear algebra operations and get performances comparable to C/C++. – joaopfg Aug 15 '22 at 16:56
  • 1
    @Sepide Also, you should specify in the details of your question that you also can use other matrices. Just a constructive critics: have some consideration for the ones who spend time to answer your question. Put all the details needed for the answer in the question. Like that you have chances to get a better answer. – joaopfg Aug 15 '22 at 17:07
  • You are right and I will make sure to that. Thank you for giving me feedback. – Saeed Aug 15 '22 at 19:46