{-# LANGUAGE RecordWildCards #-}

-- | Provides algorithms to extract potential motifs from extended RNA
-- secondary structures. In general, a motif is a set of non-Watson-Crick
-- basepairs and possibly pseudoknotted basepairs sandwiched between two
-- Watson-Crick basepairs. Individual functions below may work differently or
-- impose more restrictions.
--
-- TODO multibranched motifs

module BioInf.RNAmodule.Putative where

import Control.Monad
import Data.ByteString.Char8 (ByteString)
import Data.List (sort,groupBy,sortBy,nub,nubBy)
import Data.Tuple.Select
import qualified Data.ByteString.Char8 as BS
import Text.Printf

import Biobase.Primary
import Biobase.Secondary
import Biobase.Secondary.Isostericity
import BioInf.Secondary.Draw.DotBracket
import qualified Biobase.FR3D as FR3D
import qualified Biobase.FR3D.Import as FR3D



-- * Motif finding

-- | Spits out all potential helical motifs. A potential motif is a set of
-- non-canonical basepairs bracketed by a canonical basepair on top and a
-- canonical basepair at the bottom. This definition currently excludes certain
-- kinds of pseudoknots and multibranched loops as well as motif containing a
-- hairpin loop.
--
-- TODO need to handle "motif size" and "pseudoknotted" way better. Right now,
-- any pseudoknot in the whole structure will admint the whole structure as one
-- big motif, which is stupid.
--
-- TODO the guard on length bs can possibly be removed?

putativeMotifs :: FR3D.LinFR3D -> [PM] -- ByteString -> [ExtPairIdx] -> [PM]
putativeMotifs FR3D.LinFR3D{..} = do
  a1@(p1,_,_) <- pairs
  a2@(p2,_,_) <- pairs
  guard $ baseL p1 <= baseL p2 && baseR p1 >= baseR p2
  guard $ baseL p2 - baseL p1 <= 30
  guard $ baseR p1 - baseR p2 <= 30
  guard $ baseT p1 == cWW && baseT p2 == cWW
  let bs = filter (between p1 p2 . sel1) pairs
  guard . not . null $ bs
  guard $ length bs <= 10 -- TODO above
  guard $ all (\(b,_,_) -> baseT b/=cWW || pseudoknotted b (map sel1 bs)) bs
  return $ PM (BS.unpack pdbID) a1 a2 bs

-- | Hairpin motifs.
--
-- A hairpin motif is defined as the hairpin itself and the 'k' closest
-- basepairs. It is probably a good idea to look at either 2-3 basepairs or
-- until the first canonical pair is found, whichever comes last.

hairpinMotifs :: FR3D.LinFR3D -> [PM]
hairpinMotifs FR3D.LinFR3D{..} = do
  a@(p,_,_) <- pairs
  guard $ baseT p == cWW
  let bs = filter (within p . sel1) pairs
  guard . not . null $ bs
  guard . all ((/=cWW) . baseT . sel1) $ bs
  guard $ length bs <= 10
  return $ PMhp (BS.unpack pdbID) a bs



-- * helper functions

-- | The current definition of a putative motif.
--
-- TODO make this more "feature-complete"

data PM
  = PM
    { pmPDB  :: String
    , outer1 :: TriPair
    , outer2 :: TriPair
    , pairs  :: [TriPair]
    }
  | PMhp
    { pmPDB :: String
    , outer :: TriPair
    , pairs :: [TriPair]
    }
  deriving (Show)

-- |

type TriPair = (ExtPairIdx,ExtPair,FR3D.Basepair)

-- | Returns true if the second ext-pair is enclosed by the first one.
-- Nucleotide-sharing is possible.

within :: ExtPairIdx -> ExtPairIdx -> Bool
within ij x =
  let
    i = baseL ij; j = baseR ij
  in x/=ij && i <= baseL x && j >= baseR x

-- | Returns the subset of xs which is bracketed by ij and kl. The bracketed
-- basepairs can still be engaged in base triplets to the outside.

between :: ExtPairIdx -> ExtPairIdx -> ExtPairIdx -> Bool
between ij kl x =
  let
    i = baseL ij; j = baseR ij
    k = baseL kl; l = baseR kl
  in x/=ij && x/=kl && i <= baseL x && baseL x <= k && j >= baseR x && l <= baseR x

-- | True iff (i,j) is somehow pseudoknotted in xs.

pseudoknotted z xs = any (\x -> let k = baseL x; l = baseR x in i<k && k<j && j<l   ||   k<i && i<l && l<j) xs where
  i = baseL z
  j = baseR z

-- | Given a sequence and one putative motif, produce a "nice-looking" view
--
-- TODO need modL (+1) which updates via function (and modR, modP)

mkLocal :: ByteString -> PM -> LM
mkLocal seq (PM pdb ij kl bs) = LM pdb s xs (sort $ ij:kl:bs) (ij,kl) where -- (s,xs,sort $ ij:kl:bs) where
  lenI = baseL kl' - baseL ij'
  lenJ = baseR ij' - baseR kl'
  len = lenI + lenJ
  ij' = sel1 ij
  kl' = sel1 kl
  sLeft = (BS.drop (baseL ij') $ BS.take (baseL kl' +1) seq) 
  sRight = (BS.drop (baseR kl') $ BS.take (baseR ij' +1) seq)
  s = sLeft `BS.append` BS.cons '&' sRight
  xs = sort . map (\x -> doL . doR $ updR (baseR x +2) x) . map sel1 $ ij : kl : bs
  doL x
    | baseL x <= baseL kl' = updL (baseL x - baseL ij') x
    | otherwise = updL (baseL x - baseR kl' + lenI) x
  doR x
    | baseR x <= baseL kl' = updR (baseR x - baseL ij') x
    | otherwise = updR (baseR x - baseR kl' + lenI) x
mkLocal seq (PMhp pdb ij bs) = LMhp pdb s xs (sort $ ij:bs) ij where
  len = baseR ij' - baseL ij'
  ij' = sel1 ij
  s   = BS.drop (baseL ij') . BS.take (baseR ij' +1) $ seq
  xs  = sort . map (doL . doR) . map sel1 $ ij : bs
  doL x = updL (baseL x - baseL ij') x
  doR x = updR (baseR x - baseL ij') x



-- | Local motif; contains all the information needed to completely define a
-- motif.

data LM
  = LM
    { lmOriginPDB :: String
    , localSequence :: ByteString
    , localIdxs :: [ExtPairIdx]
    , globalIdxs :: [TriPair] -- [ExtPairIdx]
    , bounds :: (TriPair,TriPair) -- (ExtPairIdx,ExtPairIdx)
    }
  | LMhp
    { lmOriginPDB :: String
    , localSequence :: ByteString
    , localIdxs :: [ExtPairIdx]
    , globalIdxs :: [TriPair]
    , boundshp :: TriPair
    }
  deriving (Show)

-- | Basepair classes, needs the sequence and basepair to translate

bpClasses s ((i,j),t) = getClasses ((mkNuc $ s `BS.index` i, mkNuc $ s `BS.index` j),t)

-- | Returns true if two motifs are deemed essentially equal. This is not
-- trivial to do, but let's see...

sameMotif :: LM -> LM -> Bool
sameMotif (LM _ sx xs _ _) (LM _ sy ys _ _)
  | map (bpClasses sx) xs == map (bpClasses sy) ys
  && map baseP xs == map baseP ys
  = True -- same motif, according to pairing, using classes
  | otherwise = False
sameMotif (LMhp _ sx xs _ _) (LMhp _ sy ys _ _)
  | map (bpClasses sx) xs == map (bpClasses sy) ys
  && map baseP xs == map baseP ys
  = True
  | otherwise = False

-- | Useless naming.
--
-- NOTE why? because the third part of the triple contains positional
-- information relevant to the whole sequence, making it useless for
-- comparison.
--
-- TODO Once we make a good library out of all this, everything should be
-- easier, especially once the third part is removed.

eqMotif :: LM -> LM -> Bool
eqMotif (LM _ sx xs _ _) (LM _ sy ys _ _) = (sx,xs) == (sy,ys)
eqMotif (LMhp _ sx xs _ _) (LMhp _ sy ys _ _) = (sx,xs) == (sy,ys)


-- | groupBy requires presorting, which is not trivial. 'collectBy' does a more
-- general (and much slower) collection. Note that a `f` b && b `f` c does not
-- mean that a `f` c holds!

collectBy f xs = go [] xs where
  go ys [] = ys -- done with the whole collection
  go ys (x:xs) = go (go' x ys) xs -- insert x into one of the ys or create new group
  go' x [] = [ [x] ] -- no group for x, create new group
  go' x (y:ys)
    | any (f x) y = (x:y) : ys -- found group for x
    | otherwise   = y : go' x ys -- test the other groups



-- * temporary stuff that needs to be fixed up!

-- | lets view it
--
-- TODO print pairs in aA, bB style. How about triplets?

viewMotif :: LM -> IO ()
viewMotif (LM _ s xs orig _) = do
  let Parts{..} = drawParts Always Numbered (s,xs)
  maybe (return ()) putStrLn numbers
  putStrLn sequence
  putStrLn structure
  zipWithM_ (\e p -> printf "%s   %s\n" e (show $ bpClasses s p)) extended xs
viewMotif (LMhp _ s xs orig _) = do
  let Parts{..} = drawParts Always Numbered (s,xs)
  maybe (return ()) putStrLn numbers
  putStrLn sequence
  putStrLn structure
  zipWithM_ (\e p -> printf "%s   %s\n" e (show $ bpClasses s p)) extended xs

-- | print out the origin of the motif

origin :: LM -> IO ()
origin LM{..} = do -- (LM pdb s xs orig ( ((i,j),_) , ((k,l),_) ) ) = do
  -- printf "Instance: %s %4d %4d / %4d %4d\n" pdb (min i k) (max i k) (min j l) (max j l)
  printf "Instance: %s %s %4d %s %4d / %s %4d %s %4d\n"
    lmOriginPDB
    (BS.unpack . FR3D.chain1 . sel3 . fst $ bounds)
    (    (+1) . FR3D.seqpos1 . sel3 . fst $ bounds)
    (BS.unpack . FR3D.chain2 . sel3 . fst $ bounds)
    (    (+1) . FR3D.seqpos2 . sel3 . fst $ bounds)

    (BS.unpack . FR3D.chain1 . sel3 . snd $ bounds)
    (    (+1) . FR3D.seqpos1 . sel3 . snd $ bounds)
    (BS.unpack . FR3D.chain2 . sel3 . snd $ bounds)
    (    (+1) . FR3D.seqpos2 . sel3 . snd $ bounds)
origin LMhp{..} = do
  printf "Instance: %s %s %4d %s %4d\n"
    lmOriginPDB
    (BS.unpack . FR3D.chain1 . sel3 $ boundshp)
    (    (+1) . FR3D.seqpos1 . sel3 $ boundshp)
    (BS.unpack . FR3D.chain2 . sel3 $ boundshp)
    (    (+1) . FR3D.seqpos2 . sel3 $ boundshp)

