Haskell OpenGLで物体の回転

Quaternionをつかって物体の回転を行うプログラム。
3D‐CGプログラマーのためのクォータニオン入門―「ベクトル」「行列」「テンソル」「スピノール」との関係が分かる! (I・O BOOKS)の付録の移植。
LinearSpace -> Quaternionとクラスを積んでいく。
LinearSpaceはVectorSpaceを参考に作成。

linearspace.hs

{-# LANGUAGE TypeFamilies
 #-}

module Data.LinearSpace
      where

infixl 5 ^*^,^/^
infixl 6 ^+^,^-^

class AdditiveGroup e where
  -- | The zero element: identity for '(^+^)'
  zero :: e
  -- | add operator
  (^+^) :: e -> e -> e
  -- | inverse element
  negateE :: e -> e
  
-- | Group subtraction
(^-^) :: AdditiveGroup e => e -> e -> e
e ^-^ e' = e ^+^ negateE e'
         
class MultiplicativeGroup e where
  -- | The identity element: identity for '(^*^)'
  identity :: e
  -- | multiple operator
  (^*^) :: e -> e -> e
  -- | inverse element
  inverseE :: e -> e
  
(^/^) :: MultiplicativeGroup e => e -> e -> e
e ^/^ e' = e ^*^ inverseE e'
 
class ScalarMultiplicativeGroup e where
  type Scalar e :: *
  -- | scalar multiple operator
  (*^) :: Scalar e -> e -> e

class (AdditiveGroup f,MultiplicativeGroup f) => Field f

class (AdditiveGroup l,ScalarMultiplicativeGroup l) => LinearSpace l where
    -- | inner product
    (^.^) :: l -> l -> Scalar l   
    norm :: l -> Scalar l
    normalize :: l -> l

quaternion.hs

{-# LANGUAGE TypeFamilies
 #-}

module Quaternion
    where

import Data.LinearSpace hiding (norm)

data Quaternion a = Quaternion !a !a !a !a deriving(Show, Eq)

{--共役--}
conjugate:: (Fractional a) => (Quaternion a) -> (Quaternion a)
conjugate q@(Quaternion w x y z) = Quaternion w (-x) (-y) (-z)

{--||q||^2--}
normSquare :: (Fractional a, Floating a) => (Quaternion a) -> a
normSquare q@(Quaternion w x y z) = w*w + x*x + y*y + z*z

{--||q||--}
norm :: (Fractional a, Floating a) => (Quaternion a) -> a
norm q = sqrt $ normSquare q  

instance (Fractional a) => AdditiveGroup (Quaternion a) where
    zero      = Quaternion 0 0 0 0 
    q1@(Quaternion w1 x1 y1 z1) ^+^ q2@(Quaternion w2 x2 y2 z2) = 
        Quaternion (w1 + w2) (x1 + x2) (y1 + y2) (z1 + z2)
    negateE q@(Quaternion w x y z) = Quaternion (-w) (-x) (-y) (-z)
        
instance (Fractional a) => ScalarMultiplicativeGroup (Quaternion a) where
    type Scalar (Quaternion a)  = a
    s *^ q@(Quaternion w x y z) = Quaternion (s*w) (s*x) (s*y) (s*z) 

instance (Fractional a,Floating a) => MultiplicativeGroup (Quaternion a) where
    identity = Quaternion 1 0 0 0 
    q1@(Quaternion w1 x1 y1 z1) ^*^ q2@(Quaternion w2 x2 y2 z2) = Quaternion w' x' y' z'
      where
        w' = w1*w2 - x1*x2 - y1*y2 - z1*z2
        x' = w1*x2 + x1*w2 - y1*z2 + z1*y2
        y' = w1*y2 + x1*z2 + y1*w2 - z1*x2
        z' = w1*z2 - x1*y2 + y1*x2 + z1*w2
    inverseE q = (1/(normSquare q)) *^ conjugate q

{--単位クォータニオン--}
quaternionIdentity::(Fractional a,Floating a) => Quaternion a
quaternionIdentity = identity

{--クォータニオンのノルム--}                     
quaternionNorm::(Fractional a,Floating a) => Quaternion a -> a
quaternionNorm q = norm q

normalize::(Fractional a,Floating a) => Quaternion a -> Quaternion a
normalize q = (1/(norm q)) *^ q

main.hs

{-# LANGUAGE TypeFamilies
 #-}

import System.Exit ( exitWith, ExitCode(ExitSuccess))

import Control.Applicative
import Data.Foldable as Foldable hiding (mapM_)
import qualified Data.Vector as V
import Data.IORef
import Graphics.UI.GLUT as GLUT

import Data.LinearSpace as LinearSpace
import Quaternion as Q hiding (norm)

data AppCtx = AppCtx {
      curPos           :: Quaternion GLfloat, --現在の回転を表すクォータニオン
      lastMousePos     :: Vector2 GLfloat,    --前回のMouse位置
      icosahedron      :: DisplayList,        --正二十面体のDisplayList
      winSize          :: Vector2 GLfloat     --Windowのサイズ
    }

-- Vector3にLinearSpaceを適用
instance (Fractional a, Floating a) => ScalarMultiplicativeGroup (Vector3 a) where
    type Scalar (Vector3 a) = a
    s *^ v = (* s) <$> v

instance (Fractional a) => AdditiveGroup (Vector3 a) where
    zero        = Vector3 0 0 0
    v1 ^+^ v2   = (+) <$> v1 <*> v2
    negateE v   = (* (-1)) <$> v 

instance (Fractional a, Floating a) => LinearSpace (Vector3 a) where
    v1 ^.^ v2   = Foldable.foldl1 (+) ((*) <$> v1 <*> v2)
    norm v      = sqrt $ Foldable.foldl1 (+) ((*) <$> v <*> v)
    normalize v = (/ (norm v)) <$> v

--仮想トラックボール半径
vTrackBallR::(Floating a, Ord a) => a
vTrackBallR = 0.8

--色とか
black :: Color4 GLfloat
black = Color4 0.0 0.0 0.0 1.0
lightblack :: Color4 GLfloat 
lightblack = Color4 0.2 0.2 0.2 1.0
white :: Color4 GLfloat
white = Color4 1.0 1.0 1.0 1.0 
blue :: Color4 GLfloat
blue = Color4 0.0 0.0 1.0 1.0

--疑似トラックボール
simulateTrackball :: (Floating a, Ord a) => a -> a -> a -> a -> Quaternion a
simulateTrackball startX startY endX endY = 
    if (startX == endX) && (startY == endY) then
        quaternionIdentity
    else 
        calcRotateBall startX startY endX endY
    where 
      -- 半径vTrackBallRの球への投影
      projectToSphere :: (Floating a, Ord a) => a -> a -> a
      projectToSphere x y =
          if distanceFromOrig < vTrackBallR then
              --もし、(x,y)が半径内に有れば
              sqrt $ 2.0 * (vTrackBallR**2) - (x**2 + y**2)
          else
              (vTrackBallR**2)/distanceFromOrig
          where
            distanceFromOrig = sqrt $ x**2.0 + y**2.0 --原点からの距離
      -- 回転量と回転軸の計算
      calcRotateBall :: (Floating a, Ord a) => a -> a -> a -> a -> Quaternion a
      calcRotateBall startX startY endX endY =
          makeQuaternion (rotateAxis startPoint endPoint) (rotateQuantity startPoint endPoint)
              where
                startPoint = Quaternion 0 startX startY (projectToSphere startX startY) --trackballの開始位置
                endPoint   = Quaternion 0 endX endY (projectToSphere endX endY)     --trackballの終了位置
              --回転軸を求める
                rotateAxis :: (Floating a, Ord a) => Quaternion a -> Quaternion a -> Vector3 a 
                rotateAxis startPoint endPoint = LinearSpace.normalize 
                                                 $ (\q@(Quaternion w x y z)->Vector3 x y z) 
                                                 $ startPoint ^*^ endPoint
                --回転量の計算
                rotateQuantity :: (Floating a, Ord a) => Quaternion a -> Quaternion a -> a
                rotateQuantity startPoint endPoint = 
                    (\t' -> if (t' > 1.0) then 1.0 else t') (quaternionNorm (startPoint ^-^ endPoint))/(2.0 * vTrackBallR * (1/sqrt(2.0)))
                makeQuaternion axis@(Vector3 x y z) t = Quaternion (cos $ asin t) (t*x) (t*y) (t*z)

--GLUTの初期化
initGLUT :: IO ()
initGLUT = do
  (progName, _args) <- getArgsAndInitialize
  initialDisplayMode $= [RGBAMode, DoubleBuffered, WithDepthBuffer ]
  initialWindowSize  $= Size 512 512
  initialWindowPosition $= Position 100 100
  createWindow progName
  icos <- createIcosahedron 
  newCtx <- newIORef AppCtx { curPos  = simulateTrackball 0 0 0 0,
                              icosahedron = icos,
                              lastMousePos = Vector2 0 0,
                              winSize = Vector2 512.0 512.0 }
  keyboardMouseCallback $= Just (keyboardMouse newCtx)
  displayCallback $= display newCtx
  reshapeCallback $= Just (reshape newCtx)
  motionCallback  $= Just (motion newCtx)
  cursor          $= RightArrow
  where 
    createIcosahedron :: IO DisplayList
    createIcosahedron = defineNewList Compile $ do              
                          drawFrame
                          drawObject
                          return ()
        where 
          putVertex :: (NormalComponent a,VertexComponent a) => V.Vector (Vertex3 a) -> Int -> IO ()
          putVertex vertexData n = do
             normal $ (\v@(Vertex3 x y z)-> Normal3 x y x) (vertexData V.! n)
             vertex $ vertexData V.! n
             return ()
          --上下を表す平面を書く
          drawFrame  :: IO ()
          drawFrame = do 
             --3面だけ書く
             materialDiffuse FrontAndBack $= blue
             materialAmbient FrontAndBack $= black
             renderPrimitive Quads $ V.mapM_ (\idx->V.mapM_ (putVertex vertexData) idx) tIndices
             return ()
                 where 
                   l :: GLfloat
                   l = 0.5
                   vertexData::V.Vector (Vertex3 GLfloat)
                   vertexData = V.fromList [ Vertex3 (-l) (-l) (-l), Vertex3 l (-l) (-l), Vertex3 l l (-l), Vertex3 (-l) l (-l),
                                             Vertex3 (-l) (-l)   l , Vertex3 l (-l)   l , Vertex3 l l l,    Vertex3 (-l) l   l ]
                   tIndices::V.Vector (V.Vector Int)
                   tIndices = V.fromList $ map V.fromList [[0,1,2,3],[0,1,5,4],[3,0,4,7]]
          --二十面体を書きます
          drawObject :: IO ()
          drawObject = do
             cullFace $= Just Back
             materialAmbientAndDiffuse FrontAndBack $= white
             renderPrimitive Triangles $ V.mapM_ (\idx->V.mapM_ (putVertex vertexData) idx) tIndices
             cullFace $= Nothing
             return ()
                 where
                   x :: GLfloat
                   x = 0.525731112119133606/2.0
                   z :: GLfloat
                   z = 0.850650808352039932/2.0
                   vertexData::V.Vector (Vertex3 GLfloat)
                   vertexData = V.fromList [ Vertex3 (-x) 0.0    z,  Vertex3 x 0.0     z,   Vertex3 (-x) 0.0  (-z),
                                             Vertex3   x  0.0 (-z),  Vertex3 0.0 z     x,   Vertex3 0.0  z    (-x),
                                             Vertex3 0.0 (-z)    x,  Vertex3 0.0 (-z) (-x), Vertex3 z    x    0.0,
                                             Vertex3 (-z) x    0.0,  Vertex3 z   (-x) 0.0,  Vertex3 (-z) (-x) 0.0 ]
                                
                   tIndices::V.Vector (V.Vector Int)
                   tIndices = V.fromList $ map V.fromList [ [0,4,1], [0,9,4], [9,5,4], [4,5,8],
                                                            [4,8,1],[8,10,1],[8,3,10], [5,3,8],
                                                            [5,2,3], [2,7,3],[7,10,3],[7,6,10],
                                                           [7,11,6],[11,0,6], [0,1,6],[6,1,10],
                                                           [9,0,11],[9,11,2], [9,2,5],[7,2,11] ]

    display :: IORef AppCtx -> DisplayCallback
    display appCtx = do
             nowCtx <- readIORef appCtx
             clear [ColorBuffer, DepthBuffer]
             loadIdentity
             lookAt (Vertex3 0.0 0.0 4.0) (Vertex3 0.0 0.0 0.0) (Vector3 0.0 1.0 0.0)             
             preservingMatrix $ do
                       m <- createRotationMatrix $ curPos nowCtx
                       scale (1.0::GLdouble) (1.0::GLdouble) (1.0::GLdouble)
                       multMatrix m
                       callList $ icosahedron nowCtx
             swapBuffers
             where
               --回転を表すクォータニオンから回転行列を作ります。
               createRotationMatrix::(MatrixComponent a,Fractional a) => Quaternion a -> IO (GLmatrix a)
               createRotationMatrix q@(Quaternion w x y z) = newMatrix ColumnMajor [ m00, m01, m02, m03,
                                                                                     m10, m11, m12, m13,
                                                                                     m20, m21, m22, m23,
                                                                                     m30, m31, m32, m33 ]
                     where 
                       m00 = 1.0 - 2.0 * (y*y + z*z)
                       m01 =       2.0 * (x*y - z*w)
                       m02 =       2.0 * (z*x + w*y)
                       m03 = 0.0
                       m10 =       2.0 * (x*y + z*w)
                       m11 = 1.0 - 2.0 * (z*z + x*x)
                       m12 =       2.0 * (y*z - w*x)
                       m13 = 0.0
                       m20 =       2.0 * (z*x - w*y)
                       m21 =       2.0 * (y*z + x*w)
                       m22 = 1.0 - 2.0 * (y*y + x*x)
                       m23 = 0.0
                       m30 = 0.0
                       m31 = 0.0
                       m32 = 0.0
                       m33 = 1.0

    reshape :: IORef AppCtx -> ReshapeCallback
    reshape appCtx size@(Size w h) = do
             viewport   $= (Position 0 0, size)
             matrixMode $= Projection
             loadIdentity
             perspective 20.0 (fromIntegral w/fromIntegral h) 2.0 (-2.0)
             matrixMode $= Modelview 0
             nowCtx <- readIORef appCtx
             writeIORef appCtx AppCtx { curPos  = curPos nowCtx,
                                        icosahedron = icosahedron nowCtx,
                                        lastMousePos = lastMousePos nowCtx, 
                                        winSize = Vector2 (fromIntegral w) (fromIntegral h) }
    -- マウスの移動量をとります。
    motion :: IORef AppCtx -> MotionCallback
    motion appCtx size@(Position x y) = do
             nowCtx <- readIORef appCtx
             let (startX, startY, endX, endY) = 
                     (\s@(Vector2 w h)->
                      \p@(Vector2 lastX lastY)->
                      ((2.0*lastX - w)/w, (h - 2.0*lastY)/h, 
                       (2.0*(fromIntegral x) - w)/w, (h - 2.0*(fromIntegral y))/h)) (winSize nowCtx) (lastMousePos nowCtx)
             writeIORef appCtx AppCtx {
                              curPos = Q.normalize $ (simulateTrackball startX startY endX endY) ^*^ (curPos nowCtx),
                              icosahedron = icosahedron nowCtx,
                              lastMousePos = Vector2 (fromIntegral x) (fromIntegral y), 
                              winSize = winSize nowCtx }
             postRedisplay Nothing

    keyboardMouse :: IORef AppCtx -> KeyboardMouseCallback
    keyboardMouse appCtx (Char '\27') Down _ _ = exitWith ExitSuccess
    -- マウスのLeftButtonの押しはじめを記録
    keyboardMouse appCtx (MouseButton LeftButton) Down _ p@(Position x y) = do
             nowCtx <- readIORef appCtx
             writeIORef appCtx AppCtx { curPos  = curPos  nowCtx,
                                        icosahedron = icosahedron nowCtx,
                                        lastMousePos = Vector2 (fromIntegral x) (fromIntegral y), 
                                        winSize = winSize nowCtx }
    keyboardMouse appCtx _          _    _ _ = return ()

initGL :: IO ()
initGL = do
  clearColor $= Color4 0.0 0.0 0.0 1.0 --背景色
 
  frontFace       $= CW
  GLUT.normalize  $= Enabled
  depthFunc       $= Just Less --Depth Testを有効に
  lighting        $= Enabled
  
  diffuse (Light 0) $= lightblack
  light   (Light 0) $= Enabled

  diffuse  (Light 1) $= white --白を設定  
  position (Light 1) $= (Vertex4 0.0 9.0 0.5 1.0 :: Vertex4 GLfloat) --ligth1の位置
  light    (Light 1) $= Enabled

main :: IO ()
main = do
  initGLUT
  initGL
  mainLoop