1

I'm using a Slurm Workload Manager on a server and import torch takes around 30-40 seconds. The IT people running it said they couldn't do much to improve it and it was just hardware related (maybe they missed something? but i've gone through the internet before asking them and couldn't find much either). By comparison, import numpy takes around 1 second.

I would like to know if there is a way to use the saved weights of a pytorch model to ONLY predict an output with a given input without importing torch (so no need to import everything related to gradients, etc ...). Theoretically, it is just matrix multiplications (I think?) so it probably is feasible by only using numpy? I need to do this several times on different jobs so I cannot cache / pass around the imported torch which is why I'm actively looking for a solution (but generally speaking taking something from 30-40 seconds to a few is pretty cool anyway).

If that matters, here is the architecture of my model:

ActionNN(
  (conv_1): Conv2d(5, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv_3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (norm_layer_1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (norm_layer_2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (norm_layer_3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  (mlp): Sequential(
    (0): Linear(in_features=71, out_features=128, bias=True)
    (1): ReLU()
  )
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
)
Number of parameters: 152284

If it was only fully connected layers, it would be "pretty easy" but because my network is a tiny bit more complex, I'm not sure how I should do it.

I saved the parameters using torch.save(my_network.state_dict(), my_path).

Since my script takes in total on average 35 seconds (import torch included), I would be able to run it in on average a second or two, which would be great.

Here is my profiling of import torch:

         1226310 function calls (1209639 primitive calls) in 49.994 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1273   21.590    0.017   21.590    0.017 {method 'read' of '_io.BufferedReader' objects}
     5276   12.145    0.002   12.145    0.002 {built-in method posix.stat}
     1273    7.427    0.006    7.427    0.006 {built-in method io.open_code}
    45/25    5.631    0.125    9.939    0.398 {built-in method _imp.create_dynamic}
        2    0.564    0.282    0.564    0.282 {built-in method _ctypes.dlopen}
     1273    0.288    0.000    0.288    0.000 {built-in method marshal.loads}
       17    0.286    0.017    0.286    0.017 {method 'readline' of '_io.BufferedReader' objects}
2809/2753    0.098    0.000    0.546    0.000 {built-in method builtins.__build_class__}
   1620/1    0.062    0.000   49.997   49.997 {built-in method builtins.exec}
    50145    0.051    0.000    0.119    0.000 {built-in method builtins.getattr}
     1159    0.048    0.000    0.115    0.000 inspect.py:3245(signature)
      424    0.048    0.000    0.113    0.000 assumptions.py:596(__init__)
       13    0.039    0.003    0.039    0.003 {built-in method io.open}
     1411    0.035    0.000    0.045    0.000 library.py:71(impl)
     1663    0.034    0.000   12.209    0.007 <frozen importlib._bootstrap_external>:1536(find_spec)
  • I think the inference speed-up that you get by using torch instead of numpy (gpu support in torch), it should make up for the slow import time of the library. – Joan Plepi Jun 06 '23 at 10:12
  • That's the thing, for my specific case I only use batch sizes of 1 (and I can't group everything in big batches) and therefore GPUs can't apply. Which is actually beneficial because I can use several thousands CPUs on several hundreds of jobs which ends up being way cheaper than thousands GPUs. – FluidMechanics Potential Flows Jun 06 '23 at 14:38
  • Have you tested whether `from torch import ...` is any faster than `import torch`? See https://stackoverflow.com/questions/3591962/python-import-x-or-from-x-import-y-performance – jared Jun 08 '23 at 21:52
  • @FluidMechanicsPotentialFlows - Have a look at pytorch `serve`. This is the benchmarking result - https://github.com/pytorch/serve/blob/master/benchmarks/README.md#sample-latency-graph . You can use standard tools, benefit from their functionalities, etc, and still get your inference results in just few seconds. Furthermore, you will keep inference framework separate from your application allowing you to scale up easily in future. Also, please profile `import torch`, you never know where those huge import times are coming from. In my system, import takes hardly 1 second. – saurabheights Jun 09 '23 at 08:45
  • @jared I just did and unfortunately I do have the same times approximately. – FluidMechanics Potential Flows Jun 09 '23 at 09:42
  • @saurabheights I don't really know how to profile it, any idea how I should? It might help the IT team working for our cluster. – FluidMechanics Potential Flows Jun 09 '23 at 09:43
  • @FluidMechanicsPotentialFlows - You can check my results for reference - https://imgur.com/a/DmDLBpB. Code - `echo "import torch" > test.py; python -m cProfile -s tottime test.py | head -n 20`. Let's hope it can help in pinpointing issue. For profiling, see - https://stackoverflow.com/a/582337/1874627 – saurabheights Jun 09 '23 at 11:42
  • I've added the profiling in my post. I don't know if that might help you? – FluidMechanics Potential Flows Jun 09 '23 at 15:45
  • 1
    Profiling `import torch` on my PC, I find that `built-in method posix.stat` takes 11 ms (vs 12145 ms in your environment) and `method 'read' of '_io.BufferedReader' objects` takes 6 ms (vs 21590 ms for you.) Why is disk I/O a thousand times slower in your environment? Are you running a network filesystem of some kind? – Nick ODell Jun 09 '23 at 22:30
  • It's not my environment but my supercomputer environment and I have no control over it which is why I'm looking for alternative answers :/ – FluidMechanics Potential Flows Jun 10 '23 at 13:37

1 Answers1

2

There is an easy way to save up on import time, it's to spin up a server, and import torch only once at start up and load the model once only.

Use Flask or better yet FastAPI, and spin up a simple HTTP server that will run the script on an HTTP call.

The server will take 40 seconds to start, but then any inference call will take just the time to connect and run inference.

from fastapi import Request, FastAPI
import torch

model = torch.load(<yourmodel here>)

app = FastAPI()

@app.post("/predict")
async def inference(request: Request):
    input = request.json()
    prediction = model.predict(input)
    return {"predictions": predictions}

call the server with whatever client by posting data to http://<your-host>:<post>/predict

see https://fastapi.tiangolo.com/tutorial/first-steps/ for more details.

MrE
  • 19,584
  • 12
  • 87
  • 105
  • That's a brillant idea, I'm currently trying to implement it, but it should work! The only thing is I have to communicate through lists instead of tensors since you can't post tensors so it's a bit cumbersome to re-write the code, I'll let you know if it works (it should) asap – FluidMechanics Potential Flows Jun 09 '23 at 15:45
  • I suspect they have forbidden the setting up of APIs on the cluster since i get the following error: `requests.exceptions.ConnectionError: HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 111] Connection refused'))`. It works locally but not on the cluster. – FluidMechanics Potential Flows Jun 09 '23 at 16:12
  • 1
    127.0.0.1 is localhost, that would work locally but not on a server. You need to set the server to bind to host="0.0.0.0" (not 127.0.0.1) or it won't be exposed to outside. Then you need to call the server via it's IP, like http://10.1.10.10:8000/predict 10.1.10.10 being the private IP of the server on your local network. Note that without some sort of auth, this will be exposed to anyone who can reach the host. – MrE Jun 09 '23 at 20:51
  • I set up the API on the cluster too, although actually it was on another machine so maybe two machines on the same cluster cannot communicate. Fair point, thanks! But since it's the same local network I could use the local IP? I'm going to try. – FluidMechanics Potential Flows Jun 10 '23 at 13:38