Add API Framework Revel Source Files
[iec.git] / src / foundation / api / revel / cache / redis.go
1 // Copyright (c) 2012-2016 The Revel Framework Authors, All rights reserved.
2 // Revel Framework source code and usage is governed by a MIT style
3 // license that can be found in the LICENSE file.
4
5 package cache
6
7 import (
8         "time"
9
10         "github.com/garyburd/redigo/redis"
11         "github.com/revel/revel"
12 )
13
14 // RedisCache wraps the Redis client to meet the Cache interface.
15 type RedisCache struct {
16         pool              *redis.Pool
17         defaultExpiration time.Duration
18 }
19
20 // NewRedisCache returns a new RedisCache with given parameters
21 // until redigo supports sharding/clustering, only one host will be in hostList
22 func NewRedisCache(host string, password string, defaultExpiration time.Duration) RedisCache {
23         var pool = &redis.Pool{
24                 MaxIdle:     revel.Config.IntDefault("cache.redis.maxidle", 5),
25                 MaxActive:   revel.Config.IntDefault("cache.redis.maxactive", 0),
26                 IdleTimeout: time.Duration(revel.Config.IntDefault("cache.redis.idletimeout", 240)) * time.Second,
27                 Dial: func() (redis.Conn, error) {
28                         protocol := revel.Config.StringDefault("cache.redis.protocol", "tcp")
29                         toc := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.connect", 10000))
30                         tor := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.read", 5000))
31                         tow := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.write", 5000))
32                         c, err := redis.Dial(protocol, host,
33                                 redis.DialConnectTimeout(toc),
34                                 redis.DialReadTimeout(tor),
35                                 redis.DialWriteTimeout(tow))
36                         if err != nil {
37                                 return nil, err
38                         }
39                         if len(password) > 0 {
40                                 if _, err = c.Do("AUTH", password); err != nil {
41                                         _ = c.Close()
42                                         return nil, err
43                                 }
44                         } else {
45                                 // check with PING
46                                 if _, err = c.Do("PING"); err != nil {
47                                         _ = c.Close()
48                                         return nil, err
49                                 }
50                         }
51                         return c, err
52                 },
53                 // custom connection test method
54                 TestOnBorrow: func(c redis.Conn, t time.Time) error {
55                         _, err := c.Do("PING")
56                         return err
57                 },
58         }
59         return RedisCache{pool, defaultExpiration}
60 }
61
62 func (c RedisCache) Set(key string, value interface{}, expires time.Duration) error {
63         conn := c.pool.Get()
64         defer func() {
65                 _ = conn.Close()
66         }()
67         return c.invoke(conn.Do, key, value, expires)
68 }
69
70 func (c RedisCache) Add(key string, value interface{}, expires time.Duration) error {
71         conn := c.pool.Get()
72         defer func() {
73                 _ = conn.Close()
74         }()
75
76         existed, err := exists(conn, key)
77         if err != nil {
78                 return err
79         } else if existed {
80                 return ErrNotStored
81         }
82         return c.invoke(conn.Do, key, value, expires)
83 }
84
85 func (c RedisCache) Replace(key string, value interface{}, expires time.Duration) error {
86         conn := c.pool.Get()
87         defer func() {
88                 _ = conn.Close()
89         }()
90
91         existed, err := exists(conn, key)
92         if err != nil {
93                 return err
94         } else if !existed {
95                 return ErrNotStored
96         }
97
98         err = c.invoke(conn.Do, key, value, expires)
99         if value == nil {
100                 return ErrNotStored
101         }
102         return err
103 }
104
105 func (c RedisCache) Get(key string, ptrValue interface{}) error {
106         conn := c.pool.Get()
107         defer func() {
108                 _ = conn.Close()
109         }()
110         raw, err := conn.Do("GET", key)
111         if err != nil {
112                 return err
113         } else if raw == nil {
114                 return ErrCacheMiss
115         }
116         item, err := redis.Bytes(raw, err)
117         if err != nil {
118                 return err
119         }
120         return Deserialize(item, ptrValue)
121 }
122
123 func generalizeStringSlice(strs []string) []interface{} {
124         ret := make([]interface{}, len(strs))
125         for i, str := range strs {
126                 ret[i] = str
127         }
128         return ret
129 }
130
131 func (c RedisCache) GetMulti(keys ...string) (Getter, error) {
132         conn := c.pool.Get()
133         defer func() {
134                 _ = conn.Close()
135         }()
136
137         items, err := redis.Values(conn.Do("MGET", generalizeStringSlice(keys)...))
138         if err != nil {
139                 return nil, err
140         } else if items == nil {
141                 return nil, ErrCacheMiss
142         }
143
144         m := make(map[string][]byte)
145         for i, key := range keys {
146                 m[key] = nil
147                 if i < len(items) && items[i] != nil {
148                         s, ok := items[i].([]byte)
149                         if ok {
150                                 m[key] = s
151                         }
152                 }
153         }
154         return RedisItemMapGetter(m), nil
155 }
156
157 func exists(conn redis.Conn, key string) (bool, error) {
158         return redis.Bool(conn.Do("EXISTS", key))
159 }
160
161 func (c RedisCache) Delete(key string) error {
162         conn := c.pool.Get()
163         defer func() {
164                 _ = conn.Close()
165         }()
166         existed, err := redis.Bool(conn.Do("DEL", key))
167         if err == nil && !existed {
168                 err = ErrCacheMiss
169         }
170         return err
171 }
172
173 func (c RedisCache) Increment(key string, delta uint64) (uint64, error) {
174         conn := c.pool.Get()
175         defer func() {
176                 _ = conn.Close()
177         }()
178         // Check for existence *before* increment as per the cache contract.
179         // redis will auto create the key, and we don't want that. Since we need to do increment
180         // ourselves instead of natively via INCRBY (redis doesn't support wrapping), we get the value
181         // and do the exists check this way to minimize calls to Redis
182         val, err := conn.Do("GET", key)
183         if err != nil {
184                 return 0, err
185         } else if val == nil {
186                 return 0, ErrCacheMiss
187         }
188         currentVal, err := redis.Int64(val, nil)
189         if err != nil {
190                 return 0, err
191         }
192         sum := currentVal + int64(delta)
193         _, err = conn.Do("SET", key, sum)
194         if err != nil {
195                 return 0, err
196         }
197         return uint64(sum), nil
198 }
199
200 func (c RedisCache) Decrement(key string, delta uint64) (newValue uint64, err error) {
201         conn := c.pool.Get()
202         defer func() {
203                 _ = conn.Close()
204         }()
205         // Check for existence *before* increment as per the cache contract.
206         // redis will auto create the key, and we don't want that, hence the exists call
207         existed, err := exists(conn, key)
208         if err != nil {
209                 return 0, err
210         } else if !existed {
211                 return 0, ErrCacheMiss
212         }
213         // Decrement contract says you can only go to 0
214         // so we go fetch the value and if the delta is greater than the amount,
215         // 0 out the value
216         currentVal, err := redis.Int64(conn.Do("GET", key))
217         if err != nil {
218                 return 0, err
219         }
220         if delta > uint64(currentVal) {
221                 var tempint int64
222                 tempint, err = redis.Int64(conn.Do("DECRBY", key, currentVal))
223                 return uint64(tempint), err
224         }
225         tempint, err := redis.Int64(conn.Do("DECRBY", key, delta))
226         return uint64(tempint), err
227 }
228
229 func (c RedisCache) Flush() error {
230         conn := c.pool.Get()
231         defer func() {
232                 _ = conn.Close()
233         }()
234         _, err := conn.Do("FLUSHALL")
235         return err
236 }
237
238 func (c RedisCache) invoke(f func(string, ...interface{}) (interface{}, error),
239         key string, value interface{}, expires time.Duration) error {
240
241         switch expires {
242         case DefaultExpiryTime:
243                 expires = c.defaultExpiration
244         case ForEverNeverExpiry:
245                 expires = time.Duration(0)
246         }
247
248         b, err := Serialize(value)
249         if err != nil {
250                 return err
251         }
252         conn := c.pool.Get()
253         defer func() {
254                 _ = conn.Close()
255         }()
256         if expires > 0 {
257                 _, err = f("SETEX", key, int32(expires/time.Second), b)
258                 return err
259         }
260         _, err = f("SET", key, b)
261         return err
262 }
263
264 // RedisItemMapGetter implements a Getter on top of the returned item map.
265 type RedisItemMapGetter map[string][]byte
266
267 func (g RedisItemMapGetter) Get(key string, ptrValue interface{}) error {
268         item, ok := g[key]
269         if !ok {
270                 return ErrCacheMiss
271         }
272         return Deserialize(item, ptrValue)
273 }