diff --git a/src/PRF/prf.py b/src/PRF/prf.py index fef2016..a02ac00 100644 --- a/src/PRF/prf.py +++ b/src/PRF/prf.py @@ -75,10 +75,10 @@ def listFD(url, ext=''): rows = np.array([int(file[-17:-13]) for file in filelist]) #Bilinear interpolation between four surrounding PRFs - LL = np.where((rows < rownum) & (cols < colnum))[0] #lower left - LR = np.where((rows > rownum) & (cols < colnum))[0] #lower right - UL = np.where((rows < rownum) & (cols > colnum))[0] #upper left - UR = np.where((rows > rownum) & (cols > colnum))[0] #uppper right + LL = np.where((rows <= rownum) & (cols <= colnum))[0] #lower left + LR = np.where((rows >= rownum) & (cols <= colnum))[0] #lower right + UL = np.where((rows <= rownum) & (cols >= colnum))[0] #upper left + UR = np.where((rows >= rownum) & (cols >= colnum))[0] #uppper right dist = np.sqrt((rows-rownum)**2. + (cols-colnum)**2.) surroundinginds = [subset[np.argmin(dist[subset])] for subset in [LL,LR,UL,UR]] #Following https://stackoverflow.com/a/8662355 @@ -88,10 +88,29 @@ def listFD(url, ext=''): prf = hdulist[0].data points.append((cols[ind],rows[ind],prf)) hdulist.close() - points = sorted(points) + # if 0 and 1 match, then 2 will as well: don't compare the numpy arrays though, + # since sorted will throw an error + points = sorted(points, key=lambda x: (x[0], x[1])) (x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points - self.prf = (q11 * (x2 - colnum) * (y2 - rownum) + + # handle edge cases + if x1 == x2: + # Linear interpolation in y direction only + if y1 == y2: + # If both x and y coordinates are equal (single point), return any value + self.prf = q11 + else: + # Interpolate along y axis + self.prf = q11 + (q12 - q11) * (rownum - y1) / (y2 - y1) + + # Handle case where y1 equals y2 (horizontal line) + elif y1 == y2: + # Linear interpolation in x direction only + self.prf = q11 + (q21 - q11) * (colnum - x1) / (x2 - x1) + + # Standard bilinear interpolation when we have a proper rectangle + else: + self.prf = (q11 * (x2 - colnum) * (y2 - rownum) + q21 * (colnum - x1) * (y2 - rownum) + q12 * (x2 - colnum) * ( rownum - y1) + q22 * (colnum - x1) * ( rownum - y1)