Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 17419

Getting weired output Image on inference

$
0
0

I've tried to inference upscaling model on image in flutter.But I cant able to get output correctly.

Low Scaled image:

Output Image:

Some Log:

flutter: Image converted to normalized Float32Listflutter: Image normalized successfully.flutter: Input tensor created successfully.flutter: Output Shape: 1, 3, 804, 1428

Whole Code

class _MyHomePageState extends State<MyHomePage> {  late OrtSession ortSession;  img.Image? selectedImage;  @override  void initState() {    OrtEnv.instance.init();    debugPrint("ORT Environment initialized");    initializeModel();    super.initState();  }  @override  void dispose() {    OrtEnv.instance.release();    super.dispose();  }  Future<void> initializeModel() async {    final sessionOptions = OrtSessionOptions();    const assetFileName = 'assets/4xSPANkendata_fp32.onnx';    final rawAssetFile = await rootBundle.load(assetFileName);    final bytes = rawAssetFile.buffer.asUint8List();    ortSession = OrtSession.fromBuffer(bytes, sessionOptions);    debugPrint("Model initialized");  }  Future<void> inference() async {    if (selectedImage == null) {      debugPrint('No image selected');      return;    }    Float32List preprocessedImage =        NormalizeImage.imageToNormalizedFloat32List(selectedImage!);    final shape = [1, 3, selectedImage!.height, selectedImage!.width];    debugPrint('Image normalized successfully.');    final inputOrt =        OrtValueTensor.createTensorWithDataList(preprocessedImage, shape);    final inputs = {'input': inputOrt};    debugPrint('Input tensor created successfully.');    final runOptions = OrtRunOptions();    final outputs = await ortSession.runAsync(runOptions, inputs);    inputOrt.release();    runOptions.release();    outputs?.forEach((element) {      final outputValue = element?.value;      if (outputValue is List<List<List<List<double>>>>) {        img.Image generatedImage = generateImageFromOutput(outputValue);        showDialog(          context: context,          builder: (BuildContext context) {            return Dialog(              child: SizedBox(                width: generatedImage.width.toDouble(),                height: generatedImage.height.toDouble(),                child: Image.memory(                  Uint8List.fromList(img.encodePng(generatedImage)),                  fit: BoxFit.contain,                ),              ),            );          },        );      } else {        debugPrint("Output is of unknown type");      }      element?.release();    });  }  img.Image generateImageFromOutput(List<List<List<List<double>>>> output) {    final batchSize = output.length;    final channels = output[0].length;    final height = output[0][0].length;    final width = output[0][0][0].length;    debugPrint("Output Shape: $batchSize, $channels, $height, $width");    img.Image imgData = img.Image(width, height);    for (int y = 0; y < height; y++) {      for (int x = 0; x < width; x++) {        int r = (output[0][0][y][x] * 255).toInt().clamp(0, 255); // Assume RGB order        int g = (output[0][1][y][x] * 255).toInt().clamp(0, 255);        int b = (output[0][2][y][x] * 255).toInt().clamp(0, 255);        imgData.setPixelRgba(x, y, r, g, b);      }    }    return imgData;  }  Future<void> _pickImage() async {    final picker = ImagePicker();    final pickedFile = await picker.pickImage(source: ImageSource.gallery);    if (pickedFile != null) {      final bytes = await pickedFile.readAsBytes();      setState(() {        selectedImage = img.decodeImage(Uint8List.fromList(bytes));      });    }  }  @override  Widget build(BuildContext context) {    return Scaffold(      appBar: AppBar(        title: const Text('Image Inference'),      ),      body: Center(        child: Column(          mainAxisAlignment: MainAxisAlignment.center,          crossAxisAlignment: CrossAxisAlignment.center,          children: <Widget>[            if (selectedImage != null)              Image.memory(Uint8List.fromList(img.encodePng(selectedImage!))),            ElevatedButton(              onPressed: _pickImage,              child: const Text('Select Image'),            ),            ElevatedButton(              onPressed: inference,              child: const Text('Run Inference'),            ),          ],        ),      ),    );  }}

Normalize Class

class NormalizeImage {  static Float32List imageToNormalizedFloat32List(Image image) {    final int height = image.height;    final int width = image.width;    Float32List float32Image = Float32List(3 * height * width);    for (int i = 0; i < height; i++) {      for (int j = 0; j < width; j++) {        final int pixelIndex = (i * width + j) * 3;        final int pixel = image.getPixel(j, i);        float32Image[pixelIndex] = getRed(pixel) / 255.0;        float32Image[pixelIndex + 1] = getGreen(pixel) / 255.0;        float32Image[pixelIndex + 2] = getBlue(pixel) / 255.0;      }    }    debugPrint("Image converted to normalized Float32List");    return float32Image;  }}

I think the problem is mostly in generateImageFromOutput method.

I also tried to inference in python and its working correctly:

import cv2import numpy as npimport onnxruntime as rtimport torchimport time# Load the ONNX model with CUDA execution providersess = rt.InferenceSession('4xSPANkendata_fp32.onnx')print("loaded model.")# Load the input imageimg = cv2.imread('xterArt.png')img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# Convert BGR to RGB and transpose dimensionsin_mat = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]in_mat = in_mat.astype(np.float32)in_mat = in_mat / 255# Start time for measuring inference timestart_time = time.time()# Get input and output namesinput_name = sess.get_inputs()[0].nameoutput_name = sess.get_outputs()[0].nameprint("Input name: ", input_name)print("Output name: ", output_name)# Convert input to torch tensor and move it to GPUin_mat = torch.tensor(in_mat)# Run inferenceout_mat = sess.run([output_name], {input_name: in_mat.cpu().numpy()})[0]# Measure and print elapsed timeelapsed_time = time.time() - start_timeprint('Inference time: ',elapsed_time)# Save the output imageout_mat = (out_mat.squeeze().transpose((2, 1, 0)) * 255).clip(0, 255).astype(np.uint8)output_img = cv2.cvtColor(out_mat, cv2.COLOR_RGB2BGR)cv2.imwrite('outpust.jpg', out_mat)print("Output image saved.")

Viewing all articles
Browse latest Browse all 17419

Latest Images

Trending Articles



Latest Images

<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>