25b0229183c19adae430fae364362c0541b45bce
[openvswitch] / python / ovs / jsonrpc.py
1 # Copyright (c) 2010, 2011 Nicira Networks
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17
18 import ovs.json
19 import ovs.poller
20 import ovs.reconnect
21 import ovs.stream
22 import ovs.timeval
23 import ovs.vlog
24
25 EOF = -1
26 vlog = ovs.vlog.Vlog("jsonrpc")
27
28
29 class Message(object):
30     T_REQUEST = 0               # Request.
31     T_NOTIFY = 1                # Notification.
32     T_REPLY = 2                 # Successful reply.
33     T_ERROR = 3                 # Error reply.
34
35     __types = {T_REQUEST: "request",
36                T_NOTIFY: "notification",
37                T_REPLY: "reply",
38                T_ERROR: "error"}
39
40     def __init__(self, type_, method, params, result, error, id):
41         self.type = type_
42         self.method = method
43         self.params = params
44         self.result = result
45         self.error = error
46         self.id = id
47
48     _next_id = 0
49
50     @staticmethod
51     def _create_id():
52         this_id = Message._next_id
53         Message._next_id += 1
54         return this_id
55
56     @staticmethod
57     def create_request(method, params):
58         return Message(Message.T_REQUEST, method, params, None, None,
59                        Message._create_id())
60
61     @staticmethod
62     def create_notify(method, params):
63         return Message(Message.T_NOTIFY, method, params, None, None,
64                        None)
65
66     @staticmethod
67     def create_reply(result, id):
68         return Message(Message.T_REPLY, None, None, result, None, id)
69
70     @staticmethod
71     def create_error(error, id):
72         return Message(Message.T_ERROR, None, None, None, error, id)
73
74     @staticmethod
75     def type_to_string(type_):
76         return Message.__types[type_]
77
78     def __validate_arg(self, value, name, must_have):
79         if (value is not None) == (must_have != 0):
80             return None
81         else:
82             type_name = Message.type_to_string(self.type)
83             if must_have:
84                 verb = "must"
85             else:
86                 verb = "must not"
87             return "%s %s have \"%s\"" % (type_name, verb, name)
88
89     def is_valid(self):
90         if self.params is not None and type(self.params) != list:
91             return "\"params\" must be JSON array"
92
93         pattern = {Message.T_REQUEST: 0x11001,
94                    Message.T_NOTIFY:  0x11000,
95                    Message.T_REPLY:   0x00101,
96                    Message.T_ERROR:   0x00011}.get(self.type)
97         if pattern is None:
98             return "invalid JSON-RPC message type %s" % self.type
99
100         return (
101             self.__validate_arg(self.method, "method", pattern & 0x10000) or
102             self.__validate_arg(self.params, "params", pattern & 0x1000) or
103             self.__validate_arg(self.result, "result", pattern & 0x100) or
104             self.__validate_arg(self.error, "error", pattern & 0x10) or
105             self.__validate_arg(self.id, "id", pattern & 0x1))
106
107     @staticmethod
108     def from_json(json):
109         if type(json) != dict:
110             return "message is not a JSON object"
111
112         # Make a copy to avoid modifying the caller's dict.
113         json = dict(json)
114
115         if "method" in json:
116             method = json.pop("method")
117             if type(method) not in [str, unicode]:
118                 return "method is not a JSON string"
119         else:
120             method = None
121
122         params = json.pop("params", None)
123         result = json.pop("result", None)
124         error = json.pop("error", None)
125         id_ = json.pop("id", None)
126         if len(json):
127             return "message has unexpected member \"%s\"" % json.popitem()[0]
128
129         if result is not None:
130             msg_type = Message.T_REPLY
131         elif error is not None:
132             msg_type = Message.T_ERROR
133         elif id_ is not None:
134             msg_type = Message.T_REQUEST
135         else:
136             msg_type = Message.T_NOTIFY
137
138         msg = Message(msg_type, method, params, result, error, id_)
139         validation_error = msg.is_valid()
140         if validation_error is not None:
141             return validation_error
142         else:
143             return msg
144
145     def to_json(self):
146         json = {}
147
148         if self.method is not None:
149             json["method"] = self.method
150
151         if self.params is not None:
152             json["params"] = self.params
153
154         if self.result is not None or self.type == Message.T_ERROR:
155             json["result"] = self.result
156
157         if self.error is not None or self.type == Message.T_REPLY:
158             json["error"] = self.error
159
160         if self.id is not None or self.type == Message.T_NOTIFY:
161             json["id"] = self.id
162
163         return json
164
165     def __str__(self):
166         s = [Message.type_to_string(self.type)]
167         if self.method is not None:
168             s.append("method=\"%s\"" % self.method)
169         if self.params is not None:
170             s.append("params=" + ovs.json.to_string(self.params))
171         if self.result is not None:
172             s.append("result=" + ovs.json.to_string(self.result))
173         if self.error is not None:
174             s.append("error=" + ovs.json.to_string(self.error))
175         if self.id is not None:
176             s.append("id=" + ovs.json.to_string(self.id))
177         return ", ".join(s)
178
179
180 class Connection(object):
181     def __init__(self, stream):
182         self.name = stream.name
183         self.stream = stream
184         self.status = 0
185         self.input = ""
186         self.output = ""
187         self.parser = None
188
189     def close(self):
190         self.stream.close()
191         self.stream = None
192
193     def run(self):
194         if self.status:
195             return
196
197         while len(self.output):
198             retval = self.stream.send(self.output)
199             if retval >= 0:
200                 self.output = self.output[retval:]
201             else:
202                 if retval != -errno.EAGAIN:
203                     vlog.warn("%s: send error: %s" %
204                               (self.name, os.strerror(-retval)))
205                     self.error(-retval)
206                 break
207
208     def wait(self, poller):
209         if not self.status:
210             self.stream.run_wait(poller)
211             if len(self.output):
212                 self.stream.send_wait()
213
214     def get_status(self):
215         return self.status
216
217     def get_backlog(self):
218         if self.status != 0:
219             return 0
220         else:
221             return len(self.output)
222
223     def __log_msg(self, title, msg):
224         vlog.dbg("%s: %s %s" % (self.name, title, msg))
225
226     def send(self, msg):
227         if self.status:
228             return self.status
229
230         self.__log_msg("send", msg)
231
232         was_empty = len(self.output) == 0
233         self.output += ovs.json.to_string(msg.to_json())
234         if was_empty:
235             self.run()
236         return self.status
237
238     def send_block(self, msg):
239         error = self.send(msg)
240         if error:
241             return error
242
243         while True:
244             self.run()
245             if not self.get_backlog() or self.get_status():
246                 return self.status
247
248             poller = ovs.poller.Poller()
249             self.wait(poller)
250             poller.block()
251
252     def recv(self):
253         if self.status:
254             return self.status, None
255
256         while True:
257             if not self.input:
258                 error, data = self.stream.recv(4096)
259                 if error:
260                     if error == errno.EAGAIN:
261                         return error, None
262                     else:
263                         # XXX rate-limit
264                         vlog.warn("%s: receive error: %s"
265                                   % (self.name, os.strerror(error)))
266                         self.error(error)
267                         return self.status, None
268                 elif not data:
269                     self.error(EOF)
270                     return EOF, None
271                 else:
272                     self.input += data
273             else:
274                 if self.parser is None:
275                     self.parser = ovs.json.Parser()
276                 self.input = self.input[self.parser.feed(self.input):]
277                 if self.parser.is_done():
278                     msg = self.__process_msg()
279                     if msg:
280                         return 0, msg
281                     else:
282                         return self.status, None
283
284     def recv_block(self):
285         while True:
286             error, msg = self.recv()
287             if error != errno.EAGAIN:
288                 return error, msg
289
290             self.run()
291
292             poller = ovs.poller.Poller()
293             self.wait(poller)
294             self.recv_wait(poller)
295             poller.block()
296
297     def transact_block(self, request):
298         id_ = request.id
299
300         error = self.send(request)
301         reply = None
302         while not error:
303             error, reply = self.recv_block()
304             if reply and reply.type == Message.T_REPLY and reply.id == id_:
305                 break
306         return error, reply
307
308     def __process_msg(self):
309         json = self.parser.finish()
310         self.parser = None
311         if type(json) in [str, unicode]:
312             # XXX rate-limit
313             vlog.warn("%s: error parsing stream: %s" % (self.name, json))
314             self.error(errno.EPROTO)
315             return
316
317         msg = Message.from_json(json)
318         if not isinstance(msg, Message):
319             # XXX rate-limit
320             vlog.warn("%s: received bad JSON-RPC message: %s"
321                       % (self.name, msg))
322             self.error(errno.EPROTO)
323             return
324
325         self.__log_msg("received", msg)
326         return msg
327
328     def recv_wait(self, poller):
329         if self.status or self.input:
330             poller.immediate_wake()
331         else:
332             self.stream.recv_wait(poller)
333
334     def error(self, error):
335         if self.status == 0:
336             self.status = error
337             self.stream.close()
338             self.output = ""
339
340
341 class Session(object):
342     """A JSON-RPC session with reconnection."""
343
344     def __init__(self, reconnect, rpc):
345         self.reconnect = reconnect
346         self.rpc = rpc
347         self.stream = None
348         self.pstream = None
349         self.seqno = 0
350
351     @staticmethod
352     def open(name):
353         """Creates and returns a Session that maintains a JSON-RPC session to
354         'name', which should be a string acceptable to ovs.stream.Stream or
355         ovs.stream.PassiveStream's initializer.
356
357         If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new
358         session connects and reconnects, with back-off, to 'name'.
359
360         If 'name' is a passive connection method, e.g. "ptcp:", the new session
361         listens for connections to 'name'.  It maintains at most one connection
362         at any given time.  Any new connection causes the previous one (if any)
363         to be dropped."""
364         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
365         reconnect.set_name(name)
366         reconnect.enable(ovs.timeval.msec())
367
368         if ovs.stream.PassiveStream.is_valid_name(name):
369             reconnect.set_passive(True, ovs.timeval.msec())
370
371         return Session(reconnect, None)
372
373     @staticmethod
374     def open_unreliably(jsonrpc):
375         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
376         reconnect.set_quiet(True)
377         reconnect.set_name(jsonrpc.name)
378         reconnect.set_max_tries(0)
379         reconnect.connected(ovs.timeval.msec())
380         return Session(reconnect, jsonrpc)
381
382     def close(self):
383         if self.rpc is not None:
384             self.rpc.close()
385             self.rpc = None
386         if self.stream is not None:
387             self.stream.close()
388             self.stream = None
389         if self.pstream is not None:
390             self.pstream.close()
391             self.pstream = None
392
393     def __disconnect(self):
394         if self.rpc is not None:
395             self.rpc.error(EOF)
396             self.rpc.close()
397             self.rpc = None
398             self.seqno += 1
399         elif self.stream is not None:
400             self.stream.close()
401             self.stream = None
402             self.seqno += 1
403
404     def __connect(self):
405         self.__disconnect()
406
407         name = self.reconnect.get_name()
408         if not self.reconnect.is_passive():
409             error, self.stream = ovs.stream.Stream.open(name)
410             if not error:
411                 self.reconnect.connecting(ovs.timeval.msec())
412             else:
413                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
414         elif self.pstream is not None:
415             error, self.pstream = ovs.stream.PassiveStream.open(name)
416             if not error:
417                 self.reconnect.listening(ovs.timeval.msec())
418             else:
419                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
420
421         self.seqno += 1
422
423     def run(self):
424         if self.pstream is not None:
425             error, stream = self.pstream.accept()
426             if error == 0:
427                 if self.rpc or self.stream:
428                     # XXX rate-limit
429                     vlog.info("%s: new connection replacing active "
430                               "connection" % self.reconnect.get_name())
431                     self.__disconnect()
432                 self.reconnect.connected(ovs.timeval.msec())
433                 self.rpc = Connection(stream)
434             elif error != errno.EAGAIN:
435                 self.reconnect.listen_error(ovs.timeval.msec(), error)
436                 self.pstream.close()
437                 self.pstream = None
438
439         if self.rpc:
440             self.rpc.run()
441             error = self.rpc.get_status()
442             if error != 0:
443                 self.reconnect.disconnected(ovs.timeval.msec(), error)
444                 self.__disconnect()
445         elif self.stream is not None:
446             self.stream.run()
447             error = self.stream.connect()
448             if error == 0:
449                 self.reconnect.connected(ovs.timeval.msec())
450                 self.rpc = Connection(self.stream)
451                 self.stream = None
452             elif error != errno.EAGAIN:
453                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
454                 self.stream.close()
455                 self.stream = None
456
457         action = self.reconnect.run(ovs.timeval.msec())
458         if action == ovs.reconnect.CONNECT:
459             self.__connect()
460         elif action == ovs.reconnect.DISCONNECT:
461             self.reconnect.disconnected(ovs.timeval.msec(), 0)
462             self.__disconnect()
463         elif action == ovs.reconnect.PROBE:
464             if self.rpc:
465                 request = Message.create_request("echo", [])
466                 request.id = "echo"
467                 self.rpc.send(request)
468         else:
469             assert action == None
470
471     def wait(self, poller):
472         if self.rpc is not None:
473             self.rpc.wait(poller)
474         elif self.stream is not None:
475             self.stream.run_wait(poller)
476             self.stream.connect_wait(poller)
477         if self.pstream is not None:
478             self.pstream.wait(poller)
479         self.reconnect.wait(poller, ovs.timeval.msec())
480
481     def get_backlog(self):
482         if self.rpc is not None:
483             return self.rpc.get_backlog()
484         else:
485             return 0
486
487     def get_name(self):
488         return self.reconnect.get_name()
489
490     def send(self, msg):
491         if self.rpc is not None:
492             return self.rpc.send(msg)
493         else:
494             return errno.ENOTCONN
495
496     def recv(self):
497         if self.rpc is not None:
498             error, msg = self.rpc.recv()
499             if not error:
500                 self.reconnect.received(ovs.timeval.msec())
501                 if msg.type == Message.T_REQUEST and msg.method == "echo":
502                     # Echo request.  Send reply.
503                     self.send(Message.create_reply(msg.params, msg.id))
504                 elif msg.type == Message.T_REPLY and msg.id == "echo":
505                     # It's a reply to our echo request.  Suppress it.
506                     pass
507                 else:
508                     return msg
509         return None
510
511     def recv_wait(self, poller):
512         if self.rpc is not None:
513             self.rpc.recv_wait(poller)
514
515     def is_alive(self):
516         if self.rpc is not None or self.stream is not None:
517             return True
518         else:
519             max_tries = self.reconnect.get_max_tries()
520             return max_tries is None or max_tries > 0
521
522     def is_connected(self):
523         return self.rpc is not None
524
525     def get_seqno(self):
526         return self.seqno
527
528     def force_reconnect(self):
529         self.reconnect.force_reconnect(ovs.timeval.msec())