Advertisement
Kaelygon

Kaelygon variant of Protect-Images-from-AI-PixelGuard

Jun 10th, 2025 (edited)
349
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.48 KB | None | 0 0
  1. #!/usr/bin/env python3
  2.  
  3. import os
  4. import sys
  5. import time
  6. import math
  7. import argparse
  8. import logging
  9. import subprocess
  10. from pathlib import Path
  11.  
  12. import numpy as np
  13. import cv2
  14. import torch
  15. from PIL import Image
  16. from scipy.fftpack import dct, idct
  17. from scipy.ndimage import gaussian_filter
  18. from skimage.color import rgb2lab, lab2rgb
  19. import pywt
  20. import torchvision
  21. from torchvision.models import resnet50, ResNet50_Weights
  22. from torchvision import transforms
  23.  
  24. logging.basicConfig(level=logging.DEBUG, format='%(message)s')
  25.  
  26. #clip extreme values
  27. def clip_percentile(value, low, high):
  28.     value = value - value.mean() #shift mean to 0
  29.  
  30.     low_bound = np.percentile(value, low)
  31.     high_bound = np.percentile(value, high)
  32.  
  33.     value = np.clip(value, low_bound, high_bound)
  34.     return value, low_bound, high_bound
  35.  
  36. #Clip value extremes (percentile/100%)
  37. #Set value mean 0.0 and linearly scale range to -1.0, 1.0
  38. def normalize_max(value, percentile):
  39.     value, lo, hi = clip_percentile(value, percentile, 100.0-percentile)
  40.    
  41.     max_value = max(abs(lo), abs(hi))
  42.     if(not max_value):
  43.         return np.zeros_like(value)
  44.     value = value / max_value #scale max to -+1.0
  45.     return value
  46.  
  47.  
  48. #Scale by lerp such [that 0, 1, max_scale] input and output match
  49. def lerp_factor(value, strength, max_scale):
  50.     if(strength>1.0):
  51.         lerp_value = strength/max_scale
  52.         value = value*(1.0-lerp_value) + max_scale*lerp_value
  53.     else:
  54.         value*=strength
  55.        
  56.     return value
  57.    
  58. class Image_protector:
  59.    
  60.     def __init__(self):
  61.         self.supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
  62.        
  63.         #Load pre-trained ResNet50 model
  64.         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  65.         self.model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT).to(self.device)
  66.         self.model.eval()
  67.        
  68.         #Define image preprocessing
  69.         self.resize_size = (256, 256)
  70.         self.preprocess = transforms.Compose([
  71.             transforms.Resize(self.resize_size),
  72.             transforms.ToTensor(),
  73.             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  74.         ])
  75.        
  76.  
  77.     #TODO: Create presets for different levels of protection
  78.     #Apply protections to image
  79.     def protect_image(self, image_path, output_dir='protected_images', intensity=1.0):
  80.        
  81.         #Scale factors by intensity, 0 to 255
  82.         shift_strength          =  2.0
  83.         dct_strength            = 24.0
  84.         wavelet_strength        =  8.0
  85.         fourier_strength        = 20.0
  86.         pertubation_strength    =  4.0
  87.         arr_intensity = np.array( [shift_strength, dct_strength, wavelet_strength, fourier_strength, pertubation_strength] )
  88.         shift_strength, dct_strength, wavelet_strength, fourier_strength, pertubation_strength = lerp_factor(arr_intensity, intensity, 255.0)
  89.          
  90.         print("Mix strengths:\nshift %.1f\ndct  %.1f\nwave %.1f\nfft %.1f\npertub %.1f" % (shift_strength, dct_strength, wavelet_strength, fourier_strength, pertubation_strength))
  91.        
  92.         try:
  93.             file_extension = os.path.splitext(image_path)[1].lower()
  94.             if file_extension not in self.supported_formats:
  95.                 return f"Unsupported file format: {file_extension}"
  96.            
  97.             #Make folder
  98.             os.makedirs(output_dir, exist_ok=True)
  99.             logging.debug(f"Processing image: {image_path}")
  100.             #format pixels
  101.             with Image.open(image_path) as img:
  102.                 image = np.array(img)
  103.                 if len(image.shape) == 2:  #grayscale to RGB
  104.                     image = np.stack((image,)*3, axis=-1)
  105.                 elif image.shape[2] == 4:  #Remove alpha
  106.                     image = image[:,:,:3]
  107.            
  108.             #Convert to BGR
  109.             image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  110.            
  111.             #Apply protections
  112.             protected_image = image
  113.             protected_image = self.apply_color_shift(protected_image, strength=shift_strength)
  114.             protected_image = self.apply_noise_tf(protected_image, col=0, strength=dct_strength)
  115.             protected_image = self.apply_noise_tf(protected_image, col=1, strength=wavelet_strength)
  116.             protected_image = self.apply_noise_tf(protected_image, col=2, strength=fourier_strength)
  117.             protected_image = self.apply_perturbation(protected_image, strength=pertubation_strength, repeat_count=5, is_adversarial=1)
  118.  
  119.             #Save to file
  120.             final_image_path = os.path.join(output_dir, f'p_{os.path.basename(image_path)}')
  121.             pil_image = Image.fromarray(cv2.cvtColor(protected_image, cv2.COLOR_BGR2RGB))
  122.             if file_extension.lower() in ['.jpg', '.jpeg']:
  123.                 pil_image.save(final_image_path, quality=95)
  124.             else:#png
  125.                 pil_image.save(final_image_path)
  126.  
  127.             logging.debug(f"\nSaved protected image with embedded info: {final_image_path}")
  128.             return f"Image processing complete. Protected image saved as {final_image_path}"
  129.        
  130.         except Exception as e:
  131.             logging.error(f"Error processing image: {str(e)}", exc_info=True)
  132.             return f"Error processing image {image_path}: {str(e)}"
  133.  
  134.  
  135.  
  136.  
  137.     ### Generate top class features and mix them with original image ###
  138.     #is_adversarial=0 increases top class probs
  139.     #is_adversarial=1 reduces   top class probs
  140.     def apply_perturbation(self, image, strength, repeat_count, is_adversarial=0):
  141.         debug_string = "adversarial" if is_adversarial else "reinforcing"
  142.         logging.debug("\nApplying %s perturbation", debug_string)
  143.        
  144.         orig_image = image.astype(np.float32) / 255.0
  145.         perturbed_image = image #original image mixed with accumulated perturbation
  146.         accumulate = 0.0
  147.        
  148.            
  149.         #accumulate features top guess features
  150.         for i in range(repeat_count):
  151.             logging.debug("Tensor pass No. %s/%s", i + 1, repeat_count)
  152.            
  153.             #Prepare image
  154.             perturbed_image = Image.fromarray(perturbed_image)
  155.             img_tensor = self.preprocess(perturbed_image).unsqueeze(0).to(self.device)
  156.             img_tensor = img_tensor.requires_grad_(True)
  157.            
  158.             #Feed image to tensor and get top class
  159.             output = self.model(img_tensor)
  160.             probs = torch.nn.functional.softmax(output, dim=1)
  161.             top_class = probs.argmax(dim=1)
  162.            
  163.             loss = torch.nn.functional.log_softmax(output, dim=1)[0, top_class.item()]
  164.             loss.backward()
  165.  
  166.             #Generate
  167.             perturbation = img_tensor.grad.data
  168.             perturbation*= -1.0 if is_adversarial else 1.0 #invert to subtract the detected features
  169.             perturbation = torch.nn.functional.interpolate(
  170.                 perturbation, size=image.shape[:2], mode='bicubic', align_corners=False
  171.             ).squeeze().permute(1, 2, 0).cpu().numpy()
  172.                
  173.             accumulate+= perturbation
  174.            
  175.             #Mix accumulated perturbations with image
  176.             new_delta = normalize_max(accumulate,0.2)
  177.             new_delta*= (strength/255.0)
  178.            
  179.             perturbed_image = np.clip(orig_image + new_delta, 0.0, 1.0) * 255.0
  180.             perturbed_image = np.round(perturbed_image).astype(np.uint8)
  181.        
  182.         logging.info("Lum delta: %f", (perturbed_image / 255.0).mean() - orig_image.mean())
  183.         return perturbed_image
  184.    
  185.    
  186.    
  187.     ### Add noise to a color channel transform ###
  188.     def apply_noise_tf(self, image, col, strength):
  189.         tf_list = ["dct", "wavelet", "fourier"]
  190.         tf_channel = ["blue", "green", "red"]
  191.         tf_enum = [0,1,2]
  192.        
  193.         logging.debug("Applying %s watermark", tf_list[col])
  194.         col_channel = image[:,:,col].astype(float)
  195.        
  196.         #transform channel
  197.         if(tf_list[col] == "dct"):
  198.             transform = dct(dct(col_channel.T, norm='ortho').T, norm='ortho')
  199.         elif(tf_list[col] == "wavelet"):
  200.             transform, (cH, cV, cD) = pywt.dwt2(col_channel, 'haar')
  201.         else:#fft
  202.             transform = np.fft.fft2(col_channel, norm='ortho')
  203.        
  204.         #add watermark
  205.         noise_map = np.random.normal(0, 1, transform.shape)
  206.         transform+= noise_map
  207.        
  208.         #inverse transform
  209.         if(tf_list[col] == "dct"):
  210.             new_channel = idct(idct(transform.T, norm='ortho').T, norm='ortho')
  211.         elif(tf_list[col] == "wavelet"):
  212.             new_channel = pywt.idwt2((transform, (cH, cV, cD)), 'haar')
  213.         else:#fft
  214.             new_channel = np.fft.ifft2(transform, norm='ortho').real
  215.        
  216.         #add the delta to the original image
  217.         new_delta = new_channel - col_channel
  218.         new_delta = gaussian_filter(new_delta, sigma=2.0)
  219.         new_delta = new_delta - new_delta.mean() #center to zero
  220.         new_delta = new_delta / np.max(np.abs(new_delta)) #normalize
  221.         col_channel+= new_delta * strength
  222.         image[:,:,col] = np.clip(col_channel, 0, 255).astype(np.uint8)
  223.        
  224.         return image
  225.    
  226.    
  227.     ### Split hue and chroma to bins and shift them by random amount ###
  228.     #strength = 0 to 255
  229.     def apply_color_shift(self, image, strength=1.0):
  230.         logging.debug("Applying color shift")
  231.         strength = strength/255.0*180.0 #rgb to degrees
  232.        
  233.         lab = rgb2lab(image)
  234.         L, a, b = lab[..., 0], lab[..., 1], lab[..., 2]
  235.         C = np.sqrt(a**2 + b**2)
  236.         H = np.arctan2(b, a)  #radians [-, ]
  237.  
  238.         hue_rads    = (strength * math.pi / 360.0)
  239.         croma_shift = strength/1.8 #range 0 to 100
  240.  
  241.         num_bins_H = 16
  242.         num_bins_C = 16
  243.  
  244.         C_max = max(1e-12, np.max(C))
  245.  
  246.         #Compute bin indices using uniform bins
  247.         H_idx = np.floor((H + np.pi) / (2 * np.pi) * num_bins_H).astype(np.int32)
  248.         C_idx = np.floor(C / C_max * num_bins_C).astype(np.int32)
  249.  
  250.         #Clamp indices
  251.         H_idx = np.clip(H_idx, 0, num_bins_H - 1)
  252.         C_idx = np.clip(C_idx, 0, num_bins_C - 1)
  253.  
  254.         #Generate jitter arrays once
  255.  
  256.         #Generate random jitter offsets per bin (fixed for the image)
  257.         hue_jitter_offsets = np.random.uniform(-hue_rads, hue_rads, num_bins_H)
  258.         chroma_jitter_offsets = np.random.uniform(-croma_shift, croma_shift, num_bins_C)
  259.  
  260.         #Apply jitter
  261.         H_new = H + (hue_jitter_offsets[H_idx])
  262.         C_new = C + (chroma_jitter_offsets[C_idx])
  263.  
  264.         a_new = C_new * np.cos(H_new)
  265.         b_new = C_new * np.sin(H_new)
  266.  
  267.         lab_mod = np.stack([L, a_new, b_new], axis=-1)
  268.         jittered_image = np.round(lab2rgb(lab_mod) * 255.0)
  269.         jittered_image = np.clip(jittered_image, 0, 255)
  270.  
  271.         return jittered_image.astype(np.uint8)
  272.  
  273.  
  274.     def batch_protect_image(self, image_paths, output_dir='./', **kwargs):
  275.         os.makedirs(output_dir, exist_ok=True)
  276.         total_images = len(image_paths)
  277.         results = []
  278.         for i, image_path in enumerate(image_paths):
  279.             result = self.protect_image(image_path, output_dir, **kwargs)
  280.             results.append(result)
  281.             yield (i + 1) / total_images  #Yield progress
  282.         return results
  283.  
  284.  
  285.  
  286. #Ensure metadata is unchanged
  287. import subprocess
  288. def copy_metadata(original_path, protected_path):
  289.     subprocess.run([
  290.         "exiftool",
  291.         "-TagsFromFile", str(original_path),
  292.         "-all:all",
  293.         "-unsafe",
  294.         "-icc_profile",
  295.         "-o", str(protected_path) + ".tmp",
  296.         str(protected_path)
  297.     ], check=True)
  298.     #Then move tmp back to original
  299.     import os
  300.     os.replace(str(protected_path) + ".tmp", str(protected_path))
  301.     print(f"Metadata copied from {original_path.name} to {protected_path.name}")
  302.  
  303.  
  304. #AI classification test
  305. def classify_image(image_path, model, preprocess, device):
  306.     img = Image.open(image_path).convert('RGB')
  307.     img_t = preprocess(img).unsqueeze(0).to(device)
  308.     with torch.no_grad():
  309.         outputs = model(img_t)
  310.     return torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
  311.  
  312.  
  313. def predict_image(image_path):
  314.     #Load model
  315.     model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
  316.     model.eval()
  317.  
  318.     #Preprocess image
  319.     transform = transforms.Compose([
  320.         transforms.Resize(256),
  321.         transforms.CenterCrop(224),
  322.         transforms.ToTensor(),
  323.         transforms.Normalize([0.485, 0.456, 0.406],
  324.                             [0.229, 0.224, 0.225])
  325.     ])
  326.  
  327.     image = Image.open(image_path)
  328.     input_tensor = transform(image).unsqueeze(0)
  329.  
  330.     #Run model
  331.     with torch.no_grad():
  332.         output = model(input_tensor)
  333.         probs = torch.nn.functional.softmax(output[0], dim=0)
  334.         top_prob, top_class = torch.topk(probs, 5)
  335.        
  336.     weights = ResNet50_Weights.DEFAULT
  337.     model = resnet50(weights=weights)
  338.     imagenet_classes = weights.meta["categories"]
  339.  
  340.     #Print top predictions
  341.     for prob, cls_idx in zip(top_prob, top_class):
  342.         print(f"{imagenet_classes[cls_idx]}: {prob.item():.4f}")
  343.  
  344.  
  345.  
  346.  
  347.  
  348.  
  349. if __name__ == "__main__":
  350.     np.random.seed(int(time.time()))
  351.  
  352.     parser = argparse.ArgumentParser()
  353.     parser.add_argument("images", nargs="+", help="Image file paths")
  354.     parser.add_argument("--intensity", type=float, default=1.0, help="Intensity value (default: 1.0)")
  355.     parser.add_argument("--output", type=str, default="./", help="Output path (default: \"./\")")
  356.     args = parser.parse_args()
  357.  
  358.     image_paths = args.images
  359.  
  360.     protectedFolder = Path("./")
  361.     output_paths = []
  362.     for imgPath in image_paths:
  363.         p = Path(imgPath)
  364.         protectedName = f"p_{p.name}"
  365.         output_paths.append(protectedFolder / protectedName)
  366.  
  367.     protector = Image_protector()
  368.  
  369.     #Process batch
  370.     for progress, inputPath in zip(protector.batch_protect_image(image_paths, output_dir=args.output, intensity=args.intensity), image_paths):
  371.         print(f"Progress: {progress * 100:.2f}% - Processing {inputPath}")
  372.  
  373.     for inputPath, outputPath in zip(image_paths, output_paths):
  374.         #compare classification probs
  375.         print(f"\nOriginal image classification prediction\n",inputPath)
  376.         predict_image(inputPath)
  377.        
  378.         print(f"\nProtected image classification prediction\n",outputPath)
  379.         predict_image(outputPath)
  380.    
  381.     print("")
  382.     #Copy original metadata
  383.     for original_path, protected_path in zip(image_paths, output_paths):
  384.         copy_metadata(Path(original_path), Path(protected_path))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement