1

I have a basic flask server with a Generator model loaded. I'm sending an input vector via JSON which is hitting the Generator, which spits out a prediction. This works. I want to then send this image (I would settle for sending it as any kind of data I can reconstruct on the other end) to another application running on the same machine. From what I gather, it might be best to encode the image as base64, but all of my attempts have failed. Any guidance is appreciated.

@app.route("/json", methods=['GET', 'POST', 'PUT'])
def getjsondata():

if request.method=='POST':
    print("received POST")

    data = request.get_json()

    #print(format(data['z']))
    jzf = [float(i) for i in data['z']]
    jzft = torch.FloatTensor(jzf)
    jzftr = jzft.reshape([1, 512])

    z = jzftr.cuda()
    c = None                   # class labels (not used in this example)
    trunc = 1
    img = G(z, c, trunc)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

    # now what?
prismspecs
  • 1,482
  • 6
  • 20
  • 35
  • I see this is tagged [python-imaging-library], but I don't see any PIL code? If the tag is correct, have you tried sending the image back like this: https://stackoverflow.com/questions/7877282/how-to-send-image-generated-by-pil-to-browser – Nick ODell Mar 27 '22 at 19:11
  • 1
    @prismspecs Did you come up with a good/clean way of accomplishing this task? – Austin Heller Apr 10 '22 at 15:23
  • @AustinHeller yes, well, I can't tell you how clean it is but it works well for me! I will add a solution – prismspecs Apr 11 '22 at 20:52

1 Answers1

1

I managed a solution. The Python Flask server looks like this and expects to receive a JSON object containing a Z array of 512 floats and truncation of a single float.

from flask import Flask, jsonify, request, send_file # initialize our Flask application
import json
from flask_cors import CORS
import base64
from torchvision import transforms
import dnnlib
import torch
import PIL.Image
from io import BytesIO
from datetime import datetime
import legacy

device = torch.device('cuda')
with dnnlib.util.open_url("snapshot.pkl") as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

app = Flask(__name__)
CORS(app)

@app.route("/test", methods=['GET', 'POST', 'PUT'])
def test():
    return "OK"

@app.route("/query", methods=['GET', 'POST', 'PUT'])
def getjsondata():

    if request.method=='POST':
        # print("received POST")

        data = request.get_json()

        #print(format(data['z']))
        jzf = [float(i) for i in data['z']]
        jzft = torch.FloatTensor(jzf)
        jzftr = jzft.reshape([1, 512])

        z = jzftr.cuda()
        c = None                   # class labels (not used in this example)
        trunc = data['truncation']
        img = G(z, c, trunc)

        #img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)

        # turn into PIL image
        pil_img = transforms.ToPILImage()(img[0]).convert("RGB")
        #pil_img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')

        # SAVING...
        #fn = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
        #pil_img.save('saved_images/' + fn + '.jpg')

        response = serve_pil_image64(pil_img)
        response.headers.add('Access-Control-Allow-Origin', '*')
        # response.headers.add('Content-Transfer-Encoding', 'base64')
        return response


    return 'OK'

def serve_pil_image64(pil_img):
    img_io = BytesIO()
    pil_img.save(img_io, 'JPEG', quality=70)
    img_str = base64.b64encode(img_io.getvalue()).decode("utf-8")
    return jsonify({'status': True, 'image': img_str})


if __name__ == '__main__':
    app.run(host='localhost', port=9000, debug=True)

As it stands, I am sending said JSON array from a simple JavaScript/HTML site. It then listens for the response, also JSON.

// construct an HTTP request
var xhr = new XMLHttpRequest();

// upon successful completion of request...
xhr.onreadystatechange = function() {
    if (xhr.readyState == XMLHttpRequest.DONE) {
        var json = JSON.parse(xhr.responseText);
        // console.log(json);
        document.getElementById("image_output").src = "data:image/jpeg;base64," + json.image;

    }
}


xhr.open("POST", "http://localhost:9000/json");
xhr.setRequestHeader('Content-Type', 'application/json; charset=UTF-8');
prismspecs
  • 1,482
  • 6
  • 20
  • 35
  • To summarize, it looks like at some point in the redacted code you're converting your tensor into a list of floats and then you just convert them back into a FloatTensor later. That's been my solution as well, but I hate how massive the JSON becomes with all of those float values. I've been debating on converting the float values into bytes and then base64 encode them. If there's a better solution than that, I'm open to it. Thanks for sharing your approach. – Austin Heller Apr 17 '22 at 23:00